Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e5034150
Commit
e5034150
authored
Jul 15, 2022
by
peng xu
Browse files
fix pipeline parallel for beam search
parent
da11c982
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
25 deletions
+48
-25
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+48
-25
No files found.
megatron/text_generation/generation.py
View file @
e5034150
...
...
@@ -300,10 +300,12 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
forward_step
=
ForwardStep
(
model
,
beam_size
,
final_sequence_length
)
beam_hyp
=
BeamHypotheses
(
beam_size
,
length_penalty
)
done
=
False
best_batches
=
None
done
=
torch
.
zeros
(
1
,
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_device
())
scores
=
torch
.
zeros
(
beam_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()).
unsqueeze
(
1
)
scores_size_tensor
,
tokens_size_tensor
=
None
,
None
# =============
# Run infernece
# =============
...
...
@@ -321,6 +323,10 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# logits will be meanigful only in the last pipeline stage.
logits
=
forward_step
(
tokens2use
,
positions2use
,
attention_mask2use
)
# if mpu.is_pipeline_first_stage():
# print('-' * 40)
# print(tokens[:, context_length-5:context_length+5])
# print(context_length)
if
mpu
.
is_pipeline_last_stage
():
vocab_size
=
logits
.
size
(
2
)
...
...
@@ -335,6 +341,10 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
best_beam_ids
=
torch
.
div
(
indices
[:
2
*
beam_size
],
vocab_size
).
trunc
().
long
()
best_words
=
indices
[:
2
*
beam_size
]
%
vocab_size
best_scores
=
sorted_scores
[:
2
*
beam_size
]
# print('*' * 40)
# print(best_beam_ids)
# print(best_words)
# print(context_length)
next_beams
=
[]
for
beam_token_rank
,
(
token_id
,
beam_score
,
beam_id
)
in
enumerate
(
...
...
@@ -358,40 +368,53 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
break
if
beam_hyp
.
is_done
(
best_scores
.
max
().
item
(),
context_length
+
1
-
prompt_length
):
done
=
True
break
done
=
torch
.
ones
(
1
,
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_device
())
print
(
"find all hyp exiting"
)
best_batches
=
tokens
.
new
([
item
[
2
]
for
item
in
next_beams
])
tokens
=
tokens
[
best_batches
,:]
tokens
[:,
context_length
]
=
tokens
.
new
([
item
[
0
]
for
item
in
next_beams
])
scores
=
scores
.
new
([
item
[
1
]
for
item
in
next_beams
]).
unsqueeze
(
1
)
# set inference key values to make it consistent with best beam index
forward_step
.
inference_params
.
swap_key_value_dict
(
best_batches
)
# torch.distributed.barrier()
done
=
broadcast_from_last_pipeline_stage
(
1
,
torch
.
uint8
,
done
)
if
done
:
print
(
"break for loop"
)
break
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage
(
batch_size
,
torch
.
int64
,
tokens
[:,
context_length
])
copy_from_last_to_first_pipeline_stage
(
tokens
.
size
(),
torch
.
int64
,
tokens
)
# set inference key values to make it consistent with best beam index
best_batches
=
broadcast_from_last_pipeline_stage
(
beam_size
,
torch
.
int64
,
best_batches
)
forward_step
.
inference_params
.
swap_key_value_dict
(
best_batches
)
# Update the context length for the next token generation.
prev_context_length
=
context_length
copy_from_last_to_first_pipeline_stage
(
scores
.
size
(
0
),
torch
.
float32
,
scores
[:,
0
])
# if cannot find stop token, add open beams to hyps
if
not
done
:
for
beam_id
in
range
(
beam_size
):
beam_hyp
.
add
(
tokens
[
beam_id
].
clone
(),
scores
[
beam_id
],
context_length
+
1
-
prompt_length
)
# rank based on scores
sorted_hyps
=
sorted
(
beam_hyp
.
beams
,
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
num_return_gen
=
min
(
num_return_gen
,
len
(
sorted_hyps
))
scores
=
[
sorted_hyps
[
i
][
0
]
for
i
in
range
(
num_return_gen
)]
tokens
=
[
sorted_hyps
[
i
][
1
]
for
i
in
range
(
num_return_gen
)]
scores
=
torch
.
stack
(
scores
,
dim
=
0
)
tokens
=
torch
.
stack
(
tokens
,
dim
=
0
)
if
mpu
.
is_pipeline_last_stage
():
# if cannot find stop token, add open beams to hyps
if
not
done
:
for
beam_id
in
range
(
beam_size
):
beam_hyp
.
add
(
tokens
[
beam_id
].
clone
(),
scores
[
beam_id
],
context_length
+
1
-
prompt_length
)
# rank based on scores
sorted_hyps
=
sorted
(
beam_hyp
.
beams
,
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
num_return_gen
=
min
(
num_return_gen
,
len
(
sorted_hyps
))
scores
=
[
sorted_hyps
[
i
][
0
]
for
i
in
range
(
num_return_gen
)]
tokens
=
[
sorted_hyps
[
i
][
1
]
for
i
in
range
(
num_return_gen
)]
scores
=
torch
.
stack
(
scores
,
dim
=
0
)
tokens
=
torch
.
stack
(
tokens
,
dim
=
0
)
scores_size_tensor
=
torch
.
tensor
(
scores
.
shape
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
tokens_size_tensor
=
torch
.
tensor
(
tokens
.
shape
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
scores_size_tensor
=
broadcast_from_last_pipeline_stage
(
1
,
torch
.
int64
,
scores_size_tensor
)
tokens_size_tensor
=
broadcast_from_last_pipeline_stage
(
2
,
torch
.
int64
,
tokens_size_tensor
)
scores
=
broadcast_from_last_to_first_pipeline_stage
(
tuple
(
scores_size_tensor
),
torch
.
float32
,
scores
)
tokens
=
broadcast_from_last_to_first_pipeline_stage
(
tuple
(
tokens_size_tensor
),
torch
.
int64
,
tokens
)
return
tokens
,
scores
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment