Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e38d41ca
Commit
e38d41ca
authored
Sep 02, 2022
by
rprenger
Browse files
Memory safety checks were incorrect for the tokens_to_generate=0 case
parent
d63c2541
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
3 deletions
+8
-3
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+8
-3
No files found.
megatron/text_generation/generation.py
View file @
e38d41ca
...
@@ -47,10 +47,15 @@ def score_and_return_on_first_stage(model, tokens, lengths):
...
@@ -47,10 +47,15 @@ def score_and_return_on_first_stage(model, tokens, lengths):
batch_size
=
tokens
.
size
(
0
)
batch_size
=
tokens
.
size
(
0
)
max_prompt_length
=
lengths
.
max
().
item
()
max_prompt_length
=
lengths
.
max
().
item
()
assert
max_prompt_length
==
tokens
.
size
(
1
)
assert
max_prompt_length
==
tokens
.
size
(
1
)
max_sequence_length
=
min
(
max_prompt_length
,
args
.
max_position_embeddings
)
if
max_prompt_length
>
args
.
max_position_embeddings
:
raise
ValueError
(
"Length of prompt + tokens_to_generate longer than allowed"
)
if
max_prompt_length
*
batch_size
>=
MAX_TOKENS_TO_OOM
:
raise
ValueError
(
"Too many tokens. "
+
str
(
max_prompt_length
*
batch_size
)
+
" is greater than "
+
str
(
MAX_TOKENS_TO_OOM
))
# forward step.
# forward step.
forward_step
=
ForwardStep
(
model
,
batch_size
,
max_
sequence
_length
)
forward_step
=
ForwardStep
(
model
,
batch_size
,
max_
prompt
_length
)
# ===================
# ===================
# Pre-allocate memory
# Pre-allocate memory
...
@@ -58,7 +63,7 @@ def score_and_return_on_first_stage(model, tokens, lengths):
...
@@ -58,7 +63,7 @@ def score_and_return_on_first_stage(model, tokens, lengths):
# Log probability of the sequence (prompt + generated tokens).
# Log probability of the sequence (prompt + generated tokens).
output_log_probs
=
None
output_log_probs
=
None
output_log_probs_size
=
(
batch_size
,
max_
sequence
_length
-
1
)
output_log_probs_size
=
(
batch_size
,
max_
prompt
_length
-
1
)
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
output_log_probs
=
torch
.
empty
(
output_log_probs_size
,
output_log_probs
=
torch
.
empty
(
output_log_probs_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment