Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b46482e8
Commit
b46482e8
authored
Sep 14, 2021
by
rprenger
Browse files
Fixes a bug in broadcasting that was causing hanging
parent
593b47b4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
2 additions
and
10 deletions
+2
-10
megatron/text_generation_server.py
megatron/text_generation_server.py
+1
-3
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+0
-4
tools/run_text_generation_server.py
tools/run_text_generation_server.py
+1
-3
No files found.
megatron/text_generation_server.py
View file @
b46482e8
...
...
@@ -30,9 +30,7 @@ class MegatronGenerate(Resource):
@
staticmethod
def
send_do_generate
():
choice
=
torch
.
cuda
.
LongTensor
([
GENERATE_NUM
])
torch
.
distributed
.
broadcast
(
choice
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
choice
,
0
)
def
put
(
self
):
args
=
get_args
()
...
...
megatron/text_generation_utils.py
View file @
b46482e8
...
...
@@ -141,7 +141,6 @@ def receive_generate_info():
def
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_probs
):
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
,
...
...
@@ -172,7 +171,6 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
group
=
mpu
.
get_embedding_group
()
full_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
,
args
.
padded_vocab_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
full_logits
,
src
,
group
)
if
tokens
is
not
None
:
return
tokens
[:,
:
context_length
],
output_logits
,
full_logits
...
...
@@ -310,7 +308,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
output
,
layer_past
=
forward_step
(
model
,
tokens2use
,
positions2use
,
attention_mask
,
...
...
@@ -332,7 +329,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
top_p
=
args
.
top_p
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
started
=
context_lengths
<=
context_length
new_tokens
=
switch
(
...
...
tools/run_text_generation_server.py
View file @
b46482e8
...
...
@@ -78,8 +78,6 @@ if __name__ == "__main__":
while
True
:
choice
=
torch
.
cuda
.
LongTensor
(
1
)
torch
.
distributed
.
broadcast
(
choice
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
choice
,
0
)
if
choice
[
0
].
item
()
==
0
:
generate
(
model
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment