Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9cc286ba
Commit
9cc286ba
authored
Oct 15, 2021
by
rprenger
Browse files
Getting tokens_to_generate=0 to work
parent
21d2b0fe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
94 additions
and
57 deletions
+94
-57
megatron/text_generation/api.py
megatron/text_generation/api.py
+21
-23
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+68
-30
megatron/text_generation_server.py
megatron/text_generation_server.py
+5
-4
No files found.
megatron/text_generation/api.py
View file @
9cc286ba
...
@@ -20,7 +20,9 @@ import torch
...
@@ -20,7 +20,9 @@ import torch
from
megatron
import
mpu
from
megatron
import
mpu
from
.communication
import
broadcast_float_list
from
.communication
import
broadcast_float_list
from
.generation
import
generate_tokens_probs_and_return_on_first_stage
from
.generation
import
(
generate_tokens_probs_and_return_on_first_stage
,
score_and_return_on_first_stage
)
from
.tokenization
import
(
from
.tokenization
import
(
tokenize_prompts
,
tokenize_prompts
,
detokenize_generations
)
detokenize_generations
)
...
@@ -31,7 +33,6 @@ def generate_and_post_process(model,
...
@@ -31,7 +33,6 @@ def generate_and_post_process(model,
prompts
=
None
,
prompts
=
None
,
tokens_to_generate
=
0
,
tokens_to_generate
=
0
,
return_output_log_probs
=
False
,
return_output_log_probs
=
False
,
return_all_log_probs
=
False
,
greedy_sampling
=
False
,
greedy_sampling
=
False
,
top_k_sampling
=
0
,
top_k_sampling
=
0
,
top_p_sampling
=
0.0
,
top_p_sampling
=
0.0
,
...
@@ -43,12 +44,11 @@ def generate_and_post_process(model,
...
@@ -43,12 +44,11 @@ def generate_and_post_process(model,
move to cpu and convert to list."""
move to cpu and convert to list."""
# Main inference.
# Main inference.
tokens
,
lengths
,
output_log_probs
,
all_log_probs
=
generate
(
tokens
,
lengths
,
output_log_probs
=
generate
(
model
,
model
,
prompts
=
prompts
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
tokens_to_generate
=
tokens_to_generate
,
return_output_log_probs
=
return_output_log_probs
,
return_output_log_probs
=
return_output_log_probs
,
return_all_log_probs
=
return_all_log_probs
,
greedy_sampling
=
greedy_sampling
,
greedy_sampling
=
greedy_sampling
,
top_k_sampling
=
top_k_sampling
,
top_k_sampling
=
top_k_sampling
,
top_p_sampling
=
top_p_sampling
,
top_p_sampling
=
top_p_sampling
,
...
@@ -59,17 +59,16 @@ def generate_and_post_process(model,
...
@@ -59,17 +59,16 @@ def generate_and_post_process(model,
# Only post-process on first stage.
# Only post-process on first stage.
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
tokens
,
prompts_plus_generations
,
prompts_plus_generations_segments
=
\
tokens
,
prompts_plus_generations
,
prompts_plus_generations_segments
=
\
detokenize_generations
(
tokens
,
lengths
,
True
)
detokenize_generations
(
tokens
,
lengths
,
True
)
if
return_output_log_probs
:
if
return_output_log_probs
:
output_log_probs
=
output_log_probs
.
cpu
().
numpy
().
tolist
()
output_log_probs
=
output_log_probs
.
cpu
().
numpy
().
tolist
()
if
return_all_log_probs
:
for
i
,
(
prob
,
seg
)
in
enumerate
(
zip
(
output_log_probs
,
prompts_plus_generations_segments
))
:
all_log_probs
=
all_log_probs
.
cpu
().
numpy
().
tolist
()
output_log_probs
[
i
]
=
prob
[:
len
(
seg
)
-
1
]
return
prompts_plus_generations
,
prompts_plus_generations_segments
,
\
return
prompts_plus_generations
,
prompts_plus_generations_segments
,
\
output_log_probs
,
all_log_probs
,
tokens
output_log_probs
,
tokens
return
None
return
None
...
@@ -79,7 +78,6 @@ def generate(model,
...
@@ -79,7 +78,6 @@ def generate(model,
prompts
=
None
,
prompts
=
None
,
tokens_to_generate
=
0
,
tokens_to_generate
=
0
,
return_output_log_probs
=
False
,
return_output_log_probs
=
False
,
return_all_log_probs
=
False
,
greedy_sampling
=
False
,
greedy_sampling
=
False
,
top_k_sampling
=
0
,
top_k_sampling
=
0
,
top_p_sampling
=
0.0
,
top_p_sampling
=
0.0
,
...
@@ -93,25 +91,23 @@ def generate(model,
...
@@ -93,25 +91,23 @@ def generate(model,
discard tokens in the tokens tensor that are after the
discard tokens in the tokens tensor that are after the
corresponding length.
corresponding length.
output_log_probs: log probs of the tokens.
output_log_probs: log probs of the tokens.
all_log_probs: full log probs for all of tokens.
"""
"""
# Make sure input params are avaialble to all ranks.
# Make sure input params are avaialble to all ranks.
values
=
[
tokens_to_generate
,
values
=
[
tokens_to_generate
,
return_output_log_probs
,
return_all_log_probs
,
return_output_log_probs
,
greedy_sampling
,
top_k_sampling
,
top_p_sampling
,
greedy_sampling
,
top_k_sampling
,
top_p_sampling
,
temperature
,
add_BOS
,
use_eod_token_for_early_termination
,
just_score
]
temperature
,
add_BOS
,
use_eod_token_for_early_termination
,
just_score
]
values_float_tensor
=
broadcast_float_list
(
10
,
float_list
=
values
)
values_float_tensor
=
broadcast_float_list
(
9
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
return_output_log_probs
=
bool
(
values_float_tensor
[
1
].
item
())
return_output_log_probs
=
bool
(
values_float_tensor
[
1
].
item
())
return_all_log_probs
=
bool
(
values_float_tensor
[
2
].
item
())
greedy_sampling
=
bool
(
values_float_tensor
[
2
].
item
())
greedy_sampling
=
bool
(
values_float_tensor
[
3
].
item
())
top_k_sampling
=
int
(
values_float_tensor
[
3
].
item
())
top_k_sampling
=
int
(
values_float_tensor
[
4
].
item
())
top_p_sampling
=
values_float_tensor
[
4
].
item
()
top_p_sampling
=
values_float_tensor
[
5
].
item
()
temperature
=
values_float_tensor
[
5
].
item
()
temperature
=
values_float_tensor
[
6
].
item
()
add_BOS
=
bool
(
values_float_tensor
[
6
].
item
())
add_BOS
=
bool
(
values_float_tensor
[
7
].
item
())
use_eod_token_for_early_termination
=
bool
(
values_float_tensor
[
7
].
item
())
use_eod_token_for_early_termination
=
bool
(
values_float_tensor
[
8
].
item
())
just_score
=
bool
(
values_float_tensor
[
8
].
item
())
just_score
=
bool
(
values_float_tensor
[
9
].
item
())
# Tokenize prompts and get the batch.
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
# Note that these tensors are broadcaseted to all ranks.
...
@@ -121,13 +117,15 @@ def generate(model,
...
@@ -121,13 +117,15 @@ def generate(model,
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
add_BOS
=
add_BOS
)
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
add_BOS
=
add_BOS
)
if
just_score
:
return
score_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
)
# Main inference function.
# Main inference function.
# Note that the outputs are available on the first stage.
# Note that the outputs are available on the first stage.
return
generate_tokens_probs_and_return_on_first_stage
(
return
generate_tokens_probs_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
,
model
,
context_tokens_tensor
,
context_length_tensor
,
return_output_log_probs
=
return_output_log_probs
,
return_output_log_probs
=
return_output_log_probs
,
return_all_log_probs
=
return_all_log_probs
,
greedy
=
greedy_sampling
,
top_k
=
top_k_sampling
,
top_p
=
top_p_sampling
,
greedy
=
greedy_sampling
,
top_k
=
top_k_sampling
,
top_p
=
top_p_sampling
,
temperature
=
temperature
,
temperature
=
temperature
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
)
just_score
=
just_score
)
megatron/text_generation/generation.py
View file @
9cc286ba
...
@@ -27,15 +27,76 @@ from .communication import (
...
@@ -27,15 +27,76 @@ from .communication import (
from
.forward_step
import
ForwardStep
from
.forward_step
import
ForwardStep
from
.sampling
import
sample
from
.sampling
import
sample
def
score_and_return_on_first_stage
(
model
,
tokens
,
lengths
):
"""Function for just scoring.
Arguments:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max_prompt_length]
lengths: original prompt length, size: [b]
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs:
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args
=
get_args
()
batch_size
=
tokens
.
size
(
0
)
max_prompt_length
=
lengths
.
max
().
item
()
assert
max_prompt_length
==
tokens
.
size
(
1
)
max_sequence_length
=
min
(
max_prompt_length
,
args
.
max_position_embeddings
)
# forward step.
forward_step
=
ForwardStep
(
model
,
batch_size
,
max_sequence_length
)
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs
=
None
output_log_probs_size
=
(
batch_size
,
max_sequence_length
-
1
)
if
mpu
.
is_pipeline_last_stage
():
output_log_probs
=
torch
.
empty
(
output_log_probs_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
# =============
# Run infernece
# =============
with
torch
.
no_grad
():
attention_mask
,
position_ids
=
_build_attention_mask_and_position_ids
(
tokens
)
# logits will be meanigful only in the last pipeline stage.
logits
=
forward_step
(
tokens
,
position_ids
,
attention_mask
)
if
mpu
.
is_pipeline_last_stage
():
# Always the last stage should have an output.
assert
logits
is
not
None
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices
=
torch
.
unsqueeze
(
tokens
[:,
1
:],
2
)
output_log_probs
=
torch
.
gather
(
log_probs
,
2
,
indices
).
squeeze
(
2
)
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
output_log_probs
=
broadcast_from_last_to_first_pipeline_stage
(
output_log_probs_size
,
torch
.
float32
,
output_log_probs
)
return
tokens
,
lengths
,
output_log_probs
def
generate_tokens_probs_and_return_on_first_stage
(
def
generate_tokens_probs_and_return_on_first_stage
(
model
,
tokens
,
lengths
,
model
,
tokens
,
lengths
,
return_output_log_probs
=
False
,
return_output_log_probs
=
False
,
return_all_log_probs
=
False
,
greedy
=
False
,
top_k
=
0
,
top_p
=
0.0
,
greedy
=
False
,
top_k
=
0
,
top_p
=
0.0
,
temperature
=
1.0
,
temperature
=
1.0
,
use_eod_token_for_early_termination
=
True
,
use_eod_token_for_early_termination
=
True
):
just_score
=
False
):
"""Main token generation function.
"""Main token generation function.
Arguments:
Arguments:
model: no interleaving is supported.
model: no interleaving is supported.
...
@@ -44,9 +105,6 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -44,9 +105,6 @@ def generate_tokens_probs_and_return_on_first_stage(
return_output_log_probs: flag to calculate the log probability of
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
the generated tokens. Note that the log probability is the one
after logits are modifed for sampling.
after logits are modifed for sampling.
return_all_log_probs: flag to calculate the log probability of across
all the tokens (vocab size). Note that the log probability is the
one after logits are modifed for sampling.
greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
Note that these three paramters are exclusive meaning that:
Note that these three paramters are exclusive meaning that:
if greedy = true then we should have top-k=top-p=0.
if greedy = true then we should have top-k=top-p=0.
...
@@ -63,8 +121,6 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -63,8 +121,6 @@ def generate_tokens_probs_and_return_on_first_stage(
generated_sequence_lengths: total length (including prompt) of
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
output_log_probs: log probability of the selected tokens. size: [b, s]
all_log_probs: log probability of all the tokens.
size: [b, s, vocab-size]
"""
"""
args
=
get_args
()
args
=
get_args
()
...
@@ -93,9 +149,7 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -93,9 +149,7 @@ def generate_tokens_probs_and_return_on_first_stage(
output_log_probs
=
None
output_log_probs
=
None
output_log_probs_size
=
(
batch_size
,
max_sequence_length
-
1
)
output_log_probs_size
=
(
batch_size
,
max_sequence_length
-
1
)
# Log probability of all tokens for the sequence.
# Log probability of all tokens for the sequence.
all_log_probs
=
None
all_log_probs_size
=
(
batch_size
,
max_sequence_length
-
1
,
args
.
padded_vocab_size
)
# Lengths of generated seuquence including including prompts.
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths
=
None
generated_sequence_lengths
=
None
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
...
@@ -103,10 +157,6 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -103,10 +157,6 @@ def generate_tokens_probs_and_return_on_first_stage(
output_log_probs
=
torch
.
empty
(
output_log_probs_size
,
output_log_probs
=
torch
.
empty
(
output_log_probs_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
if
return_all_log_probs
:
all_log_probs
=
torch
.
empty
(
all_log_probs_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
generated_sequence_lengths
=
torch
.
ones
(
generated_sequence_lengths
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
*
max_sequence_length
device
=
torch
.
cuda
.
current_device
())
*
max_sequence_length
...
@@ -159,12 +209,8 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -159,12 +209,8 @@ def generate_tokens_probs_and_return_on_first_stage(
tokens
[
started
,
context_length
]
=
new_sample
[
started
]
tokens
[
started
,
context_length
]
=
new_sample
[
started
]
# Calculate the log probabilities.
# Calculate the log probabilities.
if
return_output_log_probs
or
return_all_log_probs
:
if
return_output_log_probs
:
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
if
return_all_log_probs
:
all_log_probs
[:,
prev_context_length
:
context_length
,
:]
=
log_probs
if
return_output_log_probs
:
if
return_output_log_probs
:
# Pick the tokens that we need to get the log
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# probabilities for. Note that next input token is
...
@@ -210,8 +256,6 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -210,8 +256,6 @@ def generate_tokens_probs_and_return_on_first_stage(
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
if
return_output_log_probs
:
if
return_output_log_probs
:
output_log_probs
=
output_log_probs
[:,
:
context_length
].
contiguous
()
output_log_probs
=
output_log_probs
[:,
:
context_length
].
contiguous
()
if
return_all_log_probs
:
all_log_probs
=
all_log_probs
[:,
:
context_length
,
:]
# ======================================
# ======================================
# Broadcast to the first pipeline stage.
# Broadcast to the first pipeline stage.
...
@@ -223,14 +267,8 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -223,14 +267,8 @@ def generate_tokens_probs_and_return_on_first_stage(
output_log_probs_size
=
(
batch_size
,
context_length
)
output_log_probs_size
=
(
batch_size
,
context_length
)
output_log_probs
=
broadcast_from_last_to_first_pipeline_stage
(
output_log_probs
=
broadcast_from_last_to_first_pipeline_stage
(
output_log_probs_size
,
torch
.
float32
,
output_log_probs
)
output_log_probs_size
,
torch
.
float32
,
output_log_probs
)
if
return_all_log_probs
:
all_log_probs_size
=
(
batch_size
,
context_length
,
return
tokens
,
generated_sequence_lengths
,
output_log_probs
args
.
padded_vocab_size
)
all_log_probs
=
broadcast_from_last_to_first_pipeline_stage
(
all_log_probs_size
,
torch
.
float32
,
all_log_probs
)
return
tokens
,
generated_sequence_lengths
,
output_log_probs
,
\
all_log_probs
...
...
megatron/text_generation_server.py
View file @
9cc286ba
...
@@ -69,6 +69,8 @@ class MegatronGenerate(Resource):
...
@@ -69,6 +69,8 @@ class MegatronGenerate(Resource):
logprobs
=
request
.
get_json
()[
"logprobs"
]
logprobs
=
request
.
get_json
()[
"logprobs"
]
if
not
isinstance
(
logprobs
,
bool
):
if
not
isinstance
(
logprobs
,
bool
):
return
"logprobs must be a boolean value"
return
"logprobs must be a boolean value"
if
just_score
and
not
logprobs
:
return
"tokens_to_generate=0 implies logprobs=True"
temperature
=
1.0
temperature
=
1.0
if
"temperature"
in
request
.
get_json
():
if
"temperature"
in
request
.
get_json
():
...
@@ -83,7 +85,7 @@ class MegatronGenerate(Resource):
...
@@ -83,7 +85,7 @@ class MegatronGenerate(Resource):
top_k
=
request
.
get_json
()[
"top_k"
]
top_k
=
request
.
get_json
()[
"top_k"
]
if
not
(
type
(
top_k
)
==
int
):
if
not
(
type
(
top_k
)
==
int
):
return
"top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
return
"top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
if
not
(
0
<
top_k
<=
1000
):
if
not
(
0
<
=
top_k
<=
1000
):
return
"top_k must be equal to or greater than 0 and less than or equal to 1000"
return
"top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p
=
0.0
top_p
=
0.0
...
@@ -93,7 +95,7 @@ class MegatronGenerate(Resource):
...
@@ -93,7 +95,7 @@ class MegatronGenerate(Resource):
return
"top_p must be a positive float less than or equal to 1.0"
return
"top_p must be a positive float less than or equal to 1.0"
if
top_p
>
0.0
and
top_k
>
0.0
:
if
top_p
>
0.0
and
top_k
>
0.0
:
return
"cannot set both top-k and top-p samplings."
return
"cannot set both top-k and top-p samplings."
if
not
(
0
<
top_p
<=
1.0
):
if
not
(
0
<
=
top_p
<=
1.0
):
return
"top_p must be less than or equal to 1.0"
return
"top_p must be less than or equal to 1.0"
add_BOS
=
False
add_BOS
=
False
...
@@ -104,13 +106,12 @@ class MegatronGenerate(Resource):
...
@@ -104,13 +106,12 @@ class MegatronGenerate(Resource):
with
lock
:
# Need to get lock to keep multiple threads from hitting code
with
lock
:
# Need to get lock to keep multiple threads from hitting code
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
response
,
response_seg
,
response_logprobs
,
_
,
_
=
\
response
,
response_seg
,
response_logprobs
,
_
=
\
generate_and_post_process
(
generate_and_post_process
(
self
.
model
,
self
.
model
,
prompts
=
prompts
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
tokens_to_generate
=
tokens_to_generate
,
return_output_log_probs
=
logprobs
,
return_output_log_probs
=
logprobs
,
return_all_log_probs
=
False
,
greedy_sampling
=
args
.
greedy
,
greedy_sampling
=
args
.
greedy
,
top_k_sampling
=
top_k
,
top_k_sampling
=
top_k
,
top_p_sampling
=
top_p
,
top_p_sampling
=
top_p
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment