Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9e0ee6fd
Commit
9e0ee6fd
authored
Oct 22, 2021
by
rprenger
Browse files
Hacked in way to to have stop tokens
parent
148a24ad
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
10 deletions
+55
-10
megatron/text_generation/api.py
megatron/text_generation/api.py
+18
-6
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+16
-3
megatron/text_generation_server.py
megatron/text_generation_server.py
+21
-1
No files found.
megatron/text_generation/api.py
View file @
9e0ee6fd
...
...
@@ -35,7 +35,9 @@ def generate_and_post_process(model,
top_p_sampling
=
0.0
,
temperature
=
1.0
,
add_BOS
=
False
,
use_eod_token_for_early_termination
=
True
):
use_eod_token_for_early_termination
=
True
,
stop_on_double_eol
=
False
,
stop_on_eol
=
False
):
"""Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
...
...
@@ -49,7 +51,9 @@ def generate_and_post_process(model,
top_p_sampling
=
top_p_sampling
,
temperature
=
temperature
,
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
)
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
,
stop_on_double_eol
=
stop_on_double_eol
,
stop_on_eol
=
stop_on_eol
)
# Only post-process on first stage.
if
mpu
.
is_pipeline_first_stage
():
...
...
@@ -74,7 +78,9 @@ def generate(model,
top_p_sampling
=
0.0
,
temperature
=
1.0
,
add_BOS
=
False
,
use_eod_token_for_early_termination
=
True
):
use_eod_token_for_early_termination
=
True
,
stop_on_double_eol
=
False
,
stop_on_eol
=
False
):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
...
...
@@ -87,8 +93,10 @@ def generate(model,
values
=
[
tokens_to_generate
,
return_output_log_probs
,
top_k_sampling
,
top_p_sampling
,
temperature
,
add_BOS
,
use_eod_token_for_early_termination
]
values_float_tensor
=
broadcast_float_list
(
7
,
float_list
=
values
)
temperature
,
add_BOS
,
use_eod_token_for_early_termination
,
stop_on_double_eol
,
stop_on_eol
]
values_float_tensor
=
broadcast_float_list
(
9
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
return_output_log_probs
=
bool
(
values_float_tensor
[
1
].
item
())
top_k_sampling
=
int
(
values_float_tensor
[
2
].
item
())
...
...
@@ -96,6 +104,8 @@ def generate(model,
temperature
=
values_float_tensor
[
4
].
item
()
add_BOS
=
bool
(
values_float_tensor
[
5
].
item
())
use_eod_token_for_early_termination
=
bool
(
values_float_tensor
[
6
].
item
())
stop_on_double_eol
=
bool
(
values_float_tensor
[
7
].
item
())
stop_on_eol
=
bool
(
values_float_tensor
[
8
].
item
())
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
...
...
@@ -117,4 +127,6 @@ def generate(model,
top_k
=
top_k_sampling
,
top_p
=
top_p_sampling
,
temperature
=
temperature
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
)
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
,
stop_on_double_eol
=
stop_on_double_eol
,
stop_on_eol
=
stop_on_eol
)
megatron/text_generation/generation.py
View file @
9e0ee6fd
...
...
@@ -96,7 +96,10 @@ def generate_tokens_probs_and_return_on_first_stage(
return_output_log_probs
=
False
,
top_k
=
0
,
top_p
=
0.0
,
temperature
=
1.0
,
use_eod_token_for_early_termination
=
True
):
use_eod_token_for_early_termination
=
True
,
stop_on_double_eol
=
False
,
stop_on_eol
=
False
):
"""Main token generation function.
Arguments:
model: no interleaving is supported.
...
...
@@ -231,8 +234,18 @@ def generate_tokens_probs_and_return_on_first_stage(
# Check if all the sequences have hit the termination_id.
done
=
None
if
mpu
.
is_pipeline_last_stage
():
done_token
=
(
new_sample
==
termination_id
).
byte
()
&
\
started
.
byte
()
if
stop_on_double_eol
:
hit_double_eol
=
(
new_sample
==
628
).
byte
()
&
started
.
byte
()
hit_two_eols
=
(
new_sample
==
198
).
byte
()
&
(
tokens
[:,
context_length
-
1
]
==
198
).
byte
()
&
started
.
byte
()
done_token
=
hit_double_eol
|
hit_two_eols
elif
stop_on_eol
:
hit_double_eol
=
(
new_sample
==
628
).
byte
()
&
started
.
byte
()
hit_eol
=
(
new_sample
==
198
).
byte
()
&
started
.
byte
()
done_token
=
hit_double_eol
|
hit_eol
else
:
done_token
=
(
new_sample
==
termination_id
).
byte
()
&
\
started
.
byte
()
just_finished
=
(
done_token
&
~
is_generation_done
).
bool
()
generated_sequence_lengths
[
just_finished
.
view
(
-
1
)]
=
\
context_length
+
1
...
...
megatron/text_generation_server.py
View file @
9e0ee6fd
...
...
@@ -98,6 +98,24 @@ class MegatronGenerate(Resource):
add_BOS
=
request
.
get_json
()[
"add_BOS"
]
if
not
isinstance
(
add_BOS
,
bool
):
return
"add_BOS must be a boolean value"
if
any
([
len
(
prompt
)
==
0
for
prompt
in
prompts
])
and
not
add_BOS
:
return
"Empty prompts require add_BOS=true"
stop_on_double_eol
=
False
if
"stop_on_double_eol"
in
request
.
get_json
():
stop_on_double_eol
=
request
.
get_json
()[
"stop_on_double_eol"
]
if
not
isinstance
(
stop_on_double_eol
,
bool
):
return
"stop_on_double_eol must be a boolean value"
stop_on_eol
=
False
if
"stop_on_eol"
in
request
.
get_json
():
stop_on_eol
=
request
.
get_json
()[
"stop_on_eol"
]
if
not
isinstance
(
stop_on_eol
,
bool
):
return
"stop_on_eol must be a boolean value"
if
str
(
request
.
remote_addr
)
==
"10.14.68.146"
:
return
"Too many tokens requested from this IP address. Contact Ryan Prenger rprenger@nvidia.com"
with
lock
:
# Need to get lock to keep multiple threads from hitting code
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
...
...
@@ -115,7 +133,9 @@ class MegatronGenerate(Resource):
top_p_sampling
=
top_p
,
temperature
=
temperature
,
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
True
)
use_eod_token_for_early_termination
=
True
,
stop_on_double_eol
=
stop_on_double_eol
,
stop_on_eol
=
stop_on_eol
)
except
ValueError
as
ve
:
return
"Length of prompt + tokens_to_generate longer than allowed"
print
(
"end time: "
,
datetime
.
datetime
.
now
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment