Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9b131fad
Commit
9b131fad
authored
Nov 02, 2021
by
rprenger
Browse files
Adding stop token logic and random seed for deterministic answers
parent
83bc79d6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
6 deletions
+23
-6
megatron/text_generation/api.py
megatron/text_generation/api.py
+13
-5
megatron/text_generation_server.py
megatron/text_generation_server.py
+10
-1
No files found.
megatron/text_generation/api.py
View file @
9b131fad
...
...
@@ -37,7 +37,8 @@ def generate_and_post_process(model,
add_BOS
=
False
,
use_eod_token_for_early_termination
=
True
,
stop_on_double_eol
=
False
,
stop_on_eol
=
False
):
stop_on_eol
=
False
,
random_seed
=-
1
):
"""Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
...
...
@@ -53,7 +54,8 @@ def generate_and_post_process(model,
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
,
stop_on_double_eol
=
stop_on_double_eol
,
stop_on_eol
=
stop_on_eol
)
stop_on_eol
=
stop_on_eol
,
random_seed
=
random_seed
)
# Only post-process on first stage.
if
mpu
.
is_pipeline_first_stage
():
...
...
@@ -80,7 +82,8 @@ def generate(model,
add_BOS
=
False
,
use_eod_token_for_early_termination
=
True
,
stop_on_double_eol
=
False
,
stop_on_eol
=
False
):
stop_on_eol
=
False
,
random_seed
=-
1
):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
...
...
@@ -95,8 +98,9 @@ def generate(model,
top_k_sampling
,
top_p_sampling
,
temperature
,
add_BOS
,
use_eod_token_for_early_termination
,
stop_on_double_eol
,
stop_on_eol
]
values_float_tensor
=
broadcast_float_list
(
9
,
float_list
=
values
)
stop_on_eol
,
random_seed
]
values_float_tensor
=
broadcast_float_list
(
10
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
return_output_log_probs
=
bool
(
values_float_tensor
[
1
].
item
())
top_k_sampling
=
int
(
values_float_tensor
[
2
].
item
())
...
...
@@ -106,6 +110,10 @@ def generate(model,
use_eod_token_for_early_termination
=
bool
(
values_float_tensor
[
6
].
item
())
stop_on_double_eol
=
bool
(
values_float_tensor
[
7
].
item
())
stop_on_eol
=
bool
(
values_float_tensor
[
8
].
item
())
random_seed
=
int
(
values_float_tensor
[
9
].
item
())
if
random_seed
!=
-
1
:
torch
.
random
.
manual_seed
(
random_seed
)
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
...
...
megatron/text_generation_server.py
View file @
9b131fad
...
...
@@ -114,6 +114,14 @@ class MegatronGenerate(Resource):
if
not
isinstance
(
stop_on_eol
,
bool
):
return
"stop_on_eol must be a boolean value"
random_seed
=
-
1
if
"random_seed"
in
request
.
get_json
():
random_seed
=
request
.
get_json
()[
"random_seed"
]
if
not
isinstance
(
random_seed
,
int
):
return
"random_seed must be integer"
if
random_seed
<
0
:
return
"random_seed must be a positive integer"
# if str(request.remote_addr) == "10.14.68.146":
# return "Too many tokens requested from this IP address. Contact Ryan Prenger rprenger@nvidia.com"
...
...
@@ -135,7 +143,8 @@ class MegatronGenerate(Resource):
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
True
,
stop_on_double_eol
=
stop_on_double_eol
,
stop_on_eol
=
stop_on_eol
)
stop_on_eol
=
stop_on_eol
,
random_seed
=
random_seed
)
except
ValueError
as
ve
:
return
"Length of prompt + tokens_to_generate longer than allowed"
print
(
"end time: "
,
datetime
.
datetime
.
now
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment