Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8c119d80
Commit
8c119d80
authored
Oct 10, 2021
by
mshoeybi
Browse files
tested and working
parent
a7539b0f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
155 additions
and
53 deletions
+155
-53
megatron/arguments.py
megatron/arguments.py
+13
-0
megatron/inference/api.py
megatron/inference/api.py
+41
-11
megatron/inference/communication.py
megatron/inference/communication.py
+36
-0
megatron/inference/forward_step.py
megatron/inference/forward_step.py
+47
-32
megatron/inference/generation.py
megatron/inference/generation.py
+18
-10
No files found.
megatron/arguments.py
View file @
8c119d80
...
...
@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser
=
_add_biencoder_args
(
parser
)
parser
=
_add_vit_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
parser
=
_add_inference_args
(
parser
)
# Custom arguments.
if
extra_args_provider
is
not
None
:
...
...
@@ -273,6 +274,18 @@ def _check_arg_is_not_none(args, arg):
assert
getattr
(
args
,
arg
)
is
not
None
,
'{} argument is None'
.
format
(
arg
)
def
_add_inference_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'inference'
)
group
.
add_argument
(
'--inference-batch-times-seqlen-threshold'
,
type
=
int
,
default
=
512
,
help
=
'During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.'
)
return
parser
def
_add_network_size_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'network size'
)
...
...
megatron/inference/api.py
View file @
8c119d80
...
...
@@ -26,14 +26,20 @@ from .tokenization import (
detokenize_generations
)
def
generate_and_post_process
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
return_output_log_probs
=
False
,
return_all_log_probs
=
False
,
greedy_sampling
=
False
,
top_k_sampling
=
0
,
top_p_sampling
=
0.0
,
temperature
=
1.0
,
add_BOS
=
False
):
"""TO DO ..."""
add_BOS
=
False
,
use_eod_token_for_early_termination
=
True
):
"""Run inferecne and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens
,
lengths
,
output_log_probs
,
all_log_probs
=
generate
(
...
...
@@ -42,8 +48,12 @@ def generate_and_post_process(model,
tokens_to_generate
=
tokens_to_generate
,
return_output_log_probs
=
return_output_log_probs
,
return_all_log_probs
=
return_all_log_probs
,
greedy_sampling
=
greedy_sampling
,
top_k_sampling
=
top_k_sampling
,
top_p_sampling
=
top_p_sampling
,
temperature
=
temperature
,
add_BOS
=
add_BOS
)
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
)
# Only post-process on first stage.
if
mpu
.
is_pipeline_first_stage
():
...
...
@@ -62,24 +72,42 @@ def generate_and_post_process(model,
return
None
def
generate
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
return_output_log_probs
=
False
,
return_all_log_probs
=
False
,
greedy_sampling
=
False
,
top_k_sampling
=
0
,
top_p_sampling
=
0.0
,
temperature
=
1.0
,
add_BOS
=
False
):
"""TO DO ..."""
add_BOS
=
False
,
use_eod_token_for_early_termination
=
True
):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
discard tokens in the tokens tensor that are after the
corresponding length.
output_log_probs: log probs of the tokens.
all_log_probs: full log probs for all of tokens.
"""
# Make sure input params are avaialble to all ranks.
values
=
[
tokens_to_generate
,
return_output_log_probs
,
return_all_log_probs
,
temperature
,
add_BOS
]
values_float_tensor
=
broadcast_float_list
(
5
,
float_list
=
values
)
values
=
[
tokens_to_generate
,
return_output_log_probs
,
return_all_log_probs
,
greedy_sampling
,
top_k_sampling
,
top_p_sampling
,
temperature
,
add_BOS
,
use_eod_token_for_early_termination
]
values_float_tensor
=
broadcast_float_list
(
9
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
return_output_log_probs
=
bool
(
values_float_tensor
[
1
].
item
())
return_all_log_probs
=
bool
(
values_float_tensor
[
2
].
item
())
temperature
=
values_float_tensor
[
3
].
item
()
add_BOS
=
bool
(
values_float_tensor
[
4
].
item
())
greedy_sampling
=
bool
(
values_float_tensor
[
3
].
item
())
top_k_sampling
=
int
(
values_float_tensor
[
4
].
item
())
top_p_sampling
=
values_float_tensor
[
5
].
item
()
temperature
=
values_float_tensor
[
6
].
item
()
add_BOS
=
bool
(
values_float_tensor
[
7
].
item
())
use_eod_token_for_early_termination
=
bool
(
values_float_tensor
[
8
].
item
())
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
...
...
@@ -95,4 +123,6 @@ def generate(model,
model
,
context_tokens_tensor
,
context_length_tensor
,
return_output_log_probs
=
return_output_log_probs
,
return_all_log_probs
=
return_all_log_probs
,
temperature
=
temperature
)
greedy
=
greedy_sampling
,
top_k
=
top_k_sampling
,
top_p
=
top_p_sampling
,
temperature
=
temperature
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
)
megatron/inference/communication.py
View file @
8c119d80
...
...
@@ -21,6 +21,38 @@ import torch
from
megatron
import
mpu
def
recv_from_prev_pipeline_rank_
(
recv_buffer
=
None
):
"""Receive from previous pipeline stage and update the
input buffer inplace."""
if
not
mpu
.
is_pipeline_first_stage
():
assert
recv_buffer
is
not
None
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
recv_buffer
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
reqs
=
torch
.
distributed
.
batch_isend_irecv
([
recv_prev_op
])
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
def
send_to_next_pipeline_rank
(
tensor
=
None
):
"""Send output to the next pipeline stage."""
if
not
mpu
.
is_pipeline_last_stage
():
assert
tensor
is
not
None
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor
,
mpu
.
get_pipeline_model_parallel_next_rank
())
reqs
=
torch
.
distributed
.
batch_isend_irecv
([
send_next_op
])
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
def
broadcast_from_last_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
"""Broadcast a tensor from last pipeline stage to all ranks."""
...
...
@@ -96,6 +128,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
tensor
[...]
=
tensor_
def
broadcast_tensor
(
size
,
dtype
,
tensor
=
None
,
rank
=
0
):
""" Given size and type of a tensor on all ranks and the tensor value
only on a specific rank, broadcast from that rank to all other ranks.
...
...
@@ -114,6 +147,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0):
return
tensor
def
broadcast_list
(
size
,
dtype
,
list_values
=
None
,
rank
=
0
):
"""Broadcast a list of values with a given type."""
...
...
@@ -125,12 +159,14 @@ def broadcast_list(size, dtype, list_values=None, rank=0):
return
broadcast_tensor
(
size
,
dtype
,
tensor
=
tensor
,
rank
=
rank
)
def
broadcast_int_list
(
size
,
int_list
=
None
,
rank
=
0
):
"""Broadcast a list of interger values."""
return
broadcast_list
(
size
,
torch
.
int64
,
list_values
=
int_list
,
rank
=
rank
)
def
broadcast_float_list
(
size
,
float_list
=
None
,
rank
=
0
):
"""Broadcast a list of float values."""
...
...
megatron/inference/forward_step.py
View file @
8c119d80
...
...
@@ -22,14 +22,20 @@ import torch
from
megatron
import
(
get_args
,
mpu
)
from
.communication
import
(
send_to_next_pipeline_rank
,
recv_from_prev_pipeline_rank_
)
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def
__init__
(
self
,
max_batch_size
,
max_sequence_len
):
"""Note that offsets are set to zero and we always set the
flag to allocate memory. After the first call, make sure to
set this flag to False."""
self
.
max_sequence_len
=
max_sequence_len
self
.
max_batch_size
=
max_batch_size
self
.
sequence_len_offset
=
0
...
...
@@ -39,38 +45,50 @@ class InferenceParams:
class
ForwardStep
:
"""Forward step function with all the communications.
We use a class here to hide the inference parameters
from the outside caller."""
def
__init__
(
self
,
model
,
max_batch_size
,
max_sequence_len
):
"""Set values so we don't need to do it multiple times."""
# Make sure model is in eval mode.
if
isinstance
(
model
,
Iterable
):
for
this_model
in
model
:
this_model
.
eval
()
else
:
model
.
eval
()
assert
not
isinstance
(
model
,
Iterable
),
\
'interleaving schedule is not supported for inference'
model
.
eval
()
self
.
model
=
model
self
.
constant
=
512
# Initialize inference parameters.
self
.
inference_params
=
InferenceParams
(
max_batch_size
,
max_sequence_len
)
# Pipelining arguments.
args
=
get_args
()
self
.
pipeline_size_larger_than_one
=
args
.
pipeline_model_parallel_size
# Threshold of pipelining.
self
.
pipelining_batch_x_seqlen
=
\
args
.
inference_batch_times_seqlen_threshold
def
__call__
(
self
,
tokens
,
position_ids
,
attention_mask
):
if
tokens
.
size
(
0
)
*
tokens
.
size
(
1
)
>=
self
.
constant
:
micro_batch_size
=
max
(
1
,
self
.
constant
//
tokens
.
size
(
1
))
return
_with_pipelining_forward_step
(
self
.
model
,
tokens
,
position_ids
,
attention_mask
,
self
.
inference_params
,
micro_batch_size
)
else
:
return
_no_pipelining_forward_step
(
self
.
model
,
tokens
,
position_ids
,
attention_mask
,
self
.
inference_params
)
"""Invocation of the forward methods. Note that self.inference_params
is being modified by the forward step."""
# Pipelining case.
if
self
.
pipeline_size_larger_than_one
:
current_batch_x_seqlen
=
tokens
.
size
(
0
)
*
tokens
.
size
(
1
)
if
current_batch_x_seqlen
>=
self
.
pipelining_batch_x_seqlen
:
micro_batch_size
=
\
max
(
1
,
self
.
pipelining_batch_x_seqlen
//
tokens
.
size
(
1
))
return
_with_pipelining_forward_step
(
self
.
model
,
tokens
,
position_ids
,
attention_mask
,
self
.
inference_params
,
micro_batch_size
)
return
_no_pipelining_forward_step
(
self
.
model
,
tokens
,
position_ids
,
attention_mask
,
self
.
inference_params
)
def
_get_recv_buffer_dtype
(
args
):
...
...
@@ -103,9 +121,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
recv_buffer
=
_allocate_recv_buffer
(
batch_size
,
sequence_length
)
# Receive from previous stage.
if
not
mpu
.
is_pipeline_first_stage
():
torch
.
distributed
.
recv
(
recv_buffer
,
src
=
mpu
.
get_pipeline_model_parallel_prev_rank
())
recv_from_prev_pipeline_rank_
(
recv_buffer
)
# Forward pass through the model.
model
.
set_input_tensor
(
recv_buffer
)
...
...
@@ -113,9 +129,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
inference_params
=
inference_params
)
# Send output to the next stage.
if
not
mpu
.
is_pipeline_last_stage
():
torch
.
distributed
.
send
(
output_tensor
,
mpu
.
get_pipeline_model_parallel_next_rank
())
send_to_next_pipeline_rank
(
output_tensor
)
# Make sure we do not allocate context memory anymore.
if
inference_params
.
allocate_key_value_memory
:
...
...
@@ -128,7 +142,7 @@ def _forward_step_helper(model, tokens, position_ids, attention_mask,
def
_no_pipelining_forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
inference_params
,
recv_buffer
=
None
):
"""If recv_buffer is none, we will allocate one on the fly."""
# Run a simple forward pass.
output_tensor
=
_forward_step_helper
(
model
,
tokens
,
position_ids
,
attention_mask
,
inference_params
,
...
...
@@ -143,9 +157,10 @@ def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
return
logits
def
_with_pipelining_forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
inference_params
,
micro_batch_size
):
"""No interleaving is supported."""
sequence_length
=
tokens
.
size
(
1
)
batch_size
=
tokens
.
size
(
0
)
...
...
megatron/inference/generation.py
View file @
8c119d80
...
...
@@ -32,10 +32,12 @@ def generate_tokens_probs_and_return_on_first_stage(
model
,
tokens
,
lengths
,
return_output_log_probs
=
False
,
return_all_log_probs
=
False
,
temperature
=
1.0
):
greedy
=
False
,
top_k
=
0
,
top_p
=
0.0
,
temperature
=
1.0
,
use_eod_token_for_early_termination
=
True
):
"""Main token generation function.
Arguments:
model:
XXX
model:
no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
...
...
@@ -44,7 +46,14 @@ def generate_tokens_probs_and_return_on_first_stage(
return_all_log_probs: flag to calculate the log probability of across
all the tokens (vocab size). Note that the log probability is the
one after logits are modifed for sampling.
greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
Note that these three paramters are exclusive meaning that:
if greedy = true then we should have top-k=top-p=0.
if top-k > 0 then we expect greedy=false and top-p=0.
if top-p > 0 then we check for greedy=false and top-k=0.
temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token.
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs: Note that is size is adjusted to a lower value than
...
...
@@ -108,10 +117,9 @@ def generate_tokens_probs_and_return_on_first_stage(
# Run infernece
# =============
attention_mask
,
position_ids
=
_build_attention_mask_and_position_ids
(
tokens
)
with
torch
.
no_grad
():
attention_mask
,
position_ids
=
_build_attention_mask_and_position_ids
(
tokens
)
prev_context_length
=
0
for
context_length
in
range
(
min_prompt_length
,
max_sequence_length
):
...
...
@@ -132,9 +140,9 @@ def generate_tokens_probs_and_return_on_first_stage(
last_token_logits
=
logits
[:,
-
1
,
:]
new_sample
,
updated_last_token_logits
=
sample
(
last_token_logits
,
greedy
=
args
.
greedy
,
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
,
greedy
=
greedy
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
vocab_size
=
tokenizer
.
vocab_size
)
# Now that we have the sample and updated logits,
...
...
@@ -189,8 +197,8 @@ def generate_tokens_probs_and_return_on_first_stage(
done
=
torch
.
all
(
is_generation_done
)
done
=
broadcast_from_last_pipeline_stage
(
1
,
torch
.
uint8
,
tensor
=
done
)
#
if done:
#
break
if
use_eod_token_for_early_termination
and
done
:
break
# ===================================================
# Update the length of based on max generated length.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment