Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a9a3ef50
"vscode:/vscode.git/clone" did not exist on "3ed4c0f33fee281fbdc276e208574e22821818d9"
Commit
a9a3ef50
authored
Jun 30, 2021
by
rprenger
Browse files
Simpler broadcasting and some clean up
parent
5580d661
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
27 deletions
+21
-27
megatron/api_server.py
megatron/api_server.py
+19
-25
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+2
-2
No files found.
megatron/api_server.py
View file @
a9a3ef50
...
...
@@ -54,17 +54,11 @@ class MegatronGenerate(Resource):
# Send the sizes of the tensors
input_info
=
[
context_tokens_tensor
.
size
(
0
),
context_tokens_tensor
.
size
(
1
),
max_len
]
input_info_tensor
=
torch
.
cuda
.
LongTensor
(
input_info
)
torch
.
distributed
.
broadcast
(
input_info_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
# Now send tensors
torch
.
distributed
.
broadcast
(
context_length_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
# Send variables to all ranks
torch
.
distributed
.
broadcast
(
context_length_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
0
)
@
staticmethod
def
receive_generate_info
():
...
...
@@ -72,9 +66,7 @@ class MegatronGenerate(Resource):
Needs to be synced up with send_generate_info
"""
input_info_tensor
=
torch
.
empty
(
3
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
input_info_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
batch_size
=
input_info_tensor
[
0
].
item
()
seq_len
=
input_info_tensor
[
1
].
item
()
max_len
=
input_info_tensor
[
2
].
item
()
...
...
@@ -82,12 +74,10 @@ class MegatronGenerate(Resource):
context_length_tensor
=
torch
.
empty
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
context_tokens_tensor
=
torch
.
empty
(
batch_size
,
seq_len
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
context_length_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
# Send variables to all ranks
torch
.
distributed
.
broadcast
(
context_length_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
0
)
return
context_length_tensor
,
context_tokens_tensor
,
max_len
@
staticmethod
...
...
@@ -100,22 +90,26 @@ class MegatronGenerate(Resource):
return
decode_tokens
def
put
(
self
):
args
=
get_args
()
sentences
=
request
.
get_json
()[
"sentences"
]
max_len
=
1024
# TODO (rprenger) this should not be hardcoded
max_len
=
args
.
seq_length
if
"max_len"
in
request
.
get_json
():
max_len
=
request
.
get_json
()[
"max_len"
]
input_max_len
=
request
.
get_json
()[
"max_len"
]
if
input_max_len
<
args
.
seq_length
:
max_len
=
input_max_len
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
)
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
MegatronGenerate
.
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
max_len
)
# Send them info
decode_tokens
=
MegatronGenerate
.
do_generate
(
self
.
model
,
context_length_tensor
,
context_tokens_tensor
,
max_len
)
# Do stuff
args
=
get_args
()
tokenizer
=
get_tokenizer
()
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)
return
jsonify
({
"sentences"
:
[
trim_decode_tokens
]})
resp_sentences
=
[]
for
i
in
range
(
decode_tokens
.
size
(
0
)):
decode_token
=
decode_tokens
[
i
,:].
cpu
().
numpy
().
tolist
()
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
return
jsonify
({
"sentences"
:
resp_sentences
})
class
MegatronServer
(
object
):
...
...
megatron/text_generation_utils.py
View file @
a9a3ef50
...
...
@@ -40,7 +40,8 @@ def get_batch(context_tokens):
tokenizer
=
get_tokenizer
()
# Move to GPU.
tokens
=
context_tokens
.
view
(
args
.
micro_batch_size
,
-
1
).
contiguous
().
cuda
()
tokens
=
context_tokens
.
contiguous
().
cuda
()
# Get the attention mask and postition ids.
attention_mask
,
_
,
position_ids
=
get_ltor_masks_and_position_ids
(
tokens
,
...
...
@@ -464,7 +465,6 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
maxlen
=
None
,
type_ids
=
None
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment