Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6c40f892
Commit
6c40f892
authored
Sep 28, 2021
by
mshoeybi
Browse files
working
parent
25f9c3f0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
243 additions
and
43 deletions
+243
-43
megatron/inference/api.py
megatron/inference/api.py
+56
-0
megatron/inference/communication.py
megatron/inference/communication.py
+39
-8
megatron/inference/generation.py
megatron/inference/generation.py
+104
-33
megatron/inference/tokenization.py
megatron/inference/tokenization.py
+33
-0
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+11
-2
No files found.
megatron/inference/api.py
0 → 100644
View file @
6c40f892
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference API."""
import
torch
from
.communication
import
broadcast_float_list
from
.generation
import
generate_tokens_probs_and_return_on_first_stage
from
.tokenization
import
tokenize_prompts
def
generate
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
return_output_log_probs
=
False
,
return_all_log_probs
=
False
,
temperature
=
1.0
):
"""TO DO ..."""
# Make sure input params are avaialble to all ranks.
values
=
[
tokens_to_generate
,
return_output_log_probs
,
return_all_log_probs
,
temperature
]
values_float_tensor
=
broadcast_float_list
(
4
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
return_output_log_probs
=
bool
(
values_float_tensor
[
1
].
item
())
return_all_log_probs
=
bool
(
values_float_tensor
[
2
].
item
())
temperature
=
values_float_tensor
[
2
].
item
()
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
if
torch
.
distributed
.
get_rank
()
==
0
:
assert
prompts
is
not
None
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
)
# Main inference function.
# Note that the outputs are available on the first stage.
return
generate_tokens_probs_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
,
return_output_log_probs
=
return_output_log_probs
,
return_all_log_probs
=
return_all_log_probs
,
temperature
=
temperature
)
megatron/inference/communication.py
View file @
6c40f892
...
@@ -40,6 +40,33 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
...
@@ -40,6 +40,33 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
return
tensor
return
tensor
def
broadcast_from_last_to_first_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
"""Broadcast tensor values from last stage into the first stage."""
# Only first and last stage pipeline stages need to be involved.
is_last_stage
=
mpu
.
is_pipeline_last_stage
()
is_first_stage
=
mpu
.
is_pipeline_first_stage
()
if
is_last_stage
or
is_first_stage
:
if
is_last_stage
:
assert
tensor
is
not
None
assert
tensor
.
is_cuda
assert
tensor
.
is_contiguous
()
else
:
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
# Broadcast from last stage into the first stage.
torch
.
distributed
.
broadcast
(
tensor
,
src
,
group
)
else
:
tensor
=
None
return
tensor
def
copy_from_last_to_first_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
def
copy_from_last_to_first_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
"""Copy tensor values from last stage into the first stage.
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
Note that the input tensor is updated in place."""
...
@@ -48,20 +75,24 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
...
@@ -48,20 +75,24 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
is_last_stage
=
mpu
.
is_pipeline_last_stage
()
is_last_stage
=
mpu
.
is_pipeline_last_stage
()
is_first_stage
=
mpu
.
is_pipeline_first_stage
()
is_first_stage
=
mpu
.
is_pipeline_first_stage
()
if
is_last_stage
or
is_first_stage
:
if
is_last_stage
or
is_first_stage
:
assert
tensor
is
not
None
assert
tensor
.
is_cuda
is_contiguous
=
tensor
.
is_contiguous
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
if
is_last_stage
:
if
is_contiguous
:
assert
tensor
is
not
None
tensor_
=
tensor
assert
tensor
.
is_cuda
tensor_
=
tensor
.
contiguous
()
else
:
else
:
tensor_
=
torch
.
empty
(
size
,
if
is_last_stage
:
dtype
=
dtype
,
tensor_
=
tensor
.
contiguous
()
device
=
torch
.
cuda
.
current_device
())
else
:
tensor_
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
# Broadcast from last stage into the first stage.
# Broadcast from last stage into the first stage.
torch
.
distributed
.
broadcast
(
tensor_
,
src
,
group
)
torch
.
distributed
.
broadcast
(
tensor_
,
src
,
group
)
# Update the first stage tensor
# Update the first stage tensor
if
is_first_stage
:
if
is_first_stage
and
not
is_contiguous
:
tensor
[...]
=
tensor_
tensor
[...]
=
tensor_
...
...
megatron/inference/generation.py
View file @
6c40f892
...
@@ -19,19 +19,44 @@
...
@@ -19,19 +19,44 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
get_tokenizer
from
megatron
import
get_args
,
get_tokenizer
,
mpu
from
megatron
import
mpu
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
.communication
import
(
from
.communication
import
(
copy_from_last_to_first_pipeline_stage
,
copy_from_last_to_first_pipeline_stage
,
broadcast_from_last_pipeline_stage
)
broadcast_from_last_pipeline_stage
,
broadcast_from_last_to_first_pipeline_stage
)
from
.forward_step
import
forward_step
from
.forward_step
import
forward_step
from
.sampling
import
sample
from
.sampling
import
sample
def
generate_tokens
(
model
,
tokens
,
lengths
,
return_all_probs
=
False
,
def
generate_tokens_probs_and_return_on_first_stage
(
temperature
=
1.0
):
model
,
tokens
,
lengths
,
"""Main token generation function."""
return_output_log_probs
=
False
,
return_all_log_probs
=
False
,
temperature
=
1.0
):
"""Main token generation function.
Arguments:
model: XXX
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
after logits are modifed for sampling.
return_all_log_probs: flag to calculate the log probability of across
all the tokens (vocab size). Note that the log probability is the
one after logits are modifed for sampling.
temperature: sampling temperature.
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs: Note that is size is adjusted to a lower value than
max-sequence-length if generation is terminated early.
tokens: prompt and generated tokens. size: [b, :]
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
all_log_probs: log probability of all the tokens.
size: [b, s, vocab-size]
"""
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
...
@@ -52,18 +77,35 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
...
@@ -52,18 +77,35 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
# Pre-allocate memory
# Pre-allocate memory
# ===================
# ===================
# Log probability of the sequence (prompt + generated tokens)
# Log probability of the sequence (prompt + generated tokens).
output_log_probs
=
torch
.
empty
(
batch_size
,
max_sequence_length
-
1
,
output_log_probs
=
None
dtype
=
torch
.
float32
,
output_log_probs_size
=
(
batch_size
,
max_sequence_length
-
1
)
device
=
torch
.
cuda
.
current_device
())
# Log probability of all tokens for the sequence.
all_log_probs
=
None
all_log_probs_size
=
(
batch_size
,
max_sequence_length
-
1
,
args
.
padded_vocab_size
)
# Lengths of generated seuquence including including prompts.
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths
=
torch
.
ones
(
generated_sequence_lengths
=
None
batch_size
,
dtype
=
torch
.
int64
,
if
mpu
.
is_pipeline_last_stage
():
device
=
torch
.
cuda
.
current_device
())
*
max_sequence_length
if
return_output_log_probs
:
output_log_probs
=
torch
.
empty
(
output_log_probs_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
if
return_all_log_probs
:
all_log_probs
=
torch
.
empty
(
all_log_probs_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
generated_sequence_lengths
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
*
max_sequence_length
# Whether we have reached a termination id.
# Whether we have reached a termination id.
is_generation_done
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
uint8
,
is_generation_done
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
# =============
# Run infernece
# =============
attention_mask
,
position_ids
=
_build_attention_mask_and_position_ids
(
attention_mask
,
position_ids
=
_build_attention_mask_and_position_ids
(
tokens
)
tokens
)
...
@@ -114,15 +156,25 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
...
@@ -114,15 +156,25 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
tokens
[
started
,
context_length
]
=
new_sample
[
started
]
tokens
[
started
,
context_length
]
=
new_sample
[
started
]
# Calculate the log probabilities.
# Calculate the log probabilities.
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
if
return_output_log_probs
or
return_all_log_probs
:
# Pick the tokens that we need to get the log probabilities for.
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
# Note that next input token is the token which we selected in
if
return_all_log_probs
:
# the current logits, so shift by 1.
all_log_probs
[:,
indices
=
torch
.
unsqueeze
(
prev_context_length
:
context_length
,
tokens
[:,
(
prev_context_length
+
1
):(
context_length
+
1
)],
:]
=
log_probs
2
)
if
return_output_log_probs
:
output_log_probs
[:,
prev_context_length
:
context_length
]
=
\
# Pick the tokens that we need to get the log
torch
.
gather
(
log_probs
,
2
,
indices
).
squeeze
(
2
)
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices
=
torch
.
unsqueeze
(
tokens
[
:,
(
prev_context_length
+
1
):(
context_length
+
1
)],
2
)
output_log_probs
[:,
prev_context_length
:
context_length
]
=
\
torch
.
gather
(
log_probs
,
2
,
indices
).
squeeze
(
2
)
# Update the tokens on the first stage so the next input to
# Update the tokens on the first stage so the next input to
# the network is correct.
# the network is correct.
...
@@ -147,17 +199,36 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
...
@@ -147,17 +199,36 @@ def generate_tokens(model, tokens, lengths, return_all_probs=False,
if
done
:
if
done
:
break
break
if
mpu
.
is_pipeline_last_stage
():
# ===================================================
if
return_all_probs
:
# Update the length of based on max generated length.
full_logits
=
None
# ===================================================
return
tokens
,
generated_sequence_lengths
,
output_log_probs
,
\
full_logits
,
context_length
+
1
tokens
=
tokens
[:,
:(
context_length
+
1
)]
return
tokens
,
generated_sequence_lengths
,
output_log_probs
,
\
if
mpu
.
is_pipeline_last_stage
():
None
,
context_length
+
1
if
return_output_log_probs
:
output_log_probs
=
output_log_probs
[:,
:
context_length
]
if
mpu
.
is_pipeline_first_stage
():
if
return_all_log_probs
:
return
tokens
,
None
,
None
,
None
,
context_length
+
1
all_log_probs
=
all_log_probs
[:,
:
context_length
,
:]
return
None
,
None
,
None
,
None
,
context_length
+
1
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
generated_sequence_lengths
=
broadcast_from_last_to_first_pipeline_stage
(
batch_size
,
torch
.
int64
,
generated_sequence_lengths
)
if
return_output_log_probs
:
output_log_probs_size
=
(
batch_size
,
context_length
)
output_log_probs
=
broadcast_from_last_to_first_pipeline_stage
(
output_log_probs_size
,
torch
.
float32
,
output_log_probs
)
if
return_all_log_probs
:
all_log_probs_size
=
(
batch_size
,
context_length
,
args
.
padded_vocab_size
)
all_log_probs
=
broadcast_from_last_to_first_pipeline_stage
(
all_log_probs_size
,
torch
.
float32
,
all_log_probs
)
return
tokens
,
generated_sequence_lengths
,
output_log_probs
,
\
all_log_probs
def
_build_attention_mask_and_position_ids
(
tokens
):
def
_build_attention_mask_and_position_ids
(
tokens
):
...
...
megatron/inference/tokenization.py
View file @
6c40f892
...
@@ -23,6 +23,39 @@ from megatron import get_tokenizer
...
@@ -23,6 +23,39 @@ from megatron import get_tokenizer
from
.communication
import
broadcast_int_list
,
broadcast_tensor
from
.communication
import
broadcast_int_list
,
broadcast_tensor
def
detokenize_generations
(
tokens_gpu_tensor
,
lengths_gpu_tensor
,
return_segments
):
"""Detokenize the generated tokens."""
tokenizer
=
get_tokenizer
()
prompts_plus_generations
=
[]
if
return_segments
:
prompts_plus_generations_segments
=
[]
tokens
=
tokens_gpu_tensor
.
cpu
().
numpy
().
tolist
()
lengths
=
lengths_gpu_tensor
.
cpu
().
numpy
().
tolist
()
for
sequence_tokens
,
length
in
zip
(
tokens
,
lengths
):
sequence_tokens
=
sequence_tokens
[:
length
]
prompts_plus_generations
.
append
(
tokenizer
.
detokenize
(
sequence_tokens
))
if
return_segments
:
words
=
[]
for
token
in
sequence_tokens
:
word
=
tokenizer
.
tokenizer
.
decoder
[
token
]
word
=
bytearray
(
[
tokenizer
.
tokenizer
.
byte_decoder
[
c
]
for
c
in
word
]).
decode
(
'utf-8'
,
errors
=
'replace'
)
words
.
append
(
word
)
prompts_plus_generations_segments
.
append
(
words
)
if
return_segments
:
return
tokens
,
prompts_plus_generations
,
\
prompts_plus_generations_segments
return
tokens
,
prompts_plus_generations
def
tokenize_prompts
(
prompts
=
None
,
tokens_to_generate
=
None
,
rank
=
0
):
def
tokenize_prompts
(
prompts
=
None
,
tokens_to_generate
=
None
,
rank
=
0
):
"""Tokenize prompts and make them avaiable on all ranks."""
"""Tokenize prompts and make them avaiable on all ranks."""
...
...
megatron/text_generation_utils.py
View file @
6c40f892
...
@@ -153,8 +153,12 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
...
@@ -153,8 +153,12 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
print
(
'last rank output size {} {} |
\n
'
.
format
(
output_logits
.
size
(
0
),
output_logits
.
size
(
1
)))
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
if
all_probs
:
if
all_probs
:
print
(
'last rank full size {} {} |
\n
'
.
format
(
full_logits
.
size
(
0
),
full_logits
.
size
(
1
),
full_logits
.
size
(
2
)))
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
torch
.
distributed
.
broadcast
(
full_logits
,
src
,
group
)
torch
.
distributed
.
broadcast
(
full_logits
,
src
,
group
)
...
@@ -164,13 +168,18 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
...
@@ -164,13 +168,18 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
output_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
-
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
output_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
-
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
print
(
'first rank output size {} {} |
\n
'
.
format
(
output_logits
.
size
(
0
),
output_logits
.
size
(
1
)))
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
if
all_probs
:
if
all_probs
:
args
=
get_args
()
args
=
get_args
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
full_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
,
args
.
padded_vocab_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
full_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
-
1
,
args
.
padded_vocab_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
print
(
'first rank full size {} {} |
\n
'
.
format
(
full_logits
.
size
(
0
),
full_logits
.
size
(
1
),
full_logits
.
size
(
2
)))
torch
.
distributed
.
broadcast
(
full_logits
,
src
,
group
)
torch
.
distributed
.
broadcast
(
full_logits
,
src
,
group
)
if
tokens
is
not
None
:
if
tokens
is
not
None
:
return
tokens
[:,
:
context_length
],
output_logits
,
full_logits
return
tokens
[:,
:
context_length
],
output_logits
,
full_logits
...
@@ -204,7 +213,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
...
@@ -204,7 +213,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
output_logits
=
output_logits
.
cpu
().
numpy
().
tolist
()
output_logits
=
output_logits
.
cpu
().
numpy
().
tolist
()
if
all_probs
:
if
all_probs
:
full_logits
=
full_logits
.
cpu
().
numpy
().
tolist
()
full_logits
=
full_logits
.
cpu
().
numpy
()
#
.tolist()
return
resp_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
,
decode_tokens
return
resp_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
,
decode_tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment