Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6be75e2a
"awq/git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "4d49ae1f75e666aa930f35e81f16d33c92543696"
Commit
6be75e2a
authored
May 16, 2022
by
rprenger
Browse files
Fixing beam search in distributed mode
parent
fd176a90
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
7 deletions
+6
-7
megatron/text_generation/api.py
megatron/text_generation/api.py
+1
-1
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+4
-5
megatron/text_generation_server.py
megatron/text_generation_server.py
+1
-1
No files found.
megatron/text_generation/api.py
View file @
6be75e2a
...
...
@@ -177,7 +177,7 @@ def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=
stop_token
,
num_return_gen
,
length_penalty
]
values_float_tensor
=
broadcast_float_list
(
3
,
float_list
=
values
)
values_float_tensor
=
broadcast_float_list
(
6
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
beam_size
=
int
(
values_float_tensor
[
1
].
item
())
add_BOS
=
bool
(
values_float_tensor
[
2
].
item
())
...
...
megatron/text_generation/generation.py
View file @
6be75e2a
...
...
@@ -347,10 +347,9 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
beam_hyp
=
BeamHypotheses
(
beam_size
,
length_penalty
)
done
=
False
if
mpu
.
is_pipeline_last_stage
():
scores
=
torch
.
zeros
(
beam_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()).
unsqueeze
(
1
)
scores
=
torch
.
zeros
(
beam_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()).
unsqueeze
(
1
)
# =============
# Run infernece
# =============
...
...
@@ -368,9 +367,9 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# logits will be meanigful only in the last pipeline stage.
logits
=
forward_step
(
tokens2use
,
positions2use
,
attention_mask2use
)
vocab_size
=
logits
.
size
(
2
)
if
mpu
.
is_pipeline_last_stage
():
vocab_size
=
logits
.
size
(
2
)
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
new_scores
=
log_probs
[:,
-
1
,
:]
+
scores
...
...
megatron/text_generation_server.py
View file @
6be75e2a
...
...
@@ -24,7 +24,7 @@ from megatron.text_generation import beam_search_and_post_process
GENERATE_NUM
=
0
BEAM_NUM
=
0
BEAM_NUM
=
1
lock
=
threading
.
Lock
()
class
MegatronGenerate
(
Resource
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment