Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
67aa8619
Commit
67aa8619
authored
Nov 01, 2021
by
Jared Casper
Browse files
Merge branch 'main' into checkpoint_util
parents
03d09af0
f5345dfa
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1532 additions
and
485 deletions
+1532
-485
megatron/mpu/layers.py
megatron/mpu/layers.py
+51
-5
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+2
-1
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+2
-2
megatron/p2p_communication.py
megatron/p2p_communication.py
+45
-38
megatron/schedules.py
megatron/schedules.py
+181
-25
megatron/text_generation/__init__.py
megatron/text_generation/__init__.py
+3
-14
megatron/text_generation/api.py
megatron/text_generation/api.py
+120
-0
megatron/text_generation/communication.py
megatron/text_generation/communication.py
+198
-0
megatron/text_generation/forward_step.py
megatron/text_generation/forward_step.py
+208
-0
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+278
-0
megatron/text_generation/sampling.py
megatron/text_generation/sampling.py
+106
-0
megatron/text_generation/tokenization.py
megatron/text_generation/tokenization.py
+131
-0
megatron/text_generation_server.py
megatron/text_generation_server.py
+92
-28
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+0
-333
megatron/training.py
megatron/training.py
+56
-24
pretrain_bert.py
pretrain_bert.py
+4
-3
pretrain_gpt.py
pretrain_gpt.py
+4
-3
pretrain_ict.py
pretrain_ict.py
+2
-0
pretrain_t5.py
pretrain_t5.py
+47
-9
pretrain_vit.py
pretrain_vit.py
+2
-0
No files found.
megatron/mpu/layers.py
View file @
67aa8619
...
...
@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter
from
.initialize
import
get_tensor_model_parallel_rank
from
.initialize
import
get_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_group
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
...
...
@@ -200,6 +201,37 @@ class VocabParallelEmbedding(torch.nn.Module):
return
output
class
ColumnParallelLinearWithAsyncAllreduce
(
torch
.
autograd
.
Function
):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
output
=
torch
.
matmul
(
input
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
grad_input
=
grad_output
.
matmul
(
weight
)
# Asyncronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
"""Linear layer with column parallelism.
...
...
@@ -276,16 +308,30 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
bias
.
zero_
()
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
not
args
.
no_async_tensor_model_parallel_allreduce
and
world_size
>
1
)
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
async_tensor_model_parallel_allreduce
:
input_shape
=
input_
.
shape
input_
=
input_
.
view
(
input_shape
[
0
]
*
input_shape
[
1
],
input_shape
[
2
])
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel
=
ColumnParallelLinearWithAsyncAllreduce
.
apply
(
input_
,
self
.
weight
,
bias
)
output_parallel
=
output_parallel
.
view
(
input_shape
[
0
],
input_shape
[
1
],
output_parallel
.
shape
[
1
])
else
:
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
...
...
megatron/optimizer/clip_grads.py
View file @
67aa8619
...
...
@@ -58,7 +58,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
grad_not_none
=
param
.
grad
is
not
None
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
grad
=
param
.
grad
.
detach
()
if
grad_not_none
:
grad
=
param
.
grad
.
detach
()
if
grad_not_none
:
# Make sure the grads are in fp32
assert
param
.
grad
.
type
()
==
'torch.cuda.FloatTensor'
...
...
megatron/optimizer/optimizer.py
View file @
67aa8619
...
...
@@ -179,7 +179,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a conti
h
uous buffer
for the DDP cases where there is a conti
n
uous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
...
...
@@ -312,7 +312,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
if
self
.
params_have_main_grad
:
if
self
.
params_have_main_grad
and
hasattr
(
model_param
,
'main_grad'
)
:
main_param
.
grad
=
model_param
.
main_grad
.
float
()
else
:
if
model_param
.
grad
is
not
None
:
...
...
megatron/p2p_communication.py
View file @
67aa8619
...
...
@@ -22,8 +22,8 @@ from megatron import mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
use_ring_exchange
=
False
,
tensor_shape
=
None
,
override_scatter_gather_tensors_in_pipelin
e
=
False
,
tensor_shape
,
use_ring_exchang
e
=
False
,
dtype_
=
None
):
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
...
...
@@ -37,16 +37,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
tensor_shape: shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
tensor_shape: optional, use when the input sequence contains less
tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline: optional, this is used
when tensor_shape is
provided to overwide
scatter gather tensors
dtype_: optional, this is used when tensor_shape is provied and what
is the type of tensor_shape
dtype_: optional, this is used when the tensor that needs to be
communicated is different from args.params_dtype.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
...
...
@@ -56,12 +53,22 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
# Some legacy inference code doesn't set the tensor shape, do so now
# for the normal values for gpt/bert. This could be removed if inference
# code is changed to provide tensor_shape.
if
tensor_shape
is
None
:
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
\
mpu
.
get_tensor_model_parallel_world_size
()
override_scatter_gather_tensors_in_pipeline
=
False
if
args
.
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
if
tensor_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
:
tensor_chunk_shape
=
tensor_chunk_shape
//
\
mpu
.
get_tensor_model_parallel_world_size
()
else
:
tensor_chunk_shape
=
tensor_shape
override_scatter_gather_tensors_in_pipeline
=
True
else
:
tensor_chunk_shape
=
tensor_shape
dtype
=
args
.
params_dtype
...
...
@@ -143,9 +150,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
tensor_shape
=
None
,
override_scatter_gather_tensors_in_pipeline
=
False
,
dtype_
=
None
,
timers
=
None
):
def
recv_forward
(
tensor_shape
=
None
,
dtype_
=
None
,
timers
=
None
):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
mpu
.
is_pipeline_first_stage
():
...
...
@@ -159,15 +164,13 @@ def recv_forward(tensor_shape=None,
recv_prev
=
True
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
override_scatter_gather_tensors_in_pipeline
=
\
override_scatter_gather_tensors_in_pipeline
,
dtype_
=
dtype_
)
if
timers
is
not
None
:
timers
(
'forward-recv'
).
stop
()
return
input_tensor
def
recv_backward
(
timers
=
None
):
def
recv_backward
(
tensor_shape
=
None
,
timers
=
None
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
...
...
@@ -178,15 +181,14 @@ def recv_backward(timers=None):
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
)
recv_next
=
True
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
timers
(
'backward-recv'
).
stop
()
return
output_tensor_grad
def
send_forward
(
output_tensor
,
timers
=
None
,
override_scatter_gather_tensors_in_pipeline
=
False
,
dtype_
=
None
):
def
send_forward
(
output_tensor
,
tensor_shape
=
None
,
dtype_
=
None
,
timers
=
None
):
"""Send tensor to next rank in pipeline (forward send)."""
if
not
mpu
.
is_pipeline_last_stage
():
...
...
@@ -197,14 +199,13 @@ def send_forward(output_tensor, timers=None,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
False
,
override_scatter_gather_tensors_in_pipeline
=
\
override_scatter_gather_tensors_in_pipeline
,
tensor_shape
=
tensor_shape
,
dtype_
=
dtype_
)
if
timers
is
not
None
:
timers
(
'forward-send'
).
stop
()
def
send_backward
(
input_tensor_grad
,
timers
=
None
):
def
send_backward
(
input_tensor_grad
,
tensor_shape
=
None
,
timers
=
None
):
"""Send tensor to previous rank in pipeline (backward send)."""
if
not
mpu
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
...
...
@@ -213,12 +214,13 @@ def send_backward(input_tensor_grad, timers=None):
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
)
recv_next
=
False
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
timers
(
'backward-send'
).
stop
()
def
send_forward_recv_backward
(
output_tensor
,
timers
=
None
):
def
send_forward_recv_backward
(
output_tensor
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with next rank in pipeline."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
...
...
@@ -229,13 +231,14 @@ def send_forward_recv_backward(output_tensor, timers=None):
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
)
recv_next
=
True
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
stop
()
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
,
timers
=
None
):
def
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with previous rank in pipeline."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
...
...
@@ -246,13 +249,14 @@ def send_backward_recv_forward(input_tensor_grad, timers=None):
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_next
=
False
)
recv_next
=
False
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
stop
()
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
=
None
):
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
start
()
...
...
@@ -260,13 +264,14 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
False
)
recv_next
=
False
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
stop
()
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
=
None
):
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
start
()
...
...
@@ -274,7 +279,8 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
recv_next
)
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
stop
()
return
output_tensor_grad
...
...
@@ -282,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
,
timers
=
None
):
recv_next
,
tensor_shape
=
None
,
timers
=
None
):
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
...
...
@@ -290,7 +296,8 @@ def send_forward_backward_recv_forward_backward(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
)
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
)
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
stop
()
return
input_tensor
,
output_tensor_grad
megatron/schedules.py
View file @
67aa8619
...
...
@@ -25,6 +25,8 @@ from megatron import p2p_communication
from
megatron.utils
import
unwrap_model
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
def
get_forward_backward_func
():
args
=
get_args
()
...
...
@@ -48,11 +50,18 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
passed-in input_tensor is used.
Returns output tensor."""
args
=
get_args
()
timers
=
get_timers
()
timers
(
'forward-compute'
).
start
()
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrap_output_tensor
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_output_tensor
=
True
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
if
mpu
.
is_pipeline_last_stage
():
...
...
@@ -62,7 +71,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
losses_reduced
.
append
(
loss_reduced
)
timers
(
'forward-compute'
).
stop
()
return
output_tensor
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
if
mpu
.
is_pipeline_stage_after_split
()
and
\
args
.
model_type
==
ModelType
.
encoder_and_decoder
:
return
[
output_tensor
,
input_tensor
[
-
1
]]
if
unwrap_output_tensor
:
return
output_tensor
return
[
output_tensor
]
def
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
...
...
@@ -73,24 +90,53 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
Returns gradient of loss with respect to input tensor (None if first
stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
args
=
get_args
()
timers
=
get_timers
()
timers
(
'backward-compute'
).
start
()
# Retain the grad on the input_tensor.
if
input_tensor
is
not
None
:
input_tensor
.
retain_grad
()
unwrap_input_tensor_grad
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_input_tensor_grad
=
True
for
x
in
input_tensor
:
if
x
is
not
None
:
x
.
retain_grad
()
if
not
isinstance
(
output_tensor
,
list
):
output_tensor
=
[
output_tensor
]
if
not
isinstance
(
output_tensor_grad
,
list
):
output_tensor_grad
=
[
output_tensor_grad
]
# Backward pass.
if
output_tensor_grad
is
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
if
output_tensor_grad
[
0
]
is
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
[
0
]
)
torch
.
autograd
.
backward
(
output_tensor
[
0
]
,
grad_tensors
=
output_tensor_grad
[
0
]
)
# Collect the grad of the input_tensor.
input_tensor_grad
=
None
input_tensor_grad
=
[
None
]
if
input_tensor
is
not
None
:
input_tensor_grad
=
input_tensor
.
grad
input_tensor_grad
=
[]
for
x
in
input_tensor
:
if
x
is
None
:
input_tensor_grad
.
append
(
None
)
else
:
input_tensor_grad
.
append
(
x
.
grad
)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
mpu
.
is_pipeline_stage_after_split
()
and
\
args
.
model_type
==
ModelType
.
encoder_and_decoder
:
if
output_tensor_grad
[
1
]
is
not
None
:
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
if
unwrap_input_tensor_grad
:
input_tensor_grad
=
input_tensor_grad
[
0
]
timers
(
'backward-compute'
).
stop
()
...
...
@@ -153,6 +199,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
pipeline_parallel_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
args
=
get_args
()
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
num_microbatches
=
get_num_microbatches
()
*
num_model_chunks
...
...
@@ -237,7 +286,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes.
mpu
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
timers
=
timers
))
p2p_communication
.
recv_forward
(
tensor_shape
,
timers
=
timers
))
for
k
in
range
(
num_warmup_microbatches
):
output_tensor
=
forward_step_helper
(
k
)
...
...
@@ -266,12 +315,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
timers
=
timers
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
input_tensor
=
\
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
,
timers
=
timers
)
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
timers
=
timers
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
# Run 1F1B in steady state.
...
...
@@ -335,7 +387,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
timers
=
timers
)
tensor_shape
=
tensor_shape
,
timers
=
timers
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
...
...
@@ -349,7 +401,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if
not
forward_only
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
timers
=
timers
))
p2p_communication
.
recv_backward
(
tensor_shape
,
timers
=
timers
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
...
...
@@ -361,11 +413,107 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
timers
=
timers
))
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
timers
=
timers
))
return
losses_reduced
def
get_tensor_shapes
(
rank
,
model_type
):
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
args
=
get_args
()
tensor_shapes
=
[]
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
mpu
.
is_pipeline_stage_before_split
(
rank
):
# If next rank is after split, then need transpose for encoder_hidden_state.
if
mpu
.
is_pipeline_stage_before_split
(
rank
+
1
):
tensor_shapes
.
append
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
micro_batch_size
,
args
.
seq_length
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
decoder_seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
args
.
micro_batch_size
,
args
.
seq_length
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
return
tensor_shapes
def
recv_forward
(
tensor_shapes
,
timers
):
input_tensors
=
[]
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
else
:
input_tensors
.
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
timers
=
timers
))
return
input_tensors
def
recv_backward
(
tensor_shapes
,
timers
):
output_tensor_grads
=
[]
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
else
:
output_tensor_grads
.
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
timers
=
timers
))
return
output_tensor_grads
def
send_forward
(
output_tensors
,
tensor_shapes
,
timers
):
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_forward
(
output_tensor
,
tensor_shape
,
timers
=
timers
)
def
send_backward
(
input_tensor_grads
,
tensor_shapes
,
timers
):
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_backward
(
input_tensor_grad
,
tensor_shape
,
timers
=
timers
)
def
send_forward_recv_backward
(
output_tensors
,
tensor_shapes
,
timers
):
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
output_tensor_grads
=
[]
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
continue
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
tensor_shape
,
timers
=
timers
)
output_tensor_grads
.
append
(
output_tensor_grad
)
return
output_tensor_grads
def
send_backward_recv_forward
(
input_tensor_grads
,
tensor_shapes
,
timers
):
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
input_tensors
=
[]
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
continue
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
,
timers
=
timers
)
input_tensors
.
append
(
input_tensor
)
return
input_tensors
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
...
...
@@ -389,6 +537,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
model_type
=
unwrapped_model
.
model_type
rank
=
mpu
.
get_pipeline_model_parallel_rank
()
recv_tensor_shapes
=
get_tensor_shapes
(
rank
-
1
,
model_type
)
send_tensor_shapes
=
get_tensor_shapes
(
rank
,
model_type
)
# Input, output tensors only need to be saved when doing backward passes
input_tensors
=
None
output_tensors
=
None
...
...
@@ -399,10 +554,10 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
p2p_communication
.
recv_forward
(
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
p2p_communication
.
send_forward
(
output_tensor
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
input_tensors
.
append
(
input_tensor
)
...
...
@@ -412,7 +567,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
input_tensor
=
p2p_communication
.
recv_forward
(
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
...
...
@@ -421,15 +576,16 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
if
forward_only
:
p2p_communication
.
send_forward
(
output_tensor
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
last_iteration
:
input_tensor
=
p2p_communication
.
recv_forward
(
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
else
:
output_tensor_grad
=
\
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
timers
=
timers
)
send_forward_recv_backward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
# Add input_tensor and output_tensor to end of list.
input_tensors
.
append
(
input_tensor
)
...
...
@@ -446,11 +602,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if
last_iteration
:
input_tensor
=
None
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
=
timers
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
else
:
input_tensor
=
\
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
timers
=
timers
)
send_backward_recv_forward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
# Run cooldown backward passes.
if
not
forward_only
:
...
...
@@ -458,12 +614,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
p2p_communication
.
recv_backward
(
timers
=
timers
)
output_tensor_grad
=
recv_backward
(
send_tensor_shapes
,
timers
=
timers
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
=
timers
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
return
losses_reduced
megatron/
package_info
.py
→
megatron/
text_generation/__init__
.py
View file @
67aa8619
...
...
@@ -13,18 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
MAJOR
=
1
MINOR
=
1.5
# Use the following formatting: (major, minor)
VERSION
=
(
MAJOR
,
MINOR
)
__version__
=
'.'
.
join
(
map
(
str
,
VERSION
))
__package_name__
=
'megatron-lm'
__contact_names__
=
'NVIDIA INC'
__url__
=
'https://github.com/NVIDIA/Megatron-LM'
__download_url__
=
'https://github.com/NVIDIA/Megatron-LM/releases'
__description__
=
'Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.'
__license__
=
'See https://github.com/NVIDIA/Megatron-LM/blob/master/LICENSE'
__keywords__
=
'deep learning, Megatron, gpu, NLP, nvidia, pytorch, torch, language'
from
.api
import
(
generate
,
generate_and_post_process
)
megatron/text_generation/api.py
0 → 100644
View file @
67aa8619
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference API."""
import
torch
from
megatron
import
mpu
from
.communication
import
broadcast_float_list
from
.generation
import
(
generate_tokens_probs_and_return_on_first_stage
,
score_and_return_on_first_stage
)
from
.tokenization
import
(
tokenize_prompts
,
detokenize_generations
)
def
generate_and_post_process
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
return_output_log_probs
=
False
,
top_k_sampling
=
0
,
top_p_sampling
=
0.0
,
temperature
=
1.0
,
add_BOS
=
False
,
use_eod_token_for_early_termination
=
True
):
"""Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens
,
lengths
,
output_log_probs
=
generate
(
model
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
return_output_log_probs
=
return_output_log_probs
,
top_k_sampling
=
top_k_sampling
,
top_p_sampling
=
top_p_sampling
,
temperature
=
temperature
,
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
)
# Only post-process on first stage.
if
mpu
.
is_pipeline_first_stage
():
tokens
,
prompts_plus_generations
,
prompts_plus_generations_segments
=
\
detokenize_generations
(
tokens
,
lengths
,
True
)
if
return_output_log_probs
:
output_log_probs
=
output_log_probs
.
cpu
().
numpy
().
tolist
()
for
i
,
(
prob
,
seg
)
in
enumerate
(
zip
(
output_log_probs
,
prompts_plus_generations_segments
)):
output_log_probs
[
i
]
=
prob
[:
len
(
seg
)
-
1
]
return
prompts_plus_generations
,
prompts_plus_generations_segments
,
\
output_log_probs
,
tokens
return
None
def
generate
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
return_output_log_probs
=
False
,
top_k_sampling
=
0
,
top_p_sampling
=
0.0
,
temperature
=
1.0
,
add_BOS
=
False
,
use_eod_token_for_early_termination
=
True
):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
discard tokens in the tokens tensor that are after the
corresponding length.
output_log_probs: log probs of the tokens.
"""
# Make sure input params are avaialble to all ranks.
values
=
[
tokens_to_generate
,
return_output_log_probs
,
top_k_sampling
,
top_p_sampling
,
temperature
,
add_BOS
,
use_eod_token_for_early_termination
]
values_float_tensor
=
broadcast_float_list
(
7
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
return_output_log_probs
=
bool
(
values_float_tensor
[
1
].
item
())
top_k_sampling
=
int
(
values_float_tensor
[
2
].
item
())
top_p_sampling
=
values_float_tensor
[
3
].
item
()
temperature
=
values_float_tensor
[
4
].
item
()
add_BOS
=
bool
(
values_float_tensor
[
5
].
item
())
use_eod_token_for_early_termination
=
bool
(
values_float_tensor
[
6
].
item
())
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
if
torch
.
distributed
.
get_rank
()
==
0
:
assert
prompts
is
not
None
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
add_BOS
=
add_BOS
)
if
tokens_to_generate
==
0
:
return
score_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
)
# Main inference function.
# Note that the outputs are available on the first stage.
return
generate_tokens_probs_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
,
return_output_log_probs
=
return_output_log_probs
,
top_k
=
top_k_sampling
,
top_p
=
top_p_sampling
,
temperature
=
temperature
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
)
megatron/text_generation/communication.py
0 → 100644
View file @
67aa8619
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Communications utilities."""
import
torch
from
megatron
import
mpu
# TODO: use functions from megatron/p2p
def
recv_from_prev_pipeline_rank_
(
recv_buffer
=
None
):
"""Receive from previous pipeline stage and update the
input buffer inplace."""
if
not
mpu
.
is_pipeline_first_stage
():
assert
recv_buffer
is
not
None
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
recv_buffer
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
reqs
=
torch
.
distributed
.
batch_isend_irecv
([
recv_prev_op
])
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
# TODO: use functions from megatron/p2p
def
send_to_next_pipeline_rank
(
tensor
=
None
):
"""Send output to the next pipeline stage."""
if
not
mpu
.
is_pipeline_last_stage
():
assert
tensor
is
not
None
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor
,
mpu
.
get_pipeline_model_parallel_next_rank
())
reqs
=
torch
.
distributed
.
batch_isend_irecv
([
send_next_op
])
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
def
_is_cuda
(
tensor
):
"""Check if a tensor is not none and is cuda."""
assert
tensor
is
not
None
assert
tensor
.
is_cuda
def
_is_cuda_contiguous
(
tensor
):
"""Check if a tensor is not none, is cuda, and is contiguous."""
_is_cuda
(
tensor
)
assert
tensor
.
is_contiguous
()
def
broadcast_from_last_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
"""Broadcast a tensor from last pipeline stage to all ranks."""
is_last_stage
=
mpu
.
is_pipeline_last_stage
()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if
mpu
.
is_pipeline_first_stage
()
and
is_last_stage
:
return
tensor
if
is_last_stage
:
_is_cuda_contiguous
(
tensor
)
else
:
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
# Get the group and corresponding source rank.
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
tensor
,
src
,
group
)
return
tensor
def
broadcast_from_last_to_first_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
"""Broadcast tensor values from last stage into the first stage."""
is_last_stage
=
mpu
.
is_pipeline_last_stage
()
is_first_stage
=
mpu
.
is_pipeline_first_stage
()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if
is_first_stage
and
is_last_stage
:
return
tensor
# Only first and last stage pipeline stages need to be involved.
if
is_last_stage
or
is_first_stage
:
if
is_last_stage
:
_is_cuda_contiguous
(
tensor
)
else
:
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
# Broadcast from last stage into the first stage.
torch
.
distributed
.
broadcast
(
tensor
,
src
,
group
)
else
:
tensor
=
None
return
tensor
def
copy_from_last_to_first_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
is_last_stage
=
mpu
.
is_pipeline_last_stage
()
is_first_stage
=
mpu
.
is_pipeline_first_stage
()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if
is_first_stage
and
is_last_stage
:
return
# Only first and last stage pipeline stages need to be involved.
if
is_last_stage
or
is_first_stage
:
_is_cuda
(
tensor
)
is_contiguous
=
tensor
.
is_contiguous
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
if
is_contiguous
:
tensor_
=
tensor
else
:
if
is_last_stage
:
tensor_
=
tensor
.
contiguous
()
else
:
tensor_
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
# Broadcast from last stage into the first stage.
torch
.
distributed
.
broadcast
(
tensor_
,
src
,
group
)
# Update the first stage tensor
if
is_first_stage
and
not
is_contiguous
:
tensor
[...]
=
tensor_
def
broadcast_tensor
(
size
,
dtype
,
tensor
=
None
,
rank
=
0
):
""" Given size and type of a tensor on all ranks and the tensor value
only on a specific rank, broadcast from that rank to all other ranks.
"""
if
torch
.
distributed
.
get_rank
()
==
rank
:
_is_cuda_contiguous
(
tensor
)
else
:
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
broadcast
(
tensor
,
rank
)
return
tensor
def
broadcast_list
(
size
,
dtype
,
list_values
=
None
,
rank
=
0
):
"""Broadcast a list of values with a given type."""
tensor
=
None
if
torch
.
distributed
.
get_rank
()
==
rank
:
tensor
=
torch
.
tensor
(
list_values
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
return
broadcast_tensor
(
size
,
dtype
,
tensor
=
tensor
,
rank
=
rank
)
def
broadcast_int_list
(
size
,
int_list
=
None
,
rank
=
0
):
"""Broadcast a list of interger values."""
return
broadcast_list
(
size
,
torch
.
int64
,
list_values
=
int_list
,
rank
=
rank
)
def
broadcast_float_list
(
size
,
float_list
=
None
,
rank
=
0
):
"""Broadcast a list of float values."""
return
broadcast_list
(
size
,
torch
.
float32
,
list_values
=
float_list
,
rank
=
rank
)
megatron/text_generation/forward_step.py
0 → 100644
View file @
67aa8619
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Forward step utilities."""
from
collections.abc
import
Iterable
import
torch
from
megatron
import
(
get_args
,
mpu
)
from
.communication
import
(
send_to_next_pipeline_rank
,
recv_from_prev_pipeline_rank_
)
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def
__init__
(
self
,
max_batch_size
,
max_sequence_len
):
"""Note that offsets are set to zero and we always set the
flag to allocate memory. After the first call, make sure to
set this flag to False."""
self
.
max_sequence_len
=
max_sequence_len
self
.
max_batch_size
=
max_batch_size
self
.
sequence_len_offset
=
0
self
.
batch_size_offset
=
0
self
.
key_value_memory_dict
=
{}
class
ForwardStep
:
"""Forward step function with all the communications.
We use a class here to hide the inference parameters
from the outside caller."""
def
__init__
(
self
,
model
,
max_batch_size
,
max_sequence_len
):
"""Set values so we don't need to do it multiple times."""
# Make sure model is in eval mode.
assert
not
isinstance
(
model
,
Iterable
),
\
'interleaving schedule is not supported for inference'
model
.
eval
()
self
.
model
=
model
# Initialize inference parameters.
self
.
inference_params
=
InferenceParams
(
max_batch_size
,
max_sequence_len
)
# Pipelining arguments.
args
=
get_args
()
self
.
pipeline_size_larger_than_one
=
(
args
.
pipeline_model_parallel_size
>
1
)
# Threshold of pipelining.
self
.
pipelining_batch_x_seqlen
=
\
args
.
inference_batch_times_seqlen_threshold
def
__call__
(
self
,
tokens
,
position_ids
,
attention_mask
):
"""Invocation of the forward methods. Note that self.inference_params
is being modified by the forward step."""
# Pipelining case.
if
self
.
pipeline_size_larger_than_one
:
current_batch_x_seqlen
=
tokens
.
size
(
0
)
*
tokens
.
size
(
1
)
if
current_batch_x_seqlen
>=
self
.
pipelining_batch_x_seqlen
:
micro_batch_size
=
\
max
(
1
,
self
.
pipelining_batch_x_seqlen
//
tokens
.
size
(
1
))
return
_with_pipelining_forward_step
(
self
.
model
,
tokens
,
position_ids
,
attention_mask
,
self
.
inference_params
,
micro_batch_size
)
return
_no_pipelining_forward_step
(
self
.
model
,
tokens
,
position_ids
,
attention_mask
,
self
.
inference_params
)
def
_get_recv_buffer_dtype
(
args
):
"""Receive happens between the layers."""
if
args
.
fp32_residual_connection
:
return
torch
.
float
return
args
.
params_dtype
def
_allocate_recv_buffer
(
batch_size
,
sequence_length
):
"""Receive happens between the layers with size [s, b, h]."""
if
mpu
.
is_pipeline_first_stage
():
return
None
args
=
get_args
()
recv_size
=
(
sequence_length
,
batch_size
,
args
.
hidden_size
)
return
torch
.
empty
(
recv_size
,
dtype
=
_get_recv_buffer_dtype
(
args
),
device
=
torch
.
cuda
.
current_device
())
def
_forward_step_helper
(
model
,
tokens
,
position_ids
,
attention_mask
,
inference_params
,
recv_buffer
=
None
):
"""Single forward step. Update the allocate memory flag so
only the first time the memory is allocated."""
batch_size
=
tokens
.
size
(
0
)
sequence_length
=
tokens
.
size
(
1
)
if
recv_buffer
is
None
:
recv_buffer
=
_allocate_recv_buffer
(
batch_size
,
sequence_length
)
# Receive from previous stage.
recv_from_prev_pipeline_rank_
(
recv_buffer
)
# Forward pass through the model.
model
.
set_input_tensor
(
recv_buffer
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
inference_params
=
inference_params
)
# Send output to the next stage.
send_to_next_pipeline_rank
(
output_tensor
)
return
output_tensor
def
_no_pipelining_forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
inference_params
,
recv_buffer
=
None
):
"""If recv_buffer is none, we will allocate one on the fly."""
# Run a simple forward pass.
output_tensor
=
_forward_step_helper
(
model
,
tokens
,
position_ids
,
attention_mask
,
inference_params
,
recv_buffer
=
recv_buffer
)
# Update the sequence length offset.
inference_params
.
sequence_len_offset
+=
tokens
.
size
(
1
)
logits
=
None
if
mpu
.
is_pipeline_last_stage
():
logits
=
output_tensor
return
logits
def
_with_pipelining_forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
inference_params
,
micro_batch_size
):
"""No interleaving is supported."""
sequence_length
=
tokens
.
size
(
1
)
batch_size
=
tokens
.
size
(
0
)
# Divide the batch dimension into micro batches.
num_micro_batches
,
last_chunk
=
divmod
(
batch_size
,
micro_batch_size
)
if
last_chunk
>
0
:
num_micro_batches
+=
1
# Preallocate memory for output logits.
logits
=
None
if
mpu
.
is_pipeline_last_stage
():
args
=
get_args
()
logits
=
torch
.
empty
(
(
batch_size
,
sequence_length
,
args
.
padded_vocab_size
),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
# Preallocate recv buffer.
recv_buffer
=
_allocate_recv_buffer
(
micro_batch_size
,
sequence_length
)
for
micro_batch_index
in
range
(
num_micro_batches
):
# Slice among the batch dimenion.
start
=
micro_batch_index
*
micro_batch_size
end
=
min
(
start
+
micro_batch_size
,
batch_size
)
this_micro_batch_size
=
end
-
start
tokens2use
=
tokens
[
start
:
end
,
...]
position_ids2use
=
position_ids
[
start
:
end
,
...]
# Run a simple forward pass.
if
this_micro_batch_size
!=
micro_batch_size
:
recv_buffer
=
None
output
=
_forward_step_helper
(
model
,
tokens2use
,
position_ids2use
,
attention_mask
,
inference_params
,
recv_buffer
=
recv_buffer
)
# Adjust the batch size offset to account for the micro-batch.
inference_params
.
batch_size_offset
+=
this_micro_batch_size
# Copy logits.
if
mpu
.
is_pipeline_last_stage
():
logits
[
start
:
end
,
...]
=
output
# Once we are done with all the micro-batches, we can
# adjust the sequence length offset.
inference_params
.
sequence_len_offset
+=
sequence_length
# and reset the batch size offset
inference_params
.
batch_size_offset
=
0
return
logits
megatron/text_generation/generation.py
0 → 100644
View file @
67aa8619
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generation utilities."""
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
get_tokenizer
,
mpu
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
.communication
import
(
copy_from_last_to_first_pipeline_stage
,
broadcast_from_last_pipeline_stage
,
broadcast_from_last_to_first_pipeline_stage
)
from
.forward_step
import
ForwardStep
from
.sampling
import
sample
def
score_and_return_on_first_stage
(
model
,
tokens
,
lengths
):
"""Function for just scoring.
Arguments:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max_prompt_length]
lengths: original prompt length, size: [b]
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs:
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args
=
get_args
()
batch_size
=
tokens
.
size
(
0
)
max_prompt_length
=
lengths
.
max
().
item
()
assert
max_prompt_length
==
tokens
.
size
(
1
)
max_sequence_length
=
min
(
max_prompt_length
,
args
.
max_position_embeddings
)
# forward step.
forward_step
=
ForwardStep
(
model
,
batch_size
,
max_sequence_length
)
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs
=
None
output_log_probs_size
=
(
batch_size
,
max_sequence_length
-
1
)
if
mpu
.
is_pipeline_last_stage
():
output_log_probs
=
torch
.
empty
(
output_log_probs_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
# =============
# Run infernece
# =============
with
torch
.
no_grad
():
attention_mask
,
position_ids
=
_build_attention_mask_and_position_ids
(
tokens
)
# logits will be meanigful only in the last pipeline stage.
logits
=
forward_step
(
tokens
,
position_ids
,
attention_mask
)
if
mpu
.
is_pipeline_last_stage
():
# Always the last stage should have an output.
assert
logits
is
not
None
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices
=
torch
.
unsqueeze
(
tokens
[:,
1
:],
2
)
output_log_probs
=
torch
.
gather
(
log_probs
,
2
,
indices
).
squeeze
(
2
)
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
output_log_probs
=
broadcast_from_last_to_first_pipeline_stage
(
output_log_probs_size
,
torch
.
float32
,
output_log_probs
)
return
tokens
,
lengths
,
output_log_probs
def
generate_tokens_probs_and_return_on_first_stage
(
model
,
tokens
,
lengths
,
return_output_log_probs
=
False
,
top_k
=
0
,
top_p
=
0.0
,
temperature
=
1.0
,
use_eod_token_for_early_termination
=
True
):
"""Main token generation function.
Arguments:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
from the original logit.
top_k, top_p: top-k and top-p sampling parameters.
Note that top-k = 1 is gready. Also, these paramters are
exclusive meaning that:
if top-k > 0 then we expect top-p=0.
if top-p > 0 then we check for top-k=0.
temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token.
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs: Note that is size is adjusted to a lower value than
max-sequence-length if generation is terminated early.
tokens: prompt and generated tokens. size: [b, :]
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
batch_size
=
tokens
.
size
(
0
)
min_prompt_length
=
lengths
.
min
().
item
()
max_sequence_length
=
tokens
.
size
(
1
)
max_sequence_length
=
min
(
max_sequence_length
,
args
.
max_position_embeddings
)
# forward step.
forward_step
=
ForwardStep
(
model
,
batch_size
,
max_sequence_length
)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
if
hasattr
(
args
,
'eos_id'
):
termination_id
=
args
.
eos_id
else
:
termination_id
=
tokenizer
.
eod
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs
=
None
output_log_probs_size
=
(
batch_size
,
max_sequence_length
-
1
)
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths
=
None
if
mpu
.
is_pipeline_last_stage
():
if
return_output_log_probs
:
output_log_probs
=
torch
.
empty
(
output_log_probs_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
generated_sequence_lengths
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
*
max_sequence_length
# Whether we have reached a termination id.
is_generation_done
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_device
())
# =============
# Run infernece
# =============
with
torch
.
no_grad
():
attention_mask
,
position_ids
=
_build_attention_mask_and_position_ids
(
tokens
)
prev_context_length
=
0
for
context_length
in
range
(
min_prompt_length
,
max_sequence_length
):
# Pick the slice that we need to pass through the network.
tokens2use
=
tokens
[:,
prev_context_length
:
context_length
]
positions2use
=
position_ids
[:,
prev_context_length
:
context_length
]
attention_mask2use
=
attention_mask
[
...,
prev_context_length
:
context_length
,
:
context_length
]
# logits will be meanigful only in the last pipeline stage.
logits
=
forward_step
(
tokens2use
,
positions2use
,
attention_mask2use
)
if
mpu
.
is_pipeline_last_stage
():
# Always the last stage should have an output.
assert
logits
is
not
None
# Sample.
last_token_logits
=
logits
[:,
-
1
,
:]
new_sample
=
sample
(
last_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
vocab_size
=
tokenizer
.
vocab_size
)
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started
=
lengths
<=
context_length
# Update the tokens.
tokens
[
started
,
context_length
]
=
new_sample
[
started
]
# Calculate the log probabilities.
if
return_output_log_probs
:
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
if
return_output_log_probs
:
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices
=
torch
.
unsqueeze
(
tokens
[
:,
(
prev_context_length
+
1
):(
context_length
+
1
)],
2
)
output_log_probs
[:,
prev_context_length
:
context_length
]
=
\
torch
.
gather
(
log_probs
,
2
,
indices
).
squeeze
(
2
)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage
(
batch_size
,
torch
.
int64
,
tokens
[:,
context_length
])
# Update the context length for the next token generation.
prev_context_length
=
context_length
# Check if all the sequences have hit the termination_id.
done
=
None
if
mpu
.
is_pipeline_last_stage
():
done_token
=
(
new_sample
==
termination_id
).
byte
()
&
\
started
.
byte
()
just_finished
=
(
done_token
&
~
is_generation_done
).
bool
()
generated_sequence_lengths
[
just_finished
.
view
(
-
1
)]
=
\
context_length
+
1
is_generation_done
=
is_generation_done
|
done_token
done
=
torch
.
all
(
is_generation_done
)
done
=
broadcast_from_last_pipeline_stage
(
1
,
torch
.
uint8
,
tensor
=
done
)
if
use_eod_token_for_early_termination
and
done
:
break
# ===================================================
# Update the length of based on max generated length.
# ===================================================
tokens
=
tokens
[:,
:(
context_length
+
1
)]
if
mpu
.
is_pipeline_last_stage
():
if
return_output_log_probs
:
output_log_probs
=
output_log_probs
[:,
:
context_length
]
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
generated_sequence_lengths
=
broadcast_from_last_to_first_pipeline_stage
(
batch_size
,
torch
.
int64
,
generated_sequence_lengths
)
if
return_output_log_probs
:
output_log_probs_size
=
(
batch_size
,
context_length
)
output_log_probs
=
broadcast_from_last_to_first_pipeline_stage
(
output_log_probs_size
,
torch
.
float32
,
output_log_probs
)
return
tokens
,
generated_sequence_lengths
,
output_log_probs
def
_build_attention_mask_and_position_ids
(
tokens
):
"""Build the attention mask and postition ids for the input tokens."""
# Since we are not interested in loss-mask and reset attention/position
# is also False, eod_token is not used so it is safe to set it to None.
attention_mask
,
_
,
position_ids
=
get_ltor_masks_and_position_ids
(
data
=
tokens
,
eod_token
=
None
,
reset_position_ids
=
False
,
reset_attention_mask
=
False
,
eod_mask_loss
=
False
)
return
attention_mask
,
position_ids
megatron/text_generation/sampling.py
0 → 100644
View file @
67aa8619
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sampling utilities.
Part of this code is inspired by:
- https://github.com/ari-holtzman/degen/blob/master/gen.py
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
"""
import
torch
def
modify_logits_for_top_k_filtering
(
logits
,
top_k
):
"""Set the logits for none top-k values to -inf."""
filter_
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
def
modify_logits_for_top_p_filtering
(
logits
,
top_p
):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
).
cumsum
(
dim
=-
1
)
# Filteration based on the cumulative sum.
filter_
=
cumulative_probs
>
top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_
[:,
1
:]
=
filter_
[:,
:
-
1
].
clone
()
# Make sure we at least have one token to select from.
filter_
[...,
0
]
=
0
# Fill in the filtered part
filter_
=
filter_
.
scatter
(
1
,
sorted_indices
,
filter_
)
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
def
sample
(
logits
,
top_k
=
0
,
top_p
=
0.0
,
temperature
=
1.0
,
vocab_size
=
None
):
""" Sample and generate a token.
Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size.
If vocab_size is provided, we will make sure the sample that is
generated is in [0, vocab-size). This will avoid out of vocabulary
generations due to padding.
"""
# Check logits for consistency.
assert
logits
.
ndim
==
2
,
'expected the logits to be of [b, v] shape.'
assert
logits
.
type
()
==
'torch.cuda.FloatTensor'
,
\
'input logits should be floats.'
# Greedy is just simple argmax.
if
top_k
==
1
:
assert
top_p
==
0.0
,
'cannot set both greedy and top-p samplings.'
samples
=
torch
.
argmax
(
logits
,
dim
=-
1
)
# Top-k or top-p sampling.
else
:
# Clone so we do not modify the inputs,
logits
=
logits
.
clone
()
# Apply temperature in place.
if
temperature
!=
1.0
:
logits
.
div_
(
temperature
)
if
top_k
>
1
:
assert
top_p
==
0.0
,
'cannot set both top-k and top-p samplings.'
assert
top_k
<=
logits
.
size
(
1
),
'top-k is larger than logit size.'
if
vocab_size
:
assert
top_k
<
vocab_size
,
'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering
(
logits
,
top_k
)
elif
top_p
>
0.0
:
assert
top_p
<=
1.0
,
'top-p should be in (0, 1].'
modify_logits_for_top_p_filtering
(
logits
,
top_p
)
# After filtering, we need to recalculate the distribution.
probs
=
logits
.
softmax
(
dim
=-
1
)
samples
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
view
(
-
1
)
# If vocab size is provided, make sure the samples are in
# in the range [0, vocab-size).
if
vocab_size
:
samples
=
torch
.
clamp
(
samples
,
min
=
0
,
max
=
(
vocab_size
-
1
))
return
samples
megatron/text_generation/tokenization.py
0 → 100644
View file @
67aa8619
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization utilities."""
import
torch
from
megatron
import
get_tokenizer
from
.communication
import
broadcast_int_list
,
broadcast_tensor
def
detokenize_generations
(
tokens_gpu_tensor
,
lengths_gpu_tensor
,
return_segments
):
"""Detokenize the generated tokens."""
tokenizer
=
get_tokenizer
()
prompts_plus_generations
=
[]
if
return_segments
:
prompts_plus_generations_segments
=
[]
tokens
=
tokens_gpu_tensor
.
cpu
().
numpy
().
tolist
()
lengths
=
lengths_gpu_tensor
.
cpu
().
numpy
().
tolist
()
for
sequence_tokens
,
length
in
zip
(
tokens
,
lengths
):
sequence_tokens
=
sequence_tokens
[:
length
]
prompts_plus_generations
.
append
(
tokenizer
.
detokenize
(
sequence_tokens
))
if
return_segments
:
words
=
[]
for
token
in
sequence_tokens
:
word
=
tokenizer
.
tokenizer
.
decoder
[
token
]
word
=
bytearray
(
[
tokenizer
.
tokenizer
.
byte_decoder
[
c
]
for
c
in
word
]).
decode
(
'utf-8'
,
errors
=
'replace'
)
words
.
append
(
word
)
prompts_plus_generations_segments
.
append
(
words
)
if
return_segments
:
return
tokens
,
prompts_plus_generations
,
\
prompts_plus_generations_segments
return
tokens
,
prompts_plus_generations
def
tokenize_prompts
(
prompts
=
None
,
tokens_to_generate
=
None
,
add_BOS
=
None
,
rank
=
0
):
"""Tokenize prompts and make them avaiable on all ranks."""
# On all ranks set to None so we can pass them to functions
sizes_list
=
None
prompts_tokens_cuda_long_tensor
=
None
prompts_length_cuda_long_tensor
=
None
# On the specified rank, build the above.
if
torch
.
distributed
.
get_rank
()
==
rank
:
assert
prompts
is
not
None
assert
tokens_to_generate
is
not
None
# Tensor of tokens padded and their unpadded length.
prompts_tokens_cuda_long_tensor
,
prompts_length_cuda_long_tensor
=
\
_tokenize_prompts_and_batch
(
prompts
,
tokens_to_generate
,
add_BOS
)
# We need the sizes of these tensors for the boradcast
sizes_list
=
[
prompts_tokens_cuda_long_tensor
.
size
(
0
),
# Batch size
prompts_tokens_cuda_long_tensor
.
size
(
1
)]
# Sequence lenght
# First, broadcast the sizes.
sizes_tensor
=
broadcast_int_list
(
2
,
int_list
=
sizes_list
,
rank
=
rank
)
# Now that we have the sizes, we can boradcast the tokens
# and length tensors.
sizes
=
sizes_tensor
.
tolist
()
prompts_tokens_cuda_long_tensor
=
broadcast_tensor
(
sizes
,
torch
.
int64
,
tensor
=
prompts_tokens_cuda_long_tensor
,
rank
=
rank
)
prompts_length_cuda_long_tensor
=
broadcast_tensor
(
sizes
[
0
],
torch
.
int64
,
tensor
=
prompts_length_cuda_long_tensor
,
rank
=
rank
)
return
prompts_tokens_cuda_long_tensor
,
prompts_length_cuda_long_tensor
def
_tokenize_prompts_and_batch
(
prompts
,
tokens_to_generate
,
add_BOS
):
"""Given a set of prompts and number of tokens to generate:
- tokenize prompts
- set the sequence length to be the max of length of prompts
plus the number of tokens we would like to generate
- pad all the sequences to this length so we can convert them
into a 2D tensor.
"""
# Tokenize all the prompts.
tokenizer
=
get_tokenizer
()
if
add_BOS
:
prompts_tokens
=
[[
tokenizer
.
eod
]
+
tokenizer
.
tokenize
(
prompt
)
for
prompt
in
prompts
]
else
:
prompts_tokens
=
[
tokenizer
.
tokenize
(
prompt
)
for
prompt
in
prompts
]
# Now we have a list of list of tokens which each list has a different
# size. We want to extend this list to:
# - incorporate the tokens that need to be generated
# - make all the sequences equal length.
# Get the prompts length.
prompts_length
=
[
len
(
prompt_tokens
)
for
prompt_tokens
in
prompts_tokens
]
# Get the max prompts length.
max_prompt_len
=
max
(
prompts_length
)
# Number of tokens in the each sample of the batch.
samples_length
=
max_prompt_len
+
tokens_to_generate
# Now update the list of list to be of the same size: samples_length.
for
prompt_tokens
,
prompt_length
in
zip
(
prompts_tokens
,
prompts_length
):
padding_size
=
samples_length
-
prompt_length
prompt_tokens
.
extend
([
tokenizer
.
eod
]
*
padding_size
)
# Now we are in a structured format, we can convert to tensors.
prompts_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
prompts_tokens
)
prompts_length_tensor
=
torch
.
cuda
.
LongTensor
(
prompts_length
)
return
prompts_tokens_tensor
,
prompts_length_tensor
megatron/text_generation_server.py
View file @
67aa8619
...
...
@@ -12,55 +12,119 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
datetime
import
torch
import
json
import
threading
from
flask
import
Flask
,
request
,
jsonify
,
current_app
from
flask_restful
import
Resource
,
Api
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.text_generation_utils
import
generate
from
megatron
.text_generation
import
generate_and_post_process
GENERATE_NUM
=
0
lock
=
threading
.
Lock
()
class
MegatronGenerate
(
Resource
):
def
__init__
(
self
,
model
):
self
.
model
=
model
@
staticmethod
def
send_do_generate
():
choice
=
torch
.
cuda
.
LongTensor
([
GENERATE_NUM
])
torch
.
distributed
.
broadcast
(
choice
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
choice
,
0
)
def
put
(
self
):
args
=
get_args
()
sentences
=
request
.
get_json
()[
"sentences"
]
if
len
(
sentences
)
>
128
:
return
"Maximum number of sentences is 128"
,
400
max_len
=
64
# Choosing hopefully sane default. Full sequence is slow
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
print
(
json
.
dumps
(
request
.
get_json
()),
flush
=
True
)
print
(
"current time: "
,
datetime
.
datetime
.
now
())
if
not
"prompts"
in
request
.
get_json
():
return
"prompts argument required"
,
400
if
"max_len"
in
request
.
get_json
():
max_len
=
request
.
get_json
()[
"max_len"
]
if
not
isinstance
(
max_len
,
int
):
return
"max_len must be an integer greater than 0"
if
max_len
<
1
:
return
"max_len must be an integer greater than 0"
return
"max_len is no longer used. Replace with tokens_to_generate"
,
400
if
"sentences"
in
request
.
get_json
():
return
"sentences is no longer used. Replace with prompts"
,
400
prompts
=
request
.
get_json
()[
"prompts"
]
if
len
(
prompts
)
>
128
:
return
"Maximum number of prompts is 128"
,
400
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
resp_sentences
=
generate
(
self
.
model
,
sentences
,
max_len
)
return
jsonify
({
"sentences"
:
resp_sentences
})
tokens_to_generate
=
64
# Choosing hopefully sane default. Full sequence is slow
if
"tokens_to_generate"
in
request
.
get_json
():
tokens_to_generate
=
request
.
get_json
()[
"tokens_to_generate"
]
if
not
isinstance
(
tokens_to_generate
,
int
):
return
"tokens_to_generate must be an integer greater than 0"
if
tokens_to_generate
<
0
:
return
"tokens_to_generate must be an integer greater than or equal to 0"
logprobs
=
False
if
"logprobs"
in
request
.
get_json
():
logprobs
=
request
.
get_json
()[
"logprobs"
]
if
not
isinstance
(
logprobs
,
bool
):
return
"logprobs must be a boolean value"
if
tokens_to_generate
==
0
and
not
logprobs
:
return
"tokens_to_generate=0 implies logprobs should be True"
temperature
=
1.0
if
"temperature"
in
request
.
get_json
():
temperature
=
request
.
get_json
()[
"temperature"
]
if
not
(
type
(
temperature
)
==
int
or
type
(
temperature
)
==
float
):
return
"temperature must be a positive number less than or equal to 100.0"
if
not
(
0.0
<
temperature
<=
100.0
):
return
"temperature must be a positive number less than or equal to 100.0"
top_k
=
0.0
if
"top_k"
in
request
.
get_json
():
top_k
=
request
.
get_json
()[
"top_k"
]
if
not
(
type
(
top_k
)
==
int
):
return
"top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
if
not
(
0
<=
top_k
<=
1000
):
return
"top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p
=
0.0
if
"top_p"
in
request
.
get_json
():
top_p
=
request
.
get_json
()[
"top_p"
]
if
not
(
type
(
top_p
)
==
float
):
return
"top_p must be a positive float less than or equal to 1.0"
if
top_p
>
0.0
and
top_k
>
0.0
:
return
"cannot set both top-k and top-p samplings."
if
not
(
0
<=
top_p
<=
1.0
):
return
"top_p must be less than or equal to 1.0"
add_BOS
=
False
if
"add_BOS"
in
request
.
get_json
():
add_BOS
=
request
.
get_json
()[
"add_BOS"
]
if
not
isinstance
(
add_BOS
,
bool
):
return
"add_BOS must be a boolean value"
def
index
():
return
current_app
.
send_static_file
(
'index.html'
)
with
lock
:
# Need to get lock to keep multiple threads from hitting code
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
response
,
response_seg
,
response_logprobs
,
_
=
\
generate_and_post_process
(
self
.
model
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
return_output_log_probs
=
logprobs
,
top_k_sampling
=
top_k
,
top_p_sampling
=
top_p
,
temperature
=
temperature
,
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
True
)
return
jsonify
({
"text"
:
response
,
"segments"
:
response_seg
,
"logprobs"
:
response_logprobs
})
class
MegatronServer
(
object
):
def
__init__
(
self
,
model
):
self
.
app
=
Flask
(
__name__
)
self
.
app
.
add_url_rule
(
'/'
,
'index'
,
index
)
self
.
app
=
Flask
(
__name__
,
static_url_path
=
''
)
api
=
Api
(
self
.
app
)
api
.
add_resource
(
MegatronGenerate
,
'/
generate
'
,
resource_class_args
=
[
model
])
def
run
(
self
,
url
):
self
.
app
.
run
(
url
,
threaded
=
Fals
e
,
debug
=
False
)
api
.
add_resource
(
MegatronGenerate
,
'/
api
'
,
resource_class_args
=
[
model
])
def
run
(
self
,
url
):
self
.
app
.
run
(
url
,
threaded
=
Tru
e
,
debug
=
False
)
megatron/text_generation_utils.py
deleted
100644 → 0
View file @
03d09af0
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for generating text."""
import
copy
import
json
import
os
import
time
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.utils
import
get_ltor_masks_and_position_ids
,
unwrap_model
from
megatron.p2p_communication
import
recv_forward
,
send_forward
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
def
get_batch
(
context_tokens
):
"""Generate batch from context tokens."""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Move to GPU.
tokens
=
context_tokens
.
contiguous
().
cuda
()
# Get the attention mask and postition ids.
attention_mask
,
_
,
position_ids
=
get_ltor_masks_and_position_ids
(
tokens
,
tokenizer
.
eod
,
args
.
reset_position_ids
,
args
.
reset_attention_mask
,
args
.
eod_mask_loss
)
return
tokens
,
attention_mask
,
position_ids
def
top_k_logits
(
logits
,
top_k
=
0
,
top_p
=
0.0
,
filter_value
=-
float
(
'Inf'
)):
""" This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313 """
if
top_k
>
0
:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
filter_value
if
top_p
>
0.0
:
# Cconvert to 1D
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
,
dim
=-
1
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove
=
cumulative_probs
>
top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove
[...,
1
:]
\
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
for
i
in
range
(
sorted_indices
.
size
(
0
)):
indices_to_remove
=
sorted_indices
[
i
][
sorted_indices_to_remove
[
i
]]
logits
[
i
][
indices_to_remove
]
=
filter_value
return
logits
def
pad_batch
(
batch
,
pad_id
,
args
):
context_lengths
=
[]
for
tokens
in
batch
:
context_length
=
len
(
tokens
)
if
context_length
<
args
.
seq_length
:
tokens
.
extend
([
pad_id
]
*
(
args
.
seq_length
-
context_length
))
context_lengths
.
append
(
context_length
)
return
batch
,
context_lengths
def
tokenize_batch
(
sentences
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
context_tokens
=
[
tokenizer
.
tokenize
(
s
)
for
s
in
sentences
]
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
.
eod
,
args
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
return
context_tokens_tensor
,
context_length_tensor
def
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
max_len
):
"""
Needs to be synced up with receive_generate_info
"""
# Send the sizes of the tensors
input_info
=
[
context_tokens_tensor
.
size
(
0
),
context_tokens_tensor
.
size
(
1
),
max_len
]
input_info_tensor
=
torch
.
cuda
.
LongTensor
(
input_info
)
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
# Send variables to all ranks
torch
.
distributed
.
broadcast
(
context_length_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
0
)
def
receive_generate_info
():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor
=
torch
.
empty
(
3
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
batch_size
=
input_info_tensor
[
0
].
item
()
seq_len
=
input_info_tensor
[
1
].
item
()
max_len
=
input_info_tensor
[
2
].
item
()
context_length_tensor
=
torch
.
empty
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
context_tokens_tensor
=
torch
.
empty
(
batch_size
,
seq_len
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
# Send variables to all ranks
torch
.
distributed
.
broadcast
(
context_length_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
0
)
return
context_length_tensor
,
context_tokens_tensor
,
max_len
def
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
max_len
):
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
,
max_len
)
for
tokens
,
lengths
in
batch_token_iterator
:
context_length
+=
1
if
tokens
is
not
None
:
return
tokens
[:,
:
context_length
]
def
generate
(
model
,
sentences
=
None
,
max_len
=
0
):
model
.
eval
()
if
torch
.
distributed
.
get_rank
()
==
0
:
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
)
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
max_len
)
else
:
context_length_tensor
,
context_tokens_tensor
,
max_len
=
receive_generate_info
()
decode_tokens
=
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
max_len
)
if
torch
.
distributed
.
get_rank
()
==
0
:
args
=
get_args
()
tokenizer
=
get_tokenizer
()
resp_sentences
=
[]
for
i
in
range
(
decode_tokens
.
size
(
0
)):
decode_token
=
decode_tokens
[
i
,:].
cpu
().
numpy
().
tolist
()
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
return
resp_sentences
def
generate_samples_eval
(
model
,
context
,
max_gen_length
,
eos_token_id
):
"""
This function is here to provide an a matching API for a legacy task
This implementation hasn't been tested yet to make sure it matches
"""
assert
False
,
"Implementation untested"
args
=
get_args
()
args
.
eos_id
=
eos_token_id
raw_text_len
=
len
(
context
)
resp_sentences
=
generate
(
model
,
[
context
],
max_gen_length
)
return
resp_sentences
[
0
][
raw_text_len
:]
def
switch
(
val1
,
val2
,
boolean
):
boolean
=
boolean
.
type_as
(
val1
)
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
def
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
,
layer_past
=
None
,
get_key_value
=
None
,
forward_method_parallel_output
=
None
):
# Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size
args
=
get_args
()
orig_seq_length
=
args
.
seq_length
args
.
seq_length
=
tokens
.
shape
[
1
]
args
.
micro_batch_size
=
tokens
.
shape
[
0
]
input_tensor
=
recv_forward
()
# Forward pass through the model.
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
if
get_key_value
:
output_tensor
,
layer_past
=
output_tensor
send_forward
(
output_tensor
)
args
.
seq_length
=
orig_seq_length
if
get_key_value
:
return
output_tensor
,
layer_past
return
output_tensor
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
maxlen
=
None
,
type_ids
=
None
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
model
.
eval
()
with
torch
.
no_grad
():
context_length
=
context_lengths
.
min
().
item
()
# added eos_id to support the function generate_samples_eval that passes
# eos_id as an argument and needs termination when that id id found.
if
hasattr
(
args
,
'eos_id'
):
eos_id
=
args
.
eos_id
else
:
eos_id
=
tokenizer
.
eod
counter
=
0
org_context_length
=
context_length
layer_past
=
None
batch_size
=
context_tokens
.
size
(
0
)
is_done
=
torch
.
zeros
([
batch_size
]).
byte
().
cuda
()
tokens
=
context_tokens
if
maxlen
is
None
:
maxlen
=
args
.
seq_length
-
1
maxlen
=
maxlen
+
org_context_length
if
maxlen
>
(
org_context_length
+
args
.
out_seq_length
):
maxlen
=
org_context_length
+
args
.
out_seq_length
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
while
context_length
<
maxlen
:
types2use
=
None
if
counter
==
0
:
tokens2use
=
tokens
[:,
:
context_length
]
positions2use
=
position_ids
[:,
:
context_length
]
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
:
context_length
]
else
:
tokens2use
=
tokens
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
positions2use
=
position_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
output
,
layer_past
=
forward_step
(
model
,
tokens2use
,
positions2use
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
True
,
tokentype_ids
=
types2use
,
forward_method_parallel_output
=
False
)
if
mpu
.
is_pipeline_last_stage
():
assert
output
is
not
None
logits
=
output
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
mpu
.
is_pipeline_last_stage
():
if
args
.
greedy
:
prev
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
-
1
)
else
:
logits
=
logits
.
float
()
logits
/=
args
.
temperature
logits
=
top_k_logits
(
logits
,
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
started
=
context_lengths
<=
context_length
new_tokens
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
]
=
new_tokens
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
done_token
=
(
prev
==
eos_id
).
byte
()
&
started
.
byte
()
just_finished
=
(
done_token
&
~
is_done
).
bool
()
lengths
[
just_finished
.
view
(
-
1
)]
=
context_length
is_done
=
is_done
|
done_token
done
=
torch
.
all
(
is_done
)
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
yield
tokens
,
lengths
else
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
new_tokens
=
torch
.
empty_like
(
tokens
[:,
context_length
])
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
tokens
[:,
context_length
]
=
new_tokens
yield
tokens
,
None
else
:
yield
None
,
None
done
=
torch
.
cuda
.
ByteTensor
([
0
])
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
context_length
+=
1
counter
+=
1
if
done
:
break
megatron/training.py
View file @
67aa8619
...
...
@@ -38,6 +38,7 @@ from megatron import print_rank_last
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
...
...
@@ -61,6 +62,7 @@ def print_datetime(string):
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
model_type
,
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{}):
...
...
@@ -77,6 +79,7 @@ def pretrain(train_valid_test_dataset_provider,
train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
model_type: an enum that specifies the type of model being trained.
forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example
...
...
@@ -109,7 +112,8 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate.
timers
(
'model-and-optimizer-setup'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
timers
(
'model-and-optimizer-setup'
).
stop
()
print_datetime
(
'after model, optimizer, and learning rate '
'scheduler are built'
)
...
...
@@ -189,13 +193,16 @@ def update_train_iters(args):
print_rank_0
(
'setting training iterations to {}'
.
format
(
args
.
train_iters
))
def
get_model
(
model_provider_func
):
def
get_model
(
model_provider_func
,
model_type
=
ModelType
.
encoder_or_decoder
,
wrap_with_ddp
=
True
):
"""Build the model."""
args
=
get_args
()
args
.
model_type
=
model_type
# Build model.
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
model_type
!=
ModelType
.
encoder_and_decoder
,
\
"Interleaved schedule not supported for model with both encoder and decoder"
model
=
[]
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
...
...
@@ -206,14 +213,36 @@ def get_model(model_provider_func):
pre_process
=
pre_process
,
post_process
=
post_process
)
this_model
.
model_type
=
model_type
model
.
append
(
this_model
)
else
:
pre_process
=
mpu
.
is_pipeline_first_stage
()
post_process
=
mpu
.
is_pipeline_last_stage
()
model
=
model_provider_func
(
pre_process
=
pre_process
,
post_process
=
post_process
)
add_encoder
=
True
add_decoder
=
True
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
assert
args
.
pipeline_model_parallel_split_rank
is
not
None
,
\
"Split rank needs to be specified for model with both encoder and decoder"
rank
=
mpu
.
get_pipeline_model_parallel_rank
()
split_rank
=
args
.
pipeline_model_parallel_split_rank
world_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pre_process
=
rank
==
0
or
rank
==
split_rank
post_process
=
(
rank
==
(
split_rank
-
1
))
or
(
rank
==
(
world_size
-
1
))
add_encoder
=
mpu
.
is_pipeline_stage_before_split
()
add_decoder
=
mpu
.
is_pipeline_stage_after_split
()
model
=
model_provider_func
(
pre_process
=
pre_process
,
post_process
=
post_process
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
)
else
:
model
=
model_provider_func
(
pre_process
=
pre_process
,
post_process
=
post_process
)
model
.
model_type
=
model_type
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
...
...
@@ -243,22 +272,24 @@ def get_model(model_provider_func):
if
args
.
fp16
or
args
.
bf16
:
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
model
=
[
torch
DDP
(
model_module
,
device_ids
=
[
i
],
outpu
t_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
for
model_module
in
model
]
retur
n
model
if
wrap_with_ddp
:
i
f
args
.
DDP_impl
=
=
'
torch
'
:
i
=
torch
.
cuda
.
curren
t_device
()
model
=
[
torchDDP
(
model_module
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
for
model_module
i
n
model
]
if
args
.
DDP_impl
==
'local'
:
model
=
[
LocalDDP
(
model_module
,
args
.
accumulate_allreduce_grads_in_fp32
,
args
.
use_contiguous_buffers_in_local_ddp
)
for
model_module
in
model
]
return
model
elif
args
.
DDP_impl
==
'local'
:
model
=
[
LocalDDP
(
model_module
,
args
.
accumulate_allreduce_grads_in_fp32
,
args
.
use_contiguous_buffers_in_local_ddp
)
for
model_module
in
model
]
else
:
raise
NotImplementedError
(
'Unknown DDP implementation specified: '
'{}. Exiting.'
.
format
(
args
.
DDP_impl
))
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
'Exiting.'
.
format
(
args
.
DDP_impl
))
return
model
def
get_learning_rate_scheduler
(
optimizer
):
...
...
@@ -304,11 +335,11 @@ def get_learning_rate_scheduler(optimizer):
return
lr_scheduler
def
setup_model_and_optimizer
(
model_provider_func
):
def
setup_model_and_optimizer
(
model_provider_func
,
model_type
):
"""Setup model and optimizer."""
args
=
get_args
()
model
=
get_model
(
model_provider_func
)
model
=
get_model
(
model_provider_func
,
model_type
)
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
...
...
@@ -377,13 +408,14 @@ def train_step(forward_step_func, data_iterator,
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers
(
'backward-embedding-all-reduce'
).
start
()
if
(
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
or
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
))
and
\
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
...
...
pretrain_bert.py
View file @
67aa8619
...
...
@@ -25,7 +25,7 @@ from megatron import print_rank_0
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
,
ModelType
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
@@ -143,5 +143,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
pretrain_gpt.py
View file @
67aa8619
...
...
@@ -23,7 +23,7 @@ from megatron import get_timers
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
GPTModel
from
megatron.model
import
GPTModel
,
ModelType
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
@@ -121,5 +121,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
pretrain_ict.py
View file @
67aa8619
...
...
@@ -28,6 +28,7 @@ from megatron import get_timers
from
megatron
import
mpu
from
megatron.data.biencoder_dataset_utils
import
get_ict_batch
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
ModelType
from
megatron.model.biencoder_model
import
biencoder_model_provider
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
@@ -174,5 +175,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
pretrain_ict_model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
pretrain_t5.py
View file @
67aa8619
...
...
@@ -26,18 +26,58 @@ from megatron import (
print_rank_0
)
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
T5Model
from
megatron.model
import
T5Model
,
ModelType
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""
Pipeline parallelism for T5
===========================
T5 is a model architecture with both encoder and decoder blocks.
Consequently, pipeline parallelism is implemented slightly differently
compared to architectures like GPT and BERT.
In particular, when pipeline_model_parallel_world_size > 1, each stage
either executes an encoder block or a decoder block. The
--pipeline-model-parallel-split-rank argument controls the rank at which
the split happens: all ranks lower than this argument execute the
encoder block, and all ranks equal to or higher than this argument value
execute the decoder block.
In the encoder section of the model, only one tensor is sent downstream:
the intermediate encoder_hidden_state. In the decoder section of the
model, two tensors are sent downstream in the forward pass: the fully
computed encoder_hidden_state, and the intermediate decoder_hidden_state.
In particular, these are the shapes of the tensors sent between
different workers:
If rank is in decoder section:
intermediate decoder_hidden_state (pre-transpose),
complete encoder_hidden_state (post-transpose).
If rank is at boundary between encoder and decoder sections:
complete encoder_hidden_state (post-transpose).
If rank is in encoder section:
intermediate encoder_hidden_state (pre-transpose).
Additionally, we have code in the backward_step function in schedules.py
to accumulate the encoder_hidden_state gradient across skip connections
(encoder_hidden_state fed in as input to each layer in the decoder).
"""
def
model_provider
(
pre_process
=
True
,
post_process
=
True
,
add_encoder
=
True
,
add_decoder
=
True
):
"""Build the model."""
assert
pre_process
and
post_process
,
"T5 doesn't yet support pipelining"
print_rank_0
(
'building T5 model ...'
)
model
=
T5Model
(
num_tokentypes
=
0
,
parallel_output
=
True
)
parallel_output
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
)
return
model
...
...
@@ -70,9 +110,7 @@ def get_batch(data_iterator):
def
loss_func
(
loss_mask
,
output_tensor
):
lm_loss_
,
_
=
output_tensor
lm_loss_
=
lm_loss_
.
float
()
lm_loss_
=
output_tensor
.
float
()
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
...
...
@@ -130,5 +168,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
encoder_and_decoder
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
pretrain_vit.py
View file @
67aa8619
...
...
@@ -20,6 +20,7 @@ import torch.nn.functional as F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model
import
ModelType
from
megatron.model.vit_model
import
VitModel
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
...
...
@@ -92,6 +93,7 @@ if __name__ == "__main__":
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
args_defaults
=
{
'dataloader_type'
:
'cyclic'
}
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment