Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
554bb262
Commit
554bb262
authored
Oct 19, 2021
by
rprenger
Browse files
Code that keeps it from dying when the input prompts are too long
parent
a3770921
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
17 deletions
+29
-17
megatron/text_generation/api.py
megatron/text_generation/api.py
+2
-1
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+4
-0
megatron/text_generation_server.py
megatron/text_generation_server.py
+19
-15
tools/run_text_generation_server.py
tools/run_text_generation_server.py
+4
-1
No files found.
megatron/text_generation/api.py
View file @
554bb262
...
...
@@ -113,10 +113,11 @@ def generate(model,
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
add_BOS
=
add_BOS
)
if
just_score
:
return
score_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
)
# Main inference function.
# Note that the outputs are available on the first stage.
return
generate_tokens_probs_and_return_on_first_stage
(
...
...
megatron/text_generation/generation.py
View file @
554bb262
...
...
@@ -130,6 +130,10 @@ def generate_tokens_probs_and_return_on_first_stage(
min_prompt_length
=
lengths
.
min
().
item
()
max_sequence_length
=
tokens
.
size
(
1
)
max_sequence_length
=
min
(
max_sequence_length
,
args
.
max_position_embeddings
)
# If the context is too big, this happens
if
min_prompt_length
>=
max_sequence_length
:
raise
ValueError
# forward step.
forward_step
=
ForwardStep
(
model
,
batch_size
,
max_sequence_length
)
...
...
megatron/text_generation_server.py
View file @
554bb262
...
...
@@ -36,9 +36,6 @@ class MegatronGenerate(Resource):
def
put
(
self
):
args
=
get_args
()
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
print
(
json
.
dumps
(
request
.
get_json
()),
flush
=
True
)
print
(
"current time: "
,
datetime
.
datetime
.
now
())
if
not
"prompts"
in
request
.
get_json
():
return
"prompts argument required"
,
400
...
...
@@ -106,19 +103,26 @@ class MegatronGenerate(Resource):
return
"add_BOS must be a boolean value"
with
lock
:
# Need to get lock to keep multiple threads from hitting code
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
print
(
json
.
dumps
(
request
.
get_json
()),
flush
=
True
)
print
(
"start time: "
,
datetime
.
datetime
.
now
())
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
response
,
response_seg
,
response_logprobs
,
_
=
\
generate_and_post_process
(
self
.
model
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
return_output_log_probs
=
logprobs
,
top_k_sampling
=
top_k
,
top_p_sampling
=
top_p
,
temperature
=
temperature
,
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
True
,
just_score
=
just_score
)
try
:
response
,
response_seg
,
response_logprobs
,
_
=
\
generate_and_post_process
(
self
.
model
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
return_output_log_probs
=
logprobs
,
top_k_sampling
=
top_k
,
top_p_sampling
=
top_p
,
temperature
=
temperature
,
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
True
,
just_score
=
just_score
)
except
ValueError
as
ve
:
return
"Length of prompt + tokens_to_generate longer than allowed"
print
(
"end time: "
,
datetime
.
datetime
.
now
())
return
jsonify
({
"text"
:
response
,
"segments"
:
response_seg
,
...
...
tools/run_text_generation_server.py
View file @
554bb262
...
...
@@ -78,4 +78,7 @@ if __name__ == "__main__":
choice
=
torch
.
cuda
.
LongTensor
(
1
)
torch
.
distributed
.
broadcast
(
choice
,
0
)
if
choice
[
0
].
item
()
==
0
:
generate_and_post_process
(
model
)
try
:
generate_and_post_process
(
model
)
except
ValueError
as
ve
:
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment