Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
c6e7c7fd
Commit
c6e7c7fd
authored
Oct 15, 2021
by
mshoeybi
Browse files
removed return all probs
parent
8d405805
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
63 deletions
+49
-63
megatron/text_generation/api.py
megatron/text_generation/api.py
+10
-19
megatron/text_generation/communication.py
megatron/text_generation/communication.py
+35
-13
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+3
-29
megatron/text_generation_server.py
megatron/text_generation_server.py
+1
-2
No files found.
megatron/text_generation/api.py
View file @
c6e7c7fd
...
@@ -31,7 +31,6 @@ def generate_and_post_process(model,
...
@@ -31,7 +31,6 @@ def generate_and_post_process(model,
prompts
=
None
,
prompts
=
None
,
tokens_to_generate
=
0
,
tokens_to_generate
=
0
,
return_output_log_probs
=
False
,
return_output_log_probs
=
False
,
return_all_log_probs
=
False
,
greedy_sampling
=
False
,
greedy_sampling
=
False
,
top_k_sampling
=
0
,
top_k_sampling
=
0
,
top_p_sampling
=
0.0
,
top_p_sampling
=
0.0
,
...
@@ -42,12 +41,11 @@ def generate_and_post_process(model,
...
@@ -42,12 +41,11 @@ def generate_and_post_process(model,
move to cpu and convert to list."""
move to cpu and convert to list."""
# Main inference.
# Main inference.
tokens
,
lengths
,
output_log_probs
,
all_log_probs
=
generate
(
tokens
,
lengths
,
output_log_probs
=
generate
(
model
,
model
,
prompts
=
prompts
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
tokens_to_generate
=
tokens_to_generate
,
return_output_log_probs
=
return_output_log_probs
,
return_output_log_probs
=
return_output_log_probs
,
return_all_log_probs
=
return_all_log_probs
,
greedy_sampling
=
greedy_sampling
,
greedy_sampling
=
greedy_sampling
,
top_k_sampling
=
top_k_sampling
,
top_k_sampling
=
top_k_sampling
,
top_p_sampling
=
top_p_sampling
,
top_p_sampling
=
top_p_sampling
,
...
@@ -63,11 +61,9 @@ def generate_and_post_process(model,
...
@@ -63,11 +61,9 @@ def generate_and_post_process(model,
if
return_output_log_probs
:
if
return_output_log_probs
:
output_log_probs
=
output_log_probs
.
cpu
().
numpy
().
tolist
()
output_log_probs
=
output_log_probs
.
cpu
().
numpy
().
tolist
()
if
return_all_log_probs
:
all_log_probs
=
all_log_probs
.
cpu
().
numpy
().
tolist
()
return
prompts_plus_generations
,
prompts_plus_generations_segments
,
\
return
prompts_plus_generations
,
prompts_plus_generations_segments
,
\
output_log_probs
,
all_log_probs
,
tokens
output_log_probs
,
tokens
return
None
return
None
...
@@ -77,7 +73,6 @@ def generate(model,
...
@@ -77,7 +73,6 @@ def generate(model,
prompts
=
None
,
prompts
=
None
,
tokens_to_generate
=
0
,
tokens_to_generate
=
0
,
return_output_log_probs
=
False
,
return_output_log_probs
=
False
,
return_all_log_probs
=
False
,
greedy_sampling
=
False
,
greedy_sampling
=
False
,
top_k_sampling
=
0
,
top_k_sampling
=
0
,
top_p_sampling
=
0.0
,
top_p_sampling
=
0.0
,
...
@@ -90,24 +85,21 @@ def generate(model,
...
@@ -90,24 +85,21 @@ def generate(model,
discard tokens in the tokens tensor that are after the
discard tokens in the tokens tensor that are after the
corresponding length.
corresponding length.
output_log_probs: log probs of the tokens.
output_log_probs: log probs of the tokens.
all_log_probs: full log probs for all of tokens.
"""
"""
# Make sure input params are avaialble to all ranks.
# Make sure input params are avaialble to all ranks.
values
=
[
tokens_to_generate
,
values
=
[
tokens_to_generate
,
return_output_log_probs
,
return_output_log_probs
,
return_all_log_probs
,
greedy_sampling
,
top_k_sampling
,
top_p_sampling
,
greedy_sampling
,
top_k_sampling
,
top_p_sampling
,
temperature
,
add_BOS
,
use_eod_token_for_early_termination
]
temperature
,
add_BOS
,
use_eod_token_for_early_termination
]
values_float_tensor
=
broadcast_float_list
(
9
,
float_list
=
values
)
values_float_tensor
=
broadcast_float_list
(
8
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
return_output_log_probs
=
bool
(
values_float_tensor
[
1
].
item
())
return_output_log_probs
=
bool
(
values_float_tensor
[
1
].
item
())
return_all_log_probs
=
bool
(
values_float_tensor
[
2
].
item
())
greedy_sampling
=
bool
(
values_float_tensor
[
2
].
item
())
greedy_sampling
=
bool
(
values_float_tensor
[
3
].
item
())
top_k_sampling
=
int
(
values_float_tensor
[
3
].
item
())
top_k_sampling
=
int
(
values_float_tensor
[
4
].
item
())
top_p_sampling
=
values_float_tensor
[
4
].
item
()
top_p_sampling
=
values_float_tensor
[
5
].
item
()
temperature
=
values_float_tensor
[
5
].
item
()
temperature
=
values_float_tensor
[
6
].
item
()
add_BOS
=
bool
(
values_float_tensor
[
6
].
item
())
add_BOS
=
bool
(
values_float_tensor
[
7
].
item
())
use_eod_token_for_early_termination
=
bool
(
values_float_tensor
[
7
].
item
())
use_eod_token_for_early_termination
=
bool
(
values_float_tensor
[
8
].
item
())
# Tokenize prompts and get the batch.
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
# Note that these tensors are broadcaseted to all ranks.
...
@@ -122,7 +114,6 @@ def generate(model,
...
@@ -122,7 +114,6 @@ def generate(model,
return
generate_tokens_probs_and_return_on_first_stage
(
return
generate_tokens_probs_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
,
model
,
context_tokens_tensor
,
context_length_tensor
,
return_output_log_probs
=
return_output_log_probs
,
return_output_log_probs
=
return_output_log_probs
,
return_all_log_probs
=
return_all_log_probs
,
greedy
=
greedy_sampling
,
top_k
=
top_k_sampling
,
top_p
=
top_p_sampling
,
greedy
=
greedy_sampling
,
top_k
=
top_k_sampling
,
top_p
=
top_p_sampling
,
temperature
=
temperature
,
temperature
=
temperature
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
)
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
)
megatron/text_generation/communication.py
View file @
c6e7c7fd
...
@@ -55,13 +55,31 @@ def send_to_next_pipeline_rank(tensor=None):
...
@@ -55,13 +55,31 @@ def send_to_next_pipeline_rank(tensor=None):
def
_is_cuda
(
tensor
):
"""Check if a tensor is not none and is cuda."""
assert
tensor
is
not
None
assert
tensor
.
is_cuda
def
_is_cuda_contiguous
(
tensor
):
"""Check if a tensor is not none, is cuda, and is contiguous."""
_is_cuda
(
tensor
)
assert
tensor
.
is_contiguous
()
def
broadcast_from_last_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
def
broadcast_from_last_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
"""Broadcast a tensor from last pipeline stage to all ranks."""
"""Broadcast a tensor from last pipeline stage to all ranks."""
if
mpu
.
is_pipeline_last_stage
():
is_last_stage
=
mpu
.
is_pipeline_last_stage
()
assert
tensor
is
not
None
# If first stage and last state are the same, then there is no
assert
tensor
.
is_cuda
# pipeline parallelism and no need to communicate.
assert
tensor
.
is_contiguous
()
if
mpu
.
is_pipeline_first_stage
()
and
is_last_stage
:
return
tensor
if
is_last_stage
:
_is_cuda_contiguous
(
tensor
)
else
:
else
:
tensor
=
torch
.
empty
(
size
,
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
dtype
=
dtype
,
...
@@ -78,14 +96,16 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
...
@@ -78,14 +96,16 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
def
broadcast_from_last_to_first_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
def
broadcast_from_last_to_first_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
"""Broadcast tensor values from last stage into the first stage."""
"""Broadcast tensor values from last stage into the first stage."""
# Only first and last stage pipeline stages need to be involved.
is_last_stage
=
mpu
.
is_pipeline_last_stage
()
is_last_stage
=
mpu
.
is_pipeline_last_stage
()
is_first_stage
=
mpu
.
is_pipeline_first_stage
()
is_first_stage
=
mpu
.
is_pipeline_first_stage
()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if
is_first_stage
and
is_last_stage
:
return
tensor
# Only first and last stage pipeline stages need to be involved.
if
is_last_stage
or
is_first_stage
:
if
is_last_stage
or
is_first_stage
:
if
is_last_stage
:
if
is_last_stage
:
assert
tensor
is
not
None
_is_cuda_contiguous
(
tensor
)
assert
tensor
.
is_cuda
assert
tensor
.
is_contiguous
()
else
:
else
:
tensor
=
torch
.
empty
(
size
,
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
dtype
=
dtype
,
...
@@ -105,12 +125,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
...
@@ -105,12 +125,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage.
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
Note that the input tensor is updated in place."""
# Only first and last stage pipeline stages need to be involved.
is_last_stage
=
mpu
.
is_pipeline_last_stage
()
is_last_stage
=
mpu
.
is_pipeline_last_stage
()
is_first_stage
=
mpu
.
is_pipeline_first_stage
()
is_first_stage
=
mpu
.
is_pipeline_first_stage
()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if
is_first_stage
and
is_last_stage
:
return
# Only first and last stage pipeline stages need to be involved.
if
is_last_stage
or
is_first_stage
:
if
is_last_stage
or
is_first_stage
:
assert
tensor
is
not
None
_is_cuda
(
tensor
)
assert
tensor
.
is_cuda
is_contiguous
=
tensor
.
is_contiguous
()
is_contiguous
=
tensor
.
is_contiguous
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
...
@@ -137,8 +160,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0):
...
@@ -137,8 +160,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0):
"""
"""
if
torch
.
distributed
.
get_rank
()
==
rank
:
if
torch
.
distributed
.
get_rank
()
==
rank
:
assert
tensor
is
not
None
_is_cuda_contiguous
(
tensor
)
assert
tensor
.
is_cuda
else
:
else
:
tensor
=
torch
.
empty
(
size
,
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
dtype
=
dtype
,
...
...
megatron/text_generation/generation.py
View file @
c6e7c7fd
...
@@ -31,7 +31,6 @@ from .sampling import sample
...
@@ -31,7 +31,6 @@ from .sampling import sample
def
generate_tokens_probs_and_return_on_first_stage
(
def
generate_tokens_probs_and_return_on_first_stage
(
model
,
tokens
,
lengths
,
model
,
tokens
,
lengths
,
return_output_log_probs
=
False
,
return_output_log_probs
=
False
,
return_all_log_probs
=
False
,
greedy
=
False
,
top_k
=
0
,
top_p
=
0.0
,
greedy
=
False
,
top_k
=
0
,
top_p
=
0.0
,
temperature
=
1.0
,
temperature
=
1.0
,
use_eod_token_for_early_termination
=
True
):
use_eod_token_for_early_termination
=
True
):
...
@@ -43,9 +42,6 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -43,9 +42,6 @@ def generate_tokens_probs_and_return_on_first_stage(
return_output_log_probs: flag to calculate the log probability of
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
the generated tokens. Note that the log probability is the one
after logits are modifed for sampling.
after logits are modifed for sampling.
return_all_log_probs: flag to calculate the log probability of across
all the tokens (vocab size). Note that the log probability is the
one after logits are modifed for sampling.
greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
Note that these three paramters are exclusive meaning that:
Note that these three paramters are exclusive meaning that:
if greedy = true then we should have top-k=top-p=0.
if greedy = true then we should have top-k=top-p=0.
...
@@ -62,8 +58,6 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -62,8 +58,6 @@ def generate_tokens_probs_and_return_on_first_stage(
generated_sequence_lengths: total length (including prompt) of
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
output_log_probs: log probability of the selected tokens. size: [b, s]
all_log_probs: log probability of all the tokens.
size: [b, s, vocab-size]
"""
"""
args
=
get_args
()
args
=
get_args
()
...
@@ -91,10 +85,6 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -91,10 +85,6 @@ def generate_tokens_probs_and_return_on_first_stage(
# Log probability of the sequence (prompt + generated tokens).
# Log probability of the sequence (prompt + generated tokens).
output_log_probs
=
None
output_log_probs
=
None
output_log_probs_size
=
(
batch_size
,
max_sequence_length
-
1
)
output_log_probs_size
=
(
batch_size
,
max_sequence_length
-
1
)
# Log probability of all tokens for the sequence.
all_log_probs
=
None
all_log_probs_size
=
(
batch_size
,
max_sequence_length
-
1
,
args
.
padded_vocab_size
)
# Lengths of generated seuquence including including prompts.
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths
=
None
generated_sequence_lengths
=
None
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
...
@@ -102,10 +92,6 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -102,10 +92,6 @@ def generate_tokens_probs_and_return_on_first_stage(
output_log_probs
=
torch
.
empty
(
output_log_probs_size
,
output_log_probs
=
torch
.
empty
(
output_log_probs_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
if
return_all_log_probs
:
all_log_probs
=
torch
.
empty
(
all_log_probs_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
generated_sequence_lengths
=
torch
.
ones
(
generated_sequence_lengths
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
*
max_sequence_length
device
=
torch
.
cuda
.
current_device
())
*
max_sequence_length
...
@@ -157,12 +143,8 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -157,12 +143,8 @@ def generate_tokens_probs_and_return_on_first_stage(
tokens
[
started
,
context_length
]
=
new_sample
[
started
]
tokens
[
started
,
context_length
]
=
new_sample
[
started
]
# Calculate the log probabilities.
# Calculate the log probabilities.
if
return_output_log_probs
or
return_all_log_probs
:
if
return_output_log_probs
:
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
if
return_all_log_probs
:
all_log_probs
[:,
prev_context_length
:
context_length
,
:]
=
log_probs
if
return_output_log_probs
:
if
return_output_log_probs
:
# Pick the tokens that we need to get the log
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# probabilities for. Note that next input token is
...
@@ -208,8 +190,6 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -208,8 +190,6 @@ def generate_tokens_probs_and_return_on_first_stage(
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
if
return_output_log_probs
:
if
return_output_log_probs
:
output_log_probs
=
output_log_probs
[:,
:
context_length
]
output_log_probs
=
output_log_probs
[:,
:
context_length
]
if
return_all_log_probs
:
all_log_probs
=
all_log_probs
[:,
:
context_length
,
:]
# ======================================
# ======================================
# Broadcast to the first pipeline stage.
# Broadcast to the first pipeline stage.
...
@@ -221,14 +201,8 @@ def generate_tokens_probs_and_return_on_first_stage(
...
@@ -221,14 +201,8 @@ def generate_tokens_probs_and_return_on_first_stage(
output_log_probs_size
=
(
batch_size
,
context_length
)
output_log_probs_size
=
(
batch_size
,
context_length
)
output_log_probs
=
broadcast_from_last_to_first_pipeline_stage
(
output_log_probs
=
broadcast_from_last_to_first_pipeline_stage
(
output_log_probs_size
,
torch
.
float32
,
output_log_probs
)
output_log_probs_size
,
torch
.
float32
,
output_log_probs
)
if
return_all_log_probs
:
all_log_probs_size
=
(
batch_size
,
context_length
,
return
tokens
,
generated_sequence_lengths
,
output_log_probs
args
.
padded_vocab_size
)
all_log_probs
=
broadcast_from_last_to_first_pipeline_stage
(
all_log_probs_size
,
torch
.
float32
,
all_log_probs
)
return
tokens
,
generated_sequence_lengths
,
output_log_probs
,
\
all_log_probs
...
...
megatron/text_generation_server.py
View file @
c6e7c7fd
...
@@ -101,13 +101,12 @@ class MegatronGenerate(Resource):
...
@@ -101,13 +101,12 @@ class MegatronGenerate(Resource):
with
lock
:
# Need to get lock to keep multiple threads from hitting code
with
lock
:
# Need to get lock to keep multiple threads from hitting code
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
response
,
response_seg
,
response_logprobs
,
_
,
_
=
\
response
,
response_seg
,
response_logprobs
,
_
=
\
generate_and_post_process
(
generate_and_post_process
(
self
.
model
,
self
.
model
,
prompts
=
prompts
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
tokens_to_generate
=
tokens_to_generate
,
return_output_log_probs
=
logprobs
,
return_output_log_probs
=
logprobs
,
return_all_log_probs
=
False
,
greedy_sampling
=
args
.
greedy
,
greedy_sampling
=
args
.
greedy
,
top_k_sampling
=
top_k
,
top_k_sampling
=
top_k
,
top_p_sampling
=
top_p
,
top_p_sampling
=
top_p
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment