Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
1d4e8760
"vscode:/vscode.git/clone" did not exist on "ba1b684fec023821bc41082abd5326ea4d88ec30"
Commit
1d4e8760
authored
Dec 10, 2020
by
Jared Casper
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Fix text generation without recompute
parent
2623551d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
15 deletions
+22
-15
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+22
-15
No files found.
megatron/text_generation_utils.py
View file @
1d4e8760
...
@@ -138,23 +138,23 @@ def generate_samples_input_from_file(model):
...
@@ -138,23 +138,23 @@ def generate_samples_input_from_file(model):
group
=
mpu
.
get_model_parallel_group
())
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
input_info_tensor
[
0
].
item
()
terminate_runs
=
input_info_tensor
[
0
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
context_length
=
input_info_tensor
[
2
].
item
()
if
terminate_runs
==
1
:
if
terminate_runs
==
1
:
return
return
# For pipeline parallel we send context tokens to
last
stage
# For pipeline parallel we send context tokens to
other
stage
s
# so
it knows when to start overwriting
# so
they get the lengths correct
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
and
args
.
pipeline_model_parallel_size
>
1
:
and
args
.
pipeline_model_parallel_size
>
1
:
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_
embedding
_group
()
group
=
mpu
.
get_
pipeline_model_parallel
_group
()
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
if
mpu
.
is_pipeline_last_stage
()
:
else
:
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
context_length
=
input_info_tensor
[
2
].
item
()
context_tokens_tensor
=
torch
.
empty
(
context_length
,
context_tokens_tensor
=
torch
.
empty
(
context_length
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
device
=
torch
.
device
(
"cuda"
))
...
@@ -229,23 +229,23 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -229,23 +229,23 @@ def generate_samples_interactive(model, print_frequency=24):
group
=
mpu
.
get_model_parallel_group
())
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
input_info_tensor
[
0
].
item
()
terminate_runs
=
input_info_tensor
[
0
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
raw_text_len
=
input_info_tensor
[
1
].
item
()
context_length
=
input_info_tensor
[
2
].
item
()
if
terminate_runs
==
1
:
if
terminate_runs
==
1
:
return
return
# For pipeline parallel we send context tokens to
last
stage
# For pipeline parallel we send context tokens to
other
stage
s
# so
it knows when to start overwriting
# so
they get the lengths correct
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
\
and
args
.
pipeline_model_parallel_size
>
1
:
and
args
.
pipeline_model_parallel_size
>
1
:
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_
embedding
_group
()
group
=
mpu
.
get_
pipeline_model_parallel
_group
()
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
src
,
group
)
if
mpu
.
is_pipeline_last_stage
()
:
else
:
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
src
=
mpu
.
get_pipeline_model_parallel_first_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
context_length
=
input_info_tensor
[
2
].
item
()
context_tokens_tensor
=
torch
.
empty
(
context_length
,
context_tokens_tensor
=
torch
.
empty
(
context_length
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
device
=
torch
.
device
(
"cuda"
))
...
@@ -253,6 +253,7 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -253,6 +253,7 @@ def generate_samples_interactive(model, print_frequency=24):
context_tokens
=
context_tokens_tensor
.
cpu
().
numpy
().
tolist
()
context_tokens
=
context_tokens_tensor
.
cpu
().
numpy
().
tolist
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
if
counter
%
print_frequency
!=
0
\
if
counter
%
print_frequency
!=
0
\
or
mpu
.
get_tensor_model_parallel_rank
()
!=
0
\
or
mpu
.
get_tensor_model_parallel_rank
()
!=
0
\
...
@@ -394,6 +395,12 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
...
@@ -394,6 +395,12 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past
=
None
,
get_key_value
=
None
,
layer_past
=
None
,
get_key_value
=
None
,
forward_method_parallel_output
=
None
):
forward_method_parallel_output
=
None
):
# Hidden size changes when not using recompute, need to tell communicate()
# the correct size
args
=
get_args
()
orig_seq_length
=
args
.
seq_length
args
.
seq_length
=
tokens
.
shape
[
1
]
if
not
mpu
.
is_pipeline_first_stage
():
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
,
_
=
communicate
(
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
...
@@ -437,8 +444,8 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
...
@@ -437,8 +444,8 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_forward
=
False
,
recv_backward
=
False
)
recv_backward
=
False
)
return
None
args
.
seq_length
=
orig_seq_length
if
get_key_value
:
if
get_key_value
:
return
output_tensor
,
layer_past
return
output_tensor
,
layer_past
return
output_tensor
return
output_tensor
...
@@ -495,7 +502,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -495,7 +502,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if
type_ids
is
not
None
:
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
batch_size
,
-
1
)
logits
,
layer_past
=
forward_step
(
model
,
tokens2use
,
output
,
layer_past
=
forward_step
(
model
,
tokens2use
,
positions2use
,
positions2use
,
attention_mask
,
attention_mask
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
...
@@ -504,7 +511,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -504,7 +511,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
forward_method_parallel_output
=
False
)
forward_method_parallel_output
=
False
)
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
assert
output
is
not
None
assert
output
is
not
None
logits
=
logits
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
logits
=
output
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
if
args
.
greedy
:
if
args
.
greedy
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment