Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
811183f0
"vscode:/vscode.git/clone" did not exist on "2dd7d0c533deecd9e4fea682f1d13fd8e7e9b8a2"
Commit
811183f0
authored
Aug 11, 2021
by
rprenger
Browse files
Got it working on the full big model
parent
ddd36145
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
8 deletions
+11
-8
megatron/api_server.py
megatron/api_server.py
+3
-2
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+8
-6
No files found.
megatron/api_server.py
View file @
811183f0
...
@@ -54,12 +54,13 @@ class MegatronGenerate(Resource):
...
@@ -54,12 +54,13 @@ class MegatronGenerate(Resource):
return
"all_probs must be a boolean value"
return
"all_probs must be a boolean value"
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
resp_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
=
generate
(
self
.
model
,
sentences
,
max_len
,
all_probs
)
resp_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
,
tokens
=
generate
(
self
.
model
,
sentences
,
max_len
,
all_probs
)
if
all_probs
:
if
all_probs
:
return
jsonify
({
"sentences"
:
resp_sentences
,
return
jsonify
({
"sentences"
:
resp_sentences
,
"segments"
:
resp_sentences_seg
,
"segments"
:
resp_sentences_seg
,
"logits"
:
output_logits
,
"logits"
:
output_logits
,
"all_logits"
:
full_logits
})
"all_logits"
:
full_logits
,
"tokens"
:
tokens
})
return
jsonify
({
"sentences"
:
resp_sentences
,
return
jsonify
({
"sentences"
:
resp_sentences
,
"segments"
:
resp_sentences_seg
,
"segments"
:
resp_sentences_seg
,
...
...
megatron/text_generation_utils.py
View file @
811183f0
...
@@ -121,7 +121,7 @@ def receive_generate_info():
...
@@ -121,7 +121,7 @@ def receive_generate_info():
"""
"""
Needs to be synced up with send_generate_info
Needs to be synced up with send_generate_info
"""
"""
input_info_tensor
=
torch
.
empty
(
3
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
input_info_tensor
=
torch
.
empty
(
4
,
dtype
=
torch
.
int64
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
batch_size
=
input_info_tensor
[
0
].
item
()
batch_size
=
input_info_tensor
[
0
].
item
()
seq_len
=
input_info_tensor
[
1
].
item
()
seq_len
=
input_info_tensor
[
1
].
item
()
...
@@ -166,9 +166,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
...
@@ -166,9 +166,10 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
if
all_probs
:
if
all_probs
:
args
=
get_args
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
full_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
,
args
.
padded_vocab_size
()
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
full_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
,
args
.
padded_vocab_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
full_logits
,
src
,
group
)
torch
.
distributed
.
broadcast
(
full_logits
,
src
,
group
)
if
tokens
is
not
None
:
if
tokens
is
not
None
:
...
@@ -193,8 +194,9 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
...
@@ -193,8 +194,9 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
resp_sentences
=
[]
resp_sentences
=
[]
resp_sentences_seg
=
[]
resp_sentences_seg
=
[]
for
i
in
range
(
decode_tokens
.
size
(
0
)):
decode_token
=
decode_tokens
[
i
,:].
cpu
().
numpy
().
tolist
()
decode_tokens
=
decode_tokens
.
cpu
().
numpy
().
tolist
()
for
decode_token
in
decode_tokens
:
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
words
=
[]
words
=
[]
for
token
in
decode_token
:
for
token
in
decode_token
:
...
@@ -208,8 +210,8 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
...
@@ -208,8 +210,8 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
full_logits
=
full_logits
.
cpu
().
numpy
().
tolist
()
full_logits
=
full_logits
.
cpu
().
numpy
().
tolist
()
end
=
time
.
time
()
end
=
time
.
time
()
print
(
str
(
b
)
+
","
+
str
(
c
)
+
","
+
str
(
decode_tokens
.
size
(
1
))
+
","
+
str
(
end
-
start
),
flush
=
True
)
print
(
str
(
b
)
+
","
+
str
(
c
)
+
","
+
str
(
len
(
decode_tokens
[
0
]
))
+
","
+
str
(
end
-
start
),
flush
=
True
)
return
resp_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
return
resp_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
,
decode_tokens
def
switch
(
val1
,
val2
,
boolean
):
def
switch
(
val1
,
val2
,
boolean
):
boolean
=
boolean
.
type_as
(
val1
)
boolean
=
boolean
.
type_as
(
val1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment