Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1d4e8760
Commit
1d4e8760
authored
Dec 10, 2020
by
Jared Casper
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Fix text generation without recompute
parent
2623551d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
15 deletions
+22
-15
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+22
-15
No files found.
megatron/text_generation_utils.py
View file @
1d4e8760
...
@@ -138,23 +138,23 @@ def generate_samples_input_from_file(model):
...
@@ -138,23 +138,23 @@ def generate_samples_input_from_file(model):
group
=
mpu
.
get_model_parallel_group
())
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
input_info_tensor
[
0
].
item
()
terminate_runs
=
input_info_tensor
[
0
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
context_length
=
input_info_tensor
[
2
].
item
()
if
terminate_runs
==
1
:
if
terminate_runs
==
1
:
return
return
# For pipeline parallel we send context tokens to
last
stage
# For pipeline parallel we send context tokens to
other
stage
s
# so
it knows when to start overwriting
# so
they get the lengths correct
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
and
args
.
pipeline_model_parallel_size
>
1
:
and
args
.
pipeline_model_parallel_size
>
1
:
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_
embedding
_group
()
group
=
mpu
.
get_
pipeline_model_parallel
_group
()
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
if
mpu
.
is_pipeline_last_stage
()
:
else
:
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
context_length
=
input_info_tensor
[
2
].
item
()
context_tokens_tensor
=
torch
.
empty
(
context_length
,
context_tokens_tensor
=
torch
.
empty
(
context_length
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
device
=
torch
.
device
(
"cuda"
))
...
@@ -229,23 +229,23 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -229,23 +229,23 @@ def generate_samples_interactive(model, print_frequency=24):
group
=
mpu
.
get_model_parallel_group
())
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
input_info_tensor
[
0
].
item
()
terminate_runs
=
input_info_tensor
[
0
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
context_length
=
input_info_tensor
[
2
].
item
()
if
terminate_runs
==
1
:
if
terminate_runs
==
1
:
return
return
# For pipeline parallel we send context tokens to
last
stage
# For pipeline parallel we send context tokens to
other
stage
s
# so
it knows when to start overwriting
# so
they get the lengths correct
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
and
args
.
pipeline_model_parallel_size
>
1
:
and
args
.
pipeline_model_parallel_size
>
1
:
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_
embedding
_group
()
group
=
mpu
.
get_
pipeline_model_parallel
_group
()
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
if
mpu
.
is_pipeline_last_stage
()
:
else
:
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
context_length
=
input_info_tensor
[
2
].
item
()
context_tokens_tensor
=
torch
.
empty
(
context_length
,
context_tokens_tensor
=
torch
.
empty
(
context_length
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
device
=
torch
.
device
(
"cuda"
))
...
@@ -253,6 +253,7 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -253,6 +253,7 @@ def generate_samples_interactive(model, print_frequency=24):
context_tokens
=
context_tokens_tensor
.
cpu
().
numpy
().
tolist
()
context_tokens
=
context_tokens_tensor
.
cpu
().
numpy
().
tolist
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
if
counter
%
print_frequency
!=
0
\
if
counter
%
print_frequency
!=
0
\
or
mpu
.
get_tensor_model_parallel_rank
()
!=
0
\
or
mpu
.
get_tensor_model_parallel_rank
()
!=
0
\
...
@@ -394,6 +395,12 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
...
@@ -394,6 +395,12 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past
=
None
,
get_key_value
=
None
,
layer_past
=
None
,
get_key_value
=
None
,
forward_method_parallel_output
=
None
):
forward_method_parallel_output
=
None
):
# Hidden size changes when not using recompute, need to tell communicate()
# the correct size
args
=
get_args
()
orig_seq_length
=
args
.
seq_length
args
.
seq_length
=
tokens
.
shape
[
1
]
if
not
mpu
.
is_pipeline_first_stage
():
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
,
_
=
communicate
(
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
...
@@ -437,8 +444,8 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
...
@@ -437,8 +444,8 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_forward
=
False
,
recv_backward
=
False
)
recv_backward
=
False
)
return
None
args
.
seq_length
=
orig_seq_length
if
get_key_value
:
if
get_key_value
:
return
output_tensor
,
layer_past
return
output_tensor
,
layer_past
return
output_tensor
return
output_tensor
...
@@ -495,7 +502,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -495,7 +502,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if
type_ids
is
not
None
:
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
batch_size
,
-
1
)
logits
,
layer_past
=
forward_step
(
model
,
tokens2use
,
output
,
layer_past
=
forward_step
(
model
,
tokens2use
,
positions2use
,
positions2use
,
attention_mask
,
attention_mask
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
...
@@ -504,7 +511,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -504,7 +511,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
forward_method_parallel_output
=
False
)
forward_method_parallel_output
=
False
)
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
assert
output
is
not
None
assert
output
is
not
None
logits
=
logits
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
logits
=
output
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
if
args
.
greedy
:
if
args
.
greedy
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment