Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
96816d3d
Commit
96816d3d
authored
May 03, 2022
by
peng xu
Browse files
rename hyp and allow return multiple samples
parent
4f579b55
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
13 deletions
+19
-13
megatron/text_generation/api.py
megatron/text_generation/api.py
+9
-4
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+10
-9
No files found.
megatron/text_generation/api.py
View file @
96816d3d
...
@@ -144,7 +144,9 @@ def beam_search_and_post_process(model,
...
@@ -144,7 +144,9 @@ def beam_search_and_post_process(model,
prompts
=
None
,
prompts
=
None
,
tokens_to_generate
=
0
,
tokens_to_generate
=
0
,
beam_size
=
0
,
beam_size
=
0
,
add_BOS
=
False
):
add_BOS
=
False
,
stop_token
=
50256
,
num_return_gen
=
1
):
"""Run beam search and post-process outputs, i.e., detokenize,
"""Run beam search and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
move to cpu and convert to list."""
...
@@ -153,7 +155,9 @@ def beam_search_and_post_process(model,
...
@@ -153,7 +155,9 @@ def beam_search_and_post_process(model,
prompts
=
prompts
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
tokens_to_generate
=
tokens_to_generate
,
beam_size
=
beam_size
,
beam_size
=
beam_size
,
add_BOS
=
add_BOS
)
add_BOS
=
add_BOS
,
stop_token
=
stop_token
,
num_return_gen
=
num_return_gen
)
# Only post-process on first stage.
# Only post-process on first stage.
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
lengths
=
tokens
.
size
(
1
)
*
torch
.
ones
(
beam_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
lengths
=
tokens
.
size
(
1
)
*
torch
.
ones
(
beam_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
...
@@ -163,7 +167,7 @@ def beam_search_and_post_process(model,
...
@@ -163,7 +167,7 @@ def beam_search_and_post_process(model,
return
None
return
None
def
beam_search
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
beam_size
=
0
,
add_BOS
=
False
,
stop_token
=
50256
):
def
beam_search
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
beam_size
=
0
,
add_BOS
=
False
,
stop_token
=
50256
,
num_return_gen
=
1
):
# Make sure input params are avaialble to all ranks.
# Make sure input params are avaialble to all ranks.
values
=
[
tokens_to_generate
,
values
=
[
tokens_to_generate
,
beam_size
,
beam_size
,
...
@@ -176,4 +180,5 @@ def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=
...
@@ -176,4 +180,5 @@ def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
add_BOS
=
add_BOS
)
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
add_BOS
=
add_BOS
)
return
beam_search_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
,
beam_size
,
stop_token
=
stop_token
)
return
beam_search_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
,
beam_size
,
stop_token
=
stop_token
,
num_return_gen
=
num_return_gen
)
megatron/text_generation/generation.py
View file @
96816d3d
...
@@ -328,7 +328,7 @@ class BeamHypotheses(object):
...
@@ -328,7 +328,7 @@ class BeamHypotheses(object):
ret
=
self
.
worst_score
>=
cur_score
ret
=
self
.
worst_score
>=
cur_score
return
ret
return
ret
def
beam_search_and_return_on_first_stage
(
model
,
tokens
,
lengths
,
beam_size
,
stop_token
):
def
beam_search_and_return_on_first_stage
(
model
,
tokens
,
lengths
,
beam_size
,
stop_token
,
num_return_gen
=
1
):
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
...
@@ -345,7 +345,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
...
@@ -345,7 +345,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# forward step.
# forward step.
forward_step
=
ForwardStep
(
model
,
beam_size
,
final_sequence_length
)
forward_step
=
ForwardStep
(
model
,
beam_size
,
final_sequence_length
)
hyp
=
BeamHypotheses
(
beam_size
)
beam_
hyp
=
BeamHypotheses
(
beam_size
)
done
=
False
done
=
False
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
scores
=
torch
.
zeros
(
beam_size
,
scores
=
torch
.
zeros
(
beam_size
,
...
@@ -392,7 +392,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
...
@@ -392,7 +392,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
is_beam_token_worse_than_top_num_beams
=
beam_token_rank
>=
beam_size
is_beam_token_worse_than_top_num_beams
=
beam_token_rank
>=
beam_size
if
is_beam_token_worse_than_top_num_beams
:
if
is_beam_token_worse_than_top_num_beams
:
continue
continue
hyp
.
add
(
beam_
hyp
.
add
(
tokens
[
beam_id
].
clone
(),
tokens
[
beam_id
].
clone
(),
beam_score
,
beam_score
,
context_length
+
1
-
prompt_length
context_length
+
1
-
prompt_length
...
@@ -404,7 +404,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
...
@@ -404,7 +404,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
if
len
(
next_beams
)
==
beam_size
:
if
len
(
next_beams
)
==
beam_size
:
break
break
if
hyp
.
is_done
(
best_scores
.
max
().
item
(),
context_length
+
1
-
prompt_length
):
if
beam_
hyp
.
is_done
(
best_scores
.
max
().
item
(),
context_length
+
1
-
prompt_length
):
done
=
True
done
=
True
break
break
...
@@ -430,13 +430,14 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
...
@@ -430,13 +430,14 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# if cannot find stop token, add open beams to hyps
# if cannot find stop token, add open beams to hyps
if
not
done
:
if
not
done
:
for
beam_id
in
range
(
beam_size
):
for
beam_id
in
range
(
beam_size
):
hyp
.
add
(
tokens
[
beam_id
].
clone
(),
scores
[
beam_id
],
context_length
+
1
-
prompt_length
)
beam_
hyp
.
add
(
tokens
[
beam_id
].
clone
(),
scores
[
beam_id
],
context_length
+
1
-
prompt_length
)
# rank based on scores
# rank based on scores
sorted_hyps
=
sorted
(
hyp
.
beams
,
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
sorted_hyps
=
sorted
(
beam_hyp
.
beams
,
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
scores
,
tokens
=
sorted_hyps
[
0
]
scores
=
[
sorted_hyps
[
i
][
0
]
for
i
in
range
(
num_return_gen
)]
scores
=
scores
.
unsqueeze
(
0
)
tokens
=
[
sorted_hyps
[
i
][
1
]
for
i
in
range
(
num_return_gen
)]
tokens
=
tokens
.
unsqueeze
(
0
)
scores
=
torch
.
stack
(
scores
,
dim
=
0
)
tokens
=
torch
.
stack
(
tokens
,
dim
=
0
)
return
tokens
,
scores
return
tokens
,
scores
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment