Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b49349ec
Commit
b49349ec
authored
Jun 08, 2022
by
rprenger
Browse files
Adding top_p decay and bound for factual sampling from Factuality Enhanced LMs
parent
15f6bb1b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
9 deletions
+46
-9
megatron/text_generation/api.py
megatron/text_generation/api.py
+18
-8
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+6
-1
megatron/text_generation_server.py
megatron/text_generation_server.py
+22
-0
No files found.
megatron/text_generation/api.py
View file @
b49349ec
...
@@ -33,6 +33,8 @@ def generate_and_post_process(model,
...
@@ -33,6 +33,8 @@ def generate_and_post_process(model,
return_output_log_probs
=
False
,
return_output_log_probs
=
False
,
top_k_sampling
=
0
,
top_k_sampling
=
0
,
top_p_sampling
=
0.0
,
top_p_sampling
=
0.0
,
factual_decay
=
0.0
,
factual_bound
=
0.0
,
temperature
=
1.0
,
temperature
=
1.0
,
add_BOS
=
False
,
add_BOS
=
False
,
use_eod_token_for_early_termination
=
True
,
use_eod_token_for_early_termination
=
True
,
...
@@ -50,6 +52,8 @@ def generate_and_post_process(model,
...
@@ -50,6 +52,8 @@ def generate_and_post_process(model,
return_output_log_probs
=
return_output_log_probs
,
return_output_log_probs
=
return_output_log_probs
,
top_k_sampling
=
top_k_sampling
,
top_k_sampling
=
top_k_sampling
,
top_p_sampling
=
top_p_sampling
,
top_p_sampling
=
top_p_sampling
,
factual_decay
=
factual_decay
,
factual_bound
=
factual_bound
,
temperature
=
temperature
,
temperature
=
temperature
,
add_BOS
=
add_BOS
,
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
,
...
@@ -78,6 +82,8 @@ def generate(model,
...
@@ -78,6 +82,8 @@ def generate(model,
return_output_log_probs
=
False
,
return_output_log_probs
=
False
,
top_k_sampling
=
0
,
top_k_sampling
=
0
,
top_p_sampling
=
0.0
,
top_p_sampling
=
0.0
,
factual_decay
=
0.0
,
factual_bound
=
0.0
,
temperature
=
1.0
,
temperature
=
1.0
,
add_BOS
=
False
,
add_BOS
=
False
,
use_eod_token_for_early_termination
=
True
,
use_eod_token_for_early_termination
=
True
,
...
@@ -95,22 +101,24 @@ def generate(model,
...
@@ -95,22 +101,24 @@ def generate(model,
# Make sure input params are avaialble to all ranks.
# Make sure input params are avaialble to all ranks.
values
=
[
tokens_to_generate
,
values
=
[
tokens_to_generate
,
return_output_log_probs
,
return_output_log_probs
,
top_k_sampling
,
top_p_sampling
,
top_k_sampling
,
top_p_sampling
,
factual_decay
,
factual_bound
,
temperature
,
add_BOS
,
use_eod_token_for_early_termination
,
temperature
,
add_BOS
,
use_eod_token_for_early_termination
,
stop_on_double_eol
,
stop_on_double_eol
,
stop_on_eol
,
stop_on_eol
,
random_seed
]
random_seed
]
values_float_tensor
=
broadcast_float_list
(
1
0
,
float_list
=
values
)
values_float_tensor
=
broadcast_float_list
(
1
2
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
return_output_log_probs
=
bool
(
values_float_tensor
[
1
].
item
())
return_output_log_probs
=
bool
(
values_float_tensor
[
1
].
item
())
top_k_sampling
=
int
(
values_float_tensor
[
2
].
item
())
top_k_sampling
=
int
(
values_float_tensor
[
2
].
item
())
top_p_sampling
=
values_float_tensor
[
3
].
item
()
top_p_sampling
=
values_float_tensor
[
3
].
item
()
temperature
=
values_float_tensor
[
4
].
item
()
factual_decay
=
values_float_tensor
[
4
].
item
()
add_BOS
=
bool
(
values_float_tensor
[
5
].
item
())
factual_bound
=
values_float_tensor
[
5
].
item
()
use_eod_token_for_early_termination
=
bool
(
values_float_tensor
[
6
].
item
())
temperature
=
values_float_tensor
[
6
].
item
()
stop_on_double_eol
=
bool
(
values_float_tensor
[
7
].
item
())
add_BOS
=
bool
(
values_float_tensor
[
7
].
item
())
stop_on_eol
=
bool
(
values_float_tensor
[
8
].
item
())
use_eod_token_for_early_termination
=
bool
(
values_float_tensor
[
8
].
item
())
random_seed
=
int
(
values_float_tensor
[
9
].
item
())
stop_on_double_eol
=
bool
(
values_float_tensor
[
9
].
item
())
stop_on_eol
=
bool
(
values_float_tensor
[
10
].
item
())
random_seed
=
int
(
values_float_tensor
[
11
].
item
())
if
random_seed
!=
-
1
:
if
random_seed
!=
-
1
:
torch
.
random
.
manual_seed
(
random_seed
)
torch
.
random
.
manual_seed
(
random_seed
)
...
@@ -134,6 +142,8 @@ def generate(model,
...
@@ -134,6 +142,8 @@ def generate(model,
return_output_log_probs
=
return_output_log_probs
,
return_output_log_probs
=
return_output_log_probs
,
top_k
=
top_k_sampling
,
top_k
=
top_k_sampling
,
top_p
=
top_p_sampling
,
top_p
=
top_p_sampling
,
factual_decay
=
factual_decay
,
factual_bound
=
factual_bound
,
temperature
=
temperature
,
temperature
=
temperature
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
,
stop_on_double_eol
=
stop_on_double_eol
,
stop_on_double_eol
=
stop_on_double_eol
,
...
...
megatron/text_generation/generation.py
View file @
b49349ec
...
@@ -94,7 +94,7 @@ def score_and_return_on_first_stage(model, tokens, lengths):
...
@@ -94,7 +94,7 @@ def score_and_return_on_first_stage(model, tokens, lengths):
def
generate_tokens_probs_and_return_on_first_stage
(
def
generate_tokens_probs_and_return_on_first_stage
(
model
,
tokens
,
lengths
,
model
,
tokens
,
lengths
,
return_output_log_probs
=
False
,
return_output_log_probs
=
False
,
top_k
=
0
,
top_p
=
0.0
,
top_k
=
0
,
top_p
=
0.0
,
factual_decay
=
0.0
,
factual_bound
=
0.0
,
temperature
=
1.0
,
temperature
=
1.0
,
use_eod_token_for_early_termination
=
True
,
use_eod_token_for_early_termination
=
True
,
stop_on_double_eol
=
False
,
stop_on_double_eol
=
False
,
...
@@ -200,6 +200,11 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -200,6 +200,11 @@ def generate_tokens_probs_and_return_on_first_stage(
top_p
=
top_p
,
top_p
=
top_p
,
temperature
=
temperature
,
temperature
=
temperature
,
vocab_size
=
tokenizer
.
vocab_size
)
vocab_size
=
tokenizer
.
vocab_size
)
if
top_p
>
0.0
and
factual_decay
>
0.0
:
top_p
=
top_p
*
factual_decay
if
factual_bound
>
0.0
:
top_p
=
max
(
top_p
,
factual_bound
)
# If a prompt length is smaller or equal th current context
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
# length, it means we have started generating tokens
started
=
lengths
<=
context_length
started
=
lengths
<=
context_length
...
...
megatron/text_generation_server.py
View file @
b49349ec
...
@@ -93,6 +93,26 @@ class MegatronGenerate(Resource):
...
@@ -93,6 +93,26 @@ class MegatronGenerate(Resource):
if
not
(
0
<=
top_p
<=
1.0
):
if
not
(
0
<=
top_p
<=
1.0
):
return
"top_p must be less than or equal to 1.0"
return
"top_p must be less than or equal to 1.0"
factual_decay
=
0.0
if
"factual_decay"
in
request
.
get_json
():
factual_decay
=
request
.
get_json
()[
"factual_decay"
]
if
not
(
type
(
factual_decay
)
==
float
):
return
"factual_decay must be a positive float less than or equal to 1.0"
if
top_p
==
0.0
:
return
"factual_decay cannot be set without top_p"
if
not
(
0
<=
factual_decay
<=
1.0
):
return
"factual_decay must be less than or equal to 1.0"
factual_bound
=
0.0
if
"factual_bound"
in
request
.
get_json
():
factual_bound
=
request
.
get_json
()[
"factual_bound"
]
if
not
(
type
(
factual_bound
)
==
float
):
return
"factual_bound must be a positive float less than or equal to top_p"
if
top_p
==
0.0
:
return
"factual_bound cannot be set without top_p"
if
not
(
0.0
<
factual_bound
<=
top_p
):
return
"factual_bound must be greater than 0 and less than top_p"
add_BOS
=
False
add_BOS
=
False
if
"add_BOS"
in
request
.
get_json
():
if
"add_BOS"
in
request
.
get_json
():
add_BOS
=
request
.
get_json
()[
"add_BOS"
]
add_BOS
=
request
.
get_json
()[
"add_BOS"
]
...
@@ -143,6 +163,8 @@ class MegatronGenerate(Resource):
...
@@ -143,6 +163,8 @@ class MegatronGenerate(Resource):
return_output_log_probs
=
logprobs
,
return_output_log_probs
=
logprobs
,
top_k_sampling
=
top_k
,
top_k_sampling
=
top_k
,
top_p_sampling
=
top_p
,
top_p_sampling
=
top_p
,
factual_decay
=
factual_decay
,
factual_bound
=
factual_bound
,
temperature
=
temperature
,
temperature
=
temperature
,
add_BOS
=
add_BOS
,
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
True
,
use_eod_token_for_early_termination
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment