Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e0bf5199
Commit
e0bf5199
authored
Jul 29, 2021
by
rprenger
Browse files
Outputting log probabilities
parent
279d8320
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
12 deletions
+51
-12
megatron/api_server.py
megatron/api_server.py
+4
-3
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+47
-9
No files found.
megatron/api_server.py
View file @
e0bf5199
...
@@ -48,9 +48,10 @@ class MegatronGenerate(Resource):
...
@@ -48,9 +48,10 @@ class MegatronGenerate(Resource):
return
"max_len must be an integer greater than 0"
return
"max_len must be an integer greater than 0"
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
resp_sentences
=
generate
(
self
.
model
,
sentences
,
max_len
)
resp_sentences
,
resp_sentences_seg
,
output_logits
=
generate
(
self
.
model
,
sentences
,
max_len
)
return
jsonify
({
"sentences"
:
resp_sentences
})
return
jsonify
({
"sentences"
:
resp_sentences
,
"segments"
:
resp_sentences_seg
,
"logits"
:
output_logits
})
def
index
():
def
index
():
return
current_app
.
send_static_file
(
'index.html'
)
return
current_app
.
send_static_file
(
'index.html'
)
...
...
megatron/text_generation_utils.py
View file @
e0bf5199
...
@@ -144,11 +144,22 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
...
@@ -144,11 +144,22 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
context_length_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
,
attention_mask
,
position_ids
,
max_len
)
max_len
)
for
tokens
,
lengths
in
batch_token_iterator
:
for
tokens
,
lengths
,
output_logits
in
batch_token_iterator
:
context_length
+=
1
context_length
+=
1
if
mpu
.
is_pipeline_last_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
else
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
output_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
-
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
if
tokens
is
not
None
:
if
tokens
is
not
None
:
return
tokens
[:,
:
context_length
]
return
tokens
[:,
:
context_length
]
,
output_logits
def
generate
(
model
,
sentences
=
None
,
max_len
=
0
):
def
generate
(
model
,
sentences
=
None
,
max_len
=
0
):
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
...
@@ -160,18 +171,29 @@ def generate(model, sentences=None, max_len=0):
...
@@ -160,18 +171,29 @@ def generate(model, sentences=None, max_len=0):
else
:
else
:
context_length_tensor
,
context_tokens_tensor
,
max_len
=
receive_generate_info
()
context_length_tensor
,
context_tokens_tensor
,
max_len
=
receive_generate_info
()
decode_tokens
=
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
max_len
)
output
=
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
max_len
)
if
output
is
not
None
:
decode_tokens
,
output_logits
=
output
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
resp_sentences
=
[]
resp_sentences
=
[]
resp_sentences_seg
=
[]
for
i
in
range
(
decode_tokens
.
size
(
0
)):
for
i
in
range
(
decode_tokens
.
size
(
0
)):
decode_token
=
decode_tokens
[
i
,:].
cpu
().
numpy
().
tolist
()
decode_token
=
decode_tokens
[
i
,:].
cpu
().
numpy
().
tolist
()
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
words
=
[]
for
token
in
decode_token
:
word
=
tokenizer
.
tokenizer
.
decoder
[
token
]
word
=
bytearray
([
tokenizer
.
tokenizer
.
byte_decoder
[
c
]
for
c
in
word
]).
decode
(
'utf-8'
,
errors
=
'replace'
)
words
.
append
(
word
)
resp_sentences_seg
.
append
(
words
)
output_logits
=
output_logits
.
cpu
().
numpy
().
tolist
()
end
=
time
.
time
()
end
=
time
.
time
()
print
(
str
(
b
)
+
","
+
str
(
c
)
+
","
+
str
(
decode_tokens
.
size
(
1
))
+
","
+
str
(
end
-
start
),
flush
=
True
)
print
(
str
(
b
)
+
","
+
str
(
c
)
+
","
+
str
(
decode_tokens
.
size
(
1
))
+
","
+
str
(
end
-
start
),
flush
=
True
)
return
resp_sentences
return
resp_sentences
,
resp_sentences_seg
,
output_logits
def
switch
(
val1
,
val2
,
boolean
):
def
switch
(
val1
,
val2
,
boolean
):
boolean
=
boolean
.
type_as
(
val1
)
boolean
=
boolean
.
type_as
(
val1
)
...
@@ -236,6 +258,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -236,6 +258,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
batch_size
=
context_tokens
.
size
(
0
)
batch_size
=
context_tokens
.
size
(
0
)
is_done
=
torch
.
zeros
([
batch_size
]).
byte
().
cuda
()
is_done
=
torch
.
zeros
([
batch_size
]).
byte
().
cuda
()
tokens
=
context_tokens
tokens
=
context_tokens
output_logits
=
None
if
maxlen
is
None
:
if
maxlen
is
None
:
maxlen
=
args
.
seq_length
-
1
maxlen
=
args
.
seq_length
-
1
...
@@ -261,6 +285,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -261,6 +285,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if
type_ids
is
not
None
:
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
batch_size
,
-
1
)
output
,
layer_past
=
forward_step
(
model
,
tokens2use
,
output
,
layer_past
=
forward_step
(
model
,
tokens2use
,
positions2use
,
positions2use
,
attention_mask
,
attention_mask
,
...
@@ -288,6 +313,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -288,6 +313,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens
=
switch
(
new_tokens
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
]
=
new_tokens
tokens
[:,
context_length
]
=
new_tokens
if
output_logits
is
None
:
output_context
=
F
.
log_softmax
(
output
[:,
:
context_length
,
:],
2
)
indices
=
torch
.
unsqueeze
(
tokens
[:,
:
context_length
],
2
)
output_logits
=
torch
.
gather
(
output_context
,
2
,
indices
).
squeeze
(
2
)
else
:
indices
=
torch
.
unsqueeze
(
new_tokens
,
1
).
unsqueeze
(
2
)
new_output_logits
=
torch
.
gather
(
F
.
log_softmax
(
output
,
2
),
2
,
indices
).
squeeze
(
2
)
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits
=
torch
.
cat
([
output_logits
,
new_output_logits
],
1
)
#output_logits = torch.cat([output_logits, output[:,context_length,new_tokens]], 1)
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
...
@@ -301,7 +339,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -301,7 +339,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
yield
tokens
,
lengths
yield
tokens
,
lengths
,
output_logits
else
:
else
:
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
...
@@ -310,9 +348,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -310,9 +348,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens
=
torch
.
empty_like
(
tokens
[:,
context_length
])
new_tokens
=
torch
.
empty_like
(
tokens
[:,
context_length
])
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
tokens
[:,
context_length
]
=
new_tokens
tokens
[:,
context_length
]
=
new_tokens
yield
tokens
,
None
yield
tokens
,
None
,
None
else
:
else
:
yield
None
,
None
yield
None
,
None
,
None
done
=
torch
.
cuda
.
ByteTensor
([
0
])
done
=
torch
.
cuda
.
ByteTensor
([
0
])
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment