Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4f579b55
Commit
4f579b55
authored
Apr 14, 2022
by
Peng Xu
Browse files
fix beam search
parent
de84b2af
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
111 additions
and
10 deletions
+111
-10
megatron/text_generation/api.py
megatron/text_generation/api.py
+2
-2
megatron/text_generation/forward_step.py
megatron/text_generation/forward_step.py
+12
-1
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+97
-7
No files found.
megatron/text_generation/api.py
View file @
4f579b55
...
@@ -163,7 +163,7 @@ def beam_search_and_post_process(model,
...
@@ -163,7 +163,7 @@ def beam_search_and_post_process(model,
return
None
return
None
def
beam_search
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
beam_size
=
0
,
add_BOS
=
False
):
def
beam_search
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
beam_size
=
0
,
add_BOS
=
False
,
stop_token
=
50256
):
# Make sure input params are avaialble to all ranks.
# Make sure input params are avaialble to all ranks.
values
=
[
tokens_to_generate
,
values
=
[
tokens_to_generate
,
beam_size
,
beam_size
,
...
@@ -176,4 +176,4 @@ def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=
...
@@ -176,4 +176,4 @@ def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
add_BOS
=
add_BOS
)
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
add_BOS
=
add_BOS
)
return
beam_search_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
,
beam_size
)
return
beam_search_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
,
beam_size
,
stop_token
=
stop_token
)
megatron/text_generation/forward_step.py
View file @
4f579b55
...
@@ -42,7 +42,18 @@ class InferenceParams:
...
@@ -42,7 +42,18 @@ class InferenceParams:
self
.
batch_size_offset
=
0
self
.
batch_size_offset
=
0
self
.
key_value_memory_dict
=
{}
self
.
key_value_memory_dict
=
{}
def
swap_key_value_dict
(
self
,
batch_idx
):
"swap between batches"
if
len
(
self
.
key_value_memory_dict
)
==
0
:
raise
ValueError
(
"should not swap when dict in empty"
)
for
layer_number
in
self
.
key_value_memory_dict
.
keys
():
inference_key_memory
,
inference_value_memory
=
self
.
key_value_memory_dict
[
layer_number
]
assert
len
(
batch_idx
)
==
inference_key_memory
.
shape
[
1
]
## make sure batch size is the same
new_inference_key_memory
=
inference_key_memory
[:,
batch_idx
]
new_inference_value_memory
=
inference_value_memory
[:,
batch_idx
]
self
.
key_value_memory_dict
[
layer_number
]
=
(
new_inference_key_memory
,
new_inference_value_memory
)
class
ForwardStep
:
class
ForwardStep
:
"""Forward step function with all the communications.
"""Forward step function with all the communications.
...
...
megatron/text_generation/generation.py
View file @
4f579b55
...
@@ -258,7 +258,7 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -258,7 +258,7 @@ def generate_tokens_probs_and_return_on_first_stage(
tensor
=
done
)
tensor
=
done
)
if
use_eod_token_for_early_termination
and
done
:
if
use_eod_token_for_early_termination
and
done
:
break
break
# ===================================================
# ===================================================
# Update the length of based on max generated length.
# Update the length of based on max generated length.
# ===================================================
# ===================================================
...
@@ -281,8 +281,54 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -281,8 +281,54 @@ def generate_tokens_probs_and_return_on_first_stage(
return
tokens
,
generated_sequence_lengths
,
output_log_probs
return
tokens
,
generated_sequence_lengths
,
output_log_probs
## from huggingface beam search
def
beam_search_and_return_on_first_stage
(
model
,
tokens
,
lengths
,
beam_size
):
class
BeamHypotheses
(
object
):
def
__init__
(
self
,
num_beams
,
length_penalty
=
1.0
,
early_stopping
=
False
):
"""
Initialize n-best list of hypotheses.
"""
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
self
.
num_beams
=
num_beams
self
.
beams
=
[]
self
.
worst_score
=
1e9
def
__len__
(
self
):
"""
Number of hypotheses in the list.
"""
return
len
(
self
.
beams
)
def
add
(
self
,
hyp
,
sum_logprobs
,
length
):
"""
Add a new hypothesis to the list.
"""
score
=
sum_logprobs
/
length
**
self
.
length_penalty
if
len
(
self
)
<
self
.
num_beams
or
score
>
self
.
worst_score
:
self
.
beams
.
append
((
score
,
hyp
))
if
len
(
self
)
>
self
.
num_beams
:
sorted_scores
=
sorted
([(
s
,
idx
)
for
idx
,
(
s
,
_
)
in
enumerate
(
self
.
beams
)])
del
self
.
beams
[
sorted_scores
[
0
][
1
]]
self
.
worst_score
=
sorted_scores
[
1
][
0
]
else
:
self
.
worst_score
=
min
(
score
,
self
.
worst_score
)
def
is_done
(
self
,
best_sum_logprobs
,
cur_len
):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if
len
(
self
)
<
self
.
num_beams
:
return
False
elif
self
.
early_stopping
:
return
True
else
:
cur_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
ret
=
self
.
worst_score
>=
cur_score
return
ret
def
beam_search_and_return_on_first_stage
(
model
,
tokens
,
lengths
,
beam_size
,
stop_token
):
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
...
@@ -299,6 +345,8 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size):
...
@@ -299,6 +345,8 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size):
# forward step.
# forward step.
forward_step
=
ForwardStep
(
model
,
beam_size
,
final_sequence_length
)
forward_step
=
ForwardStep
(
model
,
beam_size
,
final_sequence_length
)
hyp
=
BeamHypotheses
(
beam_size
)
done
=
False
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
scores
=
torch
.
zeros
(
beam_size
,
scores
=
torch
.
zeros
(
beam_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
@@ -331,13 +379,43 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size):
...
@@ -331,13 +379,43 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size):
else
:
else
:
sorted_scores
,
indices
=
torch
.
sort
(
new_scores
.
view
(
-
1
),
descending
=
True
)
sorted_scores
,
indices
=
torch
.
sort
(
new_scores
.
view
(
-
1
),
descending
=
True
)
best_batches
=
torch
.
div
(
indices
[:
beam_size
],
vocab_size
,
rounding_mode
=
'floor'
)
best_beam_ids
=
torch
.
div
(
indices
[:
2
*
beam_size
],
vocab_size
).
trunc
().
long
()
best_words
=
indices
[:
beam_size
]
%
vocab_size
best_words
=
indices
[:
2
*
beam_size
]
%
vocab_size
best_scores
=
sorted_scores
[:
2
*
beam_size
]
next_beams
=
[]
for
beam_token_rank
,
(
token_id
,
beam_score
,
beam_id
)
in
enumerate
(
zip
(
best_words
,
best_scores
,
best_beam_ids
)
):
if
token_id
.
item
()
==
stop_token
:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams
=
beam_token_rank
>=
beam_size
if
is_beam_token_worse_than_top_num_beams
:
continue
hyp
.
add
(
tokens
[
beam_id
].
clone
(),
beam_score
,
context_length
+
1
-
prompt_length
)
else
:
# add next predicted token since it is not eos_token
next_beams
.
append
((
token_id
,
beam_score
,
beam_id
))
if
len
(
next_beams
)
==
beam_size
:
break
if
hyp
.
is_done
(
best_scores
.
max
().
item
(),
context_length
+
1
-
prompt_length
):
done
=
True
break
best_batches
=
tokens
.
new
([
item
[
2
]
for
item
in
next_beams
])
tokens
=
tokens
[
best_batches
,:]
tokens
=
tokens
[
best_batches
,:]
tokens
[:,
context_length
]
=
best_words
tokens
[:,
context_length
]
=
tokens
.
new
([
item
[
0
]
for
item
in
next_beams
])
scores
=
sor
ted_scores
[:
beam_size
]
.
unsqueeze
(
1
)
scores
=
s
c
or
es
.
new
([
item
[
1
]
for
item
in
next_beams
])
.
unsqueeze
(
1
)
# set inference key values to make it consistent with best beam index
forward_step
.
inference_params
.
swap_key_value_dict
(
best_batches
)
# Update the tokens on the first stage so the next input to
# Update the tokens on the first stage so the next input to
# the network is correct.
# the network is correct.
copy_from_last_to_first_pipeline_stage
(
batch_size
,
torch
.
int64
,
copy_from_last_to_first_pipeline_stage
(
batch_size
,
torch
.
int64
,
...
@@ -348,6 +426,18 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size):
...
@@ -348,6 +426,18 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size):
copy_from_last_to_first_pipeline_stage
(
scores
.
size
(
0
),
torch
.
float32
,
copy_from_last_to_first_pipeline_stage
(
scores
.
size
(
0
),
torch
.
float32
,
scores
[:,
0
])
scores
[:,
0
])
# if cannot find stop token, add open beams to hyps
if
not
done
:
for
beam_id
in
range
(
beam_size
):
hyp
.
add
(
tokens
[
beam_id
].
clone
(),
scores
[
beam_id
],
context_length
+
1
-
prompt_length
)
# rank based on scores
sorted_hyps
=
sorted
(
hyp
.
beams
,
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
scores
,
tokens
=
sorted_hyps
[
0
]
scores
=
scores
.
unsqueeze
(
0
)
tokens
=
tokens
.
unsqueeze
(
0
)
return
tokens
,
scores
return
tokens
,
scores
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment