Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6c40f892
Commit
6c40f892
authored
Sep 28, 2021
by
mshoeybi
Browse files
working
parent
25f9c3f0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
243 additions
and
43 deletions
+243
-43
megatron/inference/api.py
megatron/inference/api.py
+56
-0
megatron/inference/communication.py
megatron/inference/communication.py
+39
-8
megatron/inference/generation.py
megatron/inference/generation.py
+104
-33
megatron/inference/tokenization.py
megatron/inference/tokenization.py
+33
-0
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+11
-2
No files found.
megatron/inference/api.py
0 → 100644
View file @
6c40f892
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference API."""
import
torch
from
.communication
import
broadcast_float_list
from
.generation
import
generate_tokens_probs_and_return_on_first_stage
from
.tokenization
import
tokenize_prompts
def
generate
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
return_output_log_probs
=
False
,
return_all_log_probs
=
False
,
temperature
=
1.0
):
"""TO DO ..."""
# Make sure input params are avaialble to all ranks.
values
=
[
tokens_to_generate
,
return_output_log_probs
,
return_all_log_probs
,
temperature
]
values_float_tensor
=
broadcast_float_list
(
4
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
return_output_log_probs
=
bool
(
values_float_tensor
[
1
].
item
())
return_all_log_probs
=
bool
(
values_float_tensor
[
2
].
item
())
temperature
=
values_float_tensor
[
2
].
item
()
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
if
torch
.
distributed
.
get_rank
()
==
0
:
assert
prompts
is
not
None
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
)
# Main inference function.
# Note that the outputs are available on the first stage.
return
generate_tokens_probs_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
,
return_output_log_probs
=
return_output_log_probs
,
return_all_log_probs
=
return_all_log_probs
,
temperature
=
temperature
)
megatron/inference/communication.py
View file @
6c40f892
...
...
@@ -40,6 +40,33 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
return
tensor
def
broadcast_from_last_to_first_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
"""Broadcast tensor values from last stage into the first stage."""
# Only first and last stage pipeline stages need to be involved.
is_last_stage
=
mpu
.
is_pipeline_last_stage
()
is_first_stage
=
mpu
.
is_pipeline_first_stage
()
if
is_last_stage
or
is_first_stage
:
if
is_last_stage
:
assert
tensor
is
not
None
assert
tensor
.
is_cuda
assert
tensor
.
is_contiguous
()
else
:
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
# Broadcast from last stage into the first stage.
torch
.
distributed
.
broadcast
(
tensor
,
src
,
group
)
else
:
tensor
=
None
return
tensor
def
copy_from_last_to_first_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
...
...
@@ -48,20 +75,24 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
is_last_stage
=
mpu
.
is_pipeline_last_stage
()
is_first_stage
=
mpu
.
is_pipeline_first_stage
()
if
is_last_stage
or
is_first_stage
:
assert
tensor
is
not
None
assert
tensor
.
is_cuda
is_contiguous
=
tensor
.
is_contiguous
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
if
is_last_stage
:
assert
tensor
is
not
None
assert
tensor
.
is_cuda
tensor_
=
tensor
.
contiguous
()
if
is_contiguous
:
tensor_
=
tensor
else
:
tensor_
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
if
is_last_stage
:
tensor_
=
tensor
.
contiguous
()
else
:
tensor_
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
# Broadcast from last stage into the first stage.
torch
.
distributed
.
broadcast
(
tensor_
,
src
,
group
)
# Update the first stage tensor
if
is_first_stage
:
if
is_first_stage
and
not
is_contiguous
:
tensor
[...]
=
tensor_
...
...
megatron/inference/generation.py
View file @
6c40f892
...
...
@@ -19,19 +19,44 @@
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
get_tokenizer
from
megatron
import
mpu
from
megatron
import
get_args
,
get_tokenizer
,
mpu
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
.communication
import
(
copy_from_last_to_first_pipeline_stage
,
broadcast_from_last_pipeline_stage
)
broadcast_from_last_pipeline_stage
,
broadcast_from_last_to_first_pipeline_stage
)
from
.forward_step
import
forward_step
from
.sampling
import
sample
def
generate_tokens
(
model
,
tokens
,
lengths
,
return_all_probs
=
False
,
temperature
=
1.0
):
"""Main token generation function."""
def
generate_tokens_probs_and_return_on_first_stage
(
model
,
tokens
,
lengths
,
return_output_log_probs
=
False
,
return_all_log_probs
=
False
,
temperature
=
1.0
):
"""Main token generation function.
Arguments:
model: XXX
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
after logits are modifed for sampling.
return_all_log_probs: flag to calculate the log probability of across
all the tokens (vocab size). Note that the log probability is the
one after logits are modifed for sampling.
temperature: sampling temperature.
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs: Note that is size is adjusted to a lower value than
max-sequence-length if generation is terminated early.
tokens: prompt and generated tokens. size: [b, :]
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
all_log_probs: log probability of all the tokens.
size: [b, s, vocab-size]
"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
...
...
@@ -52,18 +77,35 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens)
output_log_probs
=
torch
.
empty
(
batch_size
,
max_sequence_length
-
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
# Log probability of the sequence (prompt + generated tokens).
output_log_probs
=
None
output_log_probs_size
=
(
batch_size
,
max_sequence_length
-
1
)
# Log probability of all tokens for the sequence.
all_log_probs
=
None
all_log_probs_size
=
(
batch_size
,
max_sequence_length
-
1
,
args
.
padded_vocab_size
)
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
*
max_sequence_length
generated_sequence_lengths
=
None
if
mpu
.
is_pipeline_last_stage
():
if
return_output_log_probs
:
output_log_probs
=
torch
.
empty
(
output_log_probs_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
if
return_all_log_probs
:
all_log_probs
=
torch
.
empty
(
all_log_probs_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
generated_sequence_lengths
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
*
max_sequence_length
# Whether we have reached a termination id.
is_generation_done
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_device
())
# =============
# Run infernece
# =============
attention_mask
,
position_ids
=
_build_attention_mask_and_position_ids
(
tokens
)
...
...
@@ -114,15 +156,25 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
tokens
[
started
,
context_length
]
=
new_sample
[
started
]
# Calculate the log probabilities.
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
# Pick the tokens that we need to get the log probabilities for.
# Note that next input token is the token which we selected in
# the current logits, so shift by 1.
indices
=
torch
.
unsqueeze
(
tokens
[:,
(
prev_context_length
+
1
):(
context_length
+
1
)],
2
)
output_log_probs
[:,
prev_context_length
:
context_length
]
=
\
torch
.
gather
(
log_probs
,
2
,
indices
).
squeeze
(
2
)
if
return_output_log_probs
or
return_all_log_probs
:
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
if
return_all_log_probs
:
all_log_probs
[:,
prev_context_length
:
context_length
,
:]
=
log_probs
if
return_output_log_probs
:
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices
=
torch
.
unsqueeze
(
tokens
[
:,
(
prev_context_length
+
1
):(
context_length
+
1
)],
2
)
output_log_probs
[:,
prev_context_length
:
context_length
]
=
\
torch
.
gather
(
log_probs
,
2
,
indices
).
squeeze
(
2
)
# Update the tokens on the first stage so the next input to
# the network is correct.
...
...
@@ -147,17 +199,36 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
if
done
:
break
if
mpu
.
is_pipeline_last_stage
():
if
return_all_probs
:
full_logits
=
None
return
tokens
,
generated_sequence_lengths
,
output_log_probs
,
\
full_logits
,
context_length
+
1
return
tokens
,
generated_sequence_lengths
,
output_log_probs
,
\
None
,
context_length
+
1
if
mpu
.
is_pipeline_first_stage
():
return
tokens
,
None
,
None
,
None
,
context_length
+
1
return
None
,
None
,
None
,
None
,
context_length
+
1
# ===================================================
# Update the length of based on max generated length.
# ===================================================
tokens
=
tokens
[:,
:(
context_length
+
1
)]
if
mpu
.
is_pipeline_last_stage
():
if
return_output_log_probs
:
output_log_probs
=
output_log_probs
[:,
:
context_length
]
if
return_all_log_probs
:
all_log_probs
=
all_log_probs
[:,
:
context_length
,
:]
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
generated_sequence_lengths
=
broadcast_from_last_to_first_pipeline_stage
(
batch_size
,
torch
.
int64
,
generated_sequence_lengths
)
if
return_output_log_probs
:
output_log_probs_size
=
(
batch_size
,
context_length
)
output_log_probs
=
broadcast_from_last_to_first_pipeline_stage
(
output_log_probs_size
,
torch
.
float32
,
output_log_probs
)
if
return_all_log_probs
:
all_log_probs_size
=
(
batch_size
,
context_length
,
args
.
padded_vocab_size
)
all_log_probs
=
broadcast_from_last_to_first_pipeline_stage
(
all_log_probs_size
,
torch
.
float32
,
all_log_probs
)
return
tokens
,
generated_sequence_lengths
,
output_log_probs
,
\
all_log_probs
def
_build_attention_mask_and_position_ids
(
tokens
):
...
...
megatron/inference/tokenization.py
View file @
6c40f892
...
...
@@ -23,6 +23,39 @@ from megatron import get_tokenizer
from
.communication
import
broadcast_int_list
,
broadcast_tensor
def
detokenize_generations
(
tokens_gpu_tensor
,
lengths_gpu_tensor
,
return_segments
):
"""Detokenize the generated tokens."""
tokenizer
=
get_tokenizer
()
prompts_plus_generations
=
[]
if
return_segments
:
prompts_plus_generations_segments
=
[]
tokens
=
tokens_gpu_tensor
.
cpu
().
numpy
().
tolist
()
lengths
=
lengths_gpu_tensor
.
cpu
().
numpy
().
tolist
()
for
sequence_tokens
,
length
in
zip
(
tokens
,
lengths
):
sequence_tokens
=
sequence_tokens
[:
length
]
prompts_plus_generations
.
append
(
tokenizer
.
detokenize
(
sequence_tokens
))
if
return_segments
:
words
=
[]
for
token
in
sequence_tokens
:
word
=
tokenizer
.
tokenizer
.
decoder
[
token
]
word
=
bytearray
(
[
tokenizer
.
tokenizer
.
byte_decoder
[
c
]
for
c
in
word
]).
decode
(
'utf-8'
,
errors
=
'replace'
)
words
.
append
(
word
)
prompts_plus_generations_segments
.
append
(
words
)
if
return_segments
:
return
tokens
,
prompts_plus_generations
,
\
prompts_plus_generations_segments
return
tokens
,
prompts_plus_generations
def
tokenize_prompts
(
prompts
=
None
,
tokens_to_generate
=
None
,
rank
=
0
):
"""Tokenize prompts and make them avaiable on all ranks."""
...
...
megatron/text_generation_utils.py
View file @
6c40f892
...
...
@@ -153,8 +153,12 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
if
mpu
.
is_pipeline_last_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
print
(
'last rank output size {} {} |
\n
'
.
format
(
output_logits
.
size
(
0
),
output_logits
.
size
(
1
)))
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
if
all_probs
:
print
(
'last rank full size {} {} |
\n
'
.
format
(
full_logits
.
size
(
0
),
full_logits
.
size
(
1
),
full_logits
.
size
(
2
)))
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
torch
.
distributed
.
broadcast
(
full_logits
,
src
,
group
)
...
...
@@ -164,13 +168,18 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
output_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
-
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
print
(
'first rank output size {} {} |
\n
'
.
format
(
output_logits
.
size
(
0
),
output_logits
.
size
(
1
)))
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
if
all_probs
:
args
=
get_args
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
full_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
,
args
.
padded_vocab_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
full_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
-
1
,
args
.
padded_vocab_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
print
(
'first rank full size {} {} |
\n
'
.
format
(
full_logits
.
size
(
0
),
full_logits
.
size
(
1
),
full_logits
.
size
(
2
)))
torch
.
distributed
.
broadcast
(
full_logits
,
src
,
group
)
if
tokens
is
not
None
:
return
tokens
[:,
:
context_length
],
output_logits
,
full_logits
...
...
@@ -204,7 +213,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
output_logits
=
output_logits
.
cpu
().
numpy
().
tolist
()
if
all_probs
:
full_logits
=
full_logits
.
cpu
().
numpy
().
tolist
()
full_logits
=
full_logits
.
cpu
().
numpy
()
#
.tolist()
return
resp_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
,
decode_tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment