Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2eea6216
Commit
2eea6216
authored
Jul 18, 2022
by
rprenger
Browse files
Merging with main and fixing merge conflict
parents
ed6806ac
5f694372
Changes
63
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1931 additions
and
437 deletions
+1931
-437
megatron/model/t5_model.py
megatron/model/t5_model.py
+7
-2
megatron/model/transformer.py
megatron/model/transformer.py
+261
-189
megatron/model/utils.py
megatron/model/utils.py
+2
-1
megatron/model/vision/vit_backbone.py
megatron/model/vision/vit_backbone.py
+3
-1
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+9
-6
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+12
-5
megatron/mpu/layers.py
megatron/mpu/layers.py
+122
-40
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+141
-10
megatron/mpu/random.py
megatron/mpu/random.py
+7
-7
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+22
-10
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+27
-28
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+696
-0
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+337
-121
megatron/p2p_communication.py
megatron/p2p_communication.py
+6
-3
megatron/schedules.py
megatron/schedules.py
+21
-10
megatron/text_generation/__init__.py
megatron/text_generation/__init__.py
+2
-1
megatron/text_generation/api.py
megatron/text_generation/api.py
+53
-1
megatron/text_generation/beam_utils.py
megatron/text_generation/beam_utils.py
+64
-0
megatron/text_generation/forward_step.py
megatron/text_generation/forward_step.py
+12
-1
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+127
-1
No files found.
megatron/model/t5_model.py
View file @
2eea6216
...
...
@@ -152,19 +152,24 @@ class T5Model(MegatronModule):
if
self
.
post_process
and
self
.
add_decoder
:
decoder_output
,
encoder_output
=
lm_output
# Output.
# Output.
[s, b, h]
lm_logits
=
self
.
lm_head
(
decoder_output
,
self
.
word_embeddings_weight
())
if
lm_labels
is
None
:
return
lm_logits
# [s b h] => [b s h]
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
else
:
# [b s] => [s b]
lm_labels
=
lm_labels
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
# [s b] => [b s]
lm_loss
=
lm_loss
.
transpose
(
0
,
1
).
contiguous
()
return
lm_loss
elif
self
.
add_decoder
and
not
self
.
add_encoder
:
decoder_output
,
encoder_output
=
lm_output
...
...
megatron/model/transformer.py
View file @
2eea6216
This diff is collapsed.
Click to expand it.
megatron/model/utils.py
View file @
2eea6216
...
...
@@ -47,7 +47,8 @@ def attention_mask_func(attention_scores, attention_mask):
def
get_linear_layer
(
rows
,
columns
,
init_method
):
"""Simple linear layer with weight initialization."""
layer
=
torch
.
nn
.
Linear
(
rows
,
columns
)
init_method
(
layer
.
weight
)
if
get_args
().
perform_initialization
:
init_method
(
layer
.
weight
)
with
torch
.
no_grad
():
layer
.
bias
.
zero_
()
return
layer
...
...
megatron/model/vision/vit_backbone.py
View file @
2eea6216
...
...
@@ -21,7 +21,6 @@ import torch
import
apex
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
(
get_linear_layer
,
...
...
@@ -148,6 +147,7 @@ class VitBackbone(MegatronModule):
post_process
=
True
,
class_token
=
True
,
single_token_output
=
False
,
post_layer_norm
=
True
,
drop_path_rate
=
0.0
):
super
(
VitBackbone
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
...
...
@@ -165,6 +165,7 @@ class VitBackbone(MegatronModule):
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
class_token
=
class_token
self
.
post_layer_norm
=
post_layer_norm
self
.
hidden_size
=
args
.
hidden_size
self
.
patch_dim
=
args
.
patch_dim
self
.
img_h
=
args
.
img_h
...
...
@@ -218,6 +219,7 @@ class VitBackbone(MegatronModule):
self
.
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
post_layer_norm
=
self
.
post_layer_norm
,
drop_path_rate
=
self
.
drop_path_rate
)
...
...
megatron/mpu/__init__.py
View file @
2eea6216
...
...
@@ -49,18 +49,21 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.layers
import
LinearWithGradAccumulationAndAsync
Allreduce
from
.layers
import
LinearWithGradAccumulationAndAsync
Communication
from
.layers
import
ColumnParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
VocabParallelEmbedding
from
.layers
import
(
set_tensor_model_parallel_attributes
,
set_defaults_if_not_set_tensor_model_parallel_attributes
,
copy_tensor_model_parallel_attributes
)
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
scatter_to_tensor_model_parallel_region
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
scatter_to_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
scatter_to_sequence_parallel_region
from
.mappings
import
gather_from_sequence_parallel_region
from
.mappings
import
reduce_scatter_to_sequence_parallel_region
from
.random
import
checkpoint
from
.random
import
get_cuda_rng_tracker
...
...
megatron/mpu/initialize.py
View file @
2eea6216
...
...
@@ -54,6 +54,12 @@ _POSITION_EMBEDDING_GLOBAL_RANKS = None
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS
=
None
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS
=
None
def
is_unitialized
():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return
_DATA_PARALLEL_GROUP
is
None
...
...
@@ -124,6 +130,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the data-parallel groups.
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GLOBAL_RANKS
assert
_DATA_PARALLEL_GROUP
is
None
,
\
'data parallel group is already initialized'
all_data_parallel_group_ranks
=
[]
...
...
@@ -137,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
_DATA_PARALLEL_GROUP
=
group
_DATA_PARALLEL_GLOBAL_RANKS
=
ranks
# Build the model-parallel groups.
global
_MODEL_PARALLEL_GROUP
...
...
@@ -478,11 +486,10 @@ def get_tensor_model_parallel_src_rank():
def
get_data_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
data_parallel_size
=
get_data_parallel_world_size
()
num_data_parallel_groups
=
torch
.
distributed
.
get_world_size
()
//
data_parallel_size
return
global_rank
%
num_data_parallel_groups
in the data parallel group."""
assert
_DATA_PARALLEL_GLOBAL_RANKS
is
not
None
,
\
"Data parallel group is not initialized"
return
_DATA_PARALLEL_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_first_rank
():
...
...
megatron/mpu/layers.py
View file @
2eea6216
...
...
@@ -30,20 +30,21 @@ from .initialize import get_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_group
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
gather_from_sequence_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
scatter_to_tensor_model_parallel_region
from
.mappings
import
reduce_scatter_to_sequence_parallel_region
from
.random
import
get_cuda_rng_tracker
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
from
.utils
import
VocabUtility
from
megatron
import
get_args
from
megatron
import
get_args
,
get_global_memory_buffer
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
'tensor_model_parallel'
:
False
,
'partition_dim'
:
-
1
,
'partition_stride'
:
1
}
def
param_is_not_tensor_parallel_duplicate
(
param
):
return
(
hasattr
(
param
,
'tensor_model_parallel'
)
and
param
.
tensor_model_parallel
)
or
(
...
...
@@ -166,15 +167,17 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
dtype
=
args
.
params_dtype
))
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
num_embeddings
,
self
.
embedding_dim
,
self
.
num_embeddings_per_partition
,
0
,
init_method
)
if
args
.
perform_initialization
:
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
num_embeddings
,
self
.
embedding_dim
,
self
.
num_embeddings_per_partition
,
0
,
init_method
)
else
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
1
)
if
args
.
perform_initialization
:
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
1
)
def
forward
(
self
,
input_
):
if
self
.
tensor_model_parallel_size
>
1
:
...
...
@@ -199,35 +202,75 @@ class VocabParallelEmbedding(torch.nn.Module):
return
output
class
LinearWithGradAccumulationAndAsync
Allreduce
(
torch
.
autograd
.
Function
):
class
LinearWithGradAccumulationAndAsync
Communication
(
torch
.
autograd
.
Function
):
"""
Linear layer execution with asynchronous
all-reduce
and gradient accumulation
Linear layer execution with asynchronous
communication
and gradient accumulation
fusion in backprop.
"""
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
):
async_grad_allreduce
,
sequence_parallel
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
async_grad_allreduce
=
async_grad_allreduce
output
=
torch
.
matmul
(
input
,
weight
.
t
())
ctx
.
sequence_parallel
=
sequence_parallel
if
sequence_parallel
:
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
\
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
())
total_input
=
all_gather_buffer
else
:
total_input
=
input
output
=
torch
.
matmul
(
total_input
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
import
fused_dense_cuda
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
if
ctx
.
sequence_parallel
:
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
\
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
handle
=
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
total_input
=
all_gather_buffer
else
:
total_input
=
input
grad_input
=
grad_output
.
matmul
(
weight
)
if
ctx
.
sequence_parallel
:
handle
.
wait
()
# Convert the tensor shapes to 2D for execution compatibility
grad_output
=
grad_output
.
view
(
grad_output
.
shape
[
0
]
*
grad_output
.
shape
[
1
],
grad_output
.
shape
[
2
])
input
=
input
.
view
(
input
.
shape
[
0
]
*
input
.
shape
[
1
],
input
.
shape
[
2
])
total_input
=
total_input
.
view
(
total_input
.
shape
[
0
]
*
total_input
.
shape
[
1
],
total_input
.
shape
[
2
])
if
ctx
.
async_grad_allreduce
:
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
...
...
@@ -235,15 +278,38 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
if
ctx
.
sequence_parallel
:
assert
not
ctx
.
async_grad_allreduce
dim_size
=
list
(
input
.
size
())
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# reduce_scatter
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
if
ctx
.
gradient_accumulation_fusion
:
fused_dense_cuda
.
wgrad_gemm_accum_fp32
(
input
,
grad_output
,
weight
.
main_grad
)
import
fused_dense_cuda
fused_dense_cuda
.
wgrad_gemm_accum_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
grad_weight
=
None
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
grad_weight
=
grad_output
.
t
().
matmul
(
total_
input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
sequence_parallel
:
handle
.
wait
()
return
sub_grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
if
ctx
.
async_grad_allreduce
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
...
...
@@ -266,7 +332,7 @@ class ColumnParallelLinear(torch.nn.Module):
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
...
...
@@ -294,16 +360,18 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
input_size
,
dtype
=
args
.
params_dtype
))
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
output_size
,
self
.
input_size
,
self
.
output_size_per_partition
,
0
,
init_method
,
stride
=
stride
,
return_master_weight
=
keep_master_weight_for_test
)
if
args
.
perform_initialization
:
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
output_size
,
self
.
input_size
,
self
.
output_size_per_partition
,
0
,
init_method
,
stride
=
stride
,
return_master_weight
=
keep_master_weight_for_test
)
else
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
input_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
stride
)
if
args
.
perform_initialization
:
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
stride
)
if
bias
:
if
args
.
use_cpu_initialization
:
...
...
@@ -323,23 +391,28 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
async_tensor_model_parallel_allreduce
=
(
args
.
async_tensor_model_parallel_allreduce
and
world_size
>
1
)
self
.
sequence_parallel
=
(
args
.
sequence_parallel
and
world_size
>
1
)
assert
not
self
.
async_tensor_model_parallel_allreduce
or
\
not
self
.
sequence_parallel
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
self
.
async_tensor_model_parallel_allreduce
:
if
self
.
async_tensor_model_parallel_allreduce
or
\
self
.
sequence_parallel
:
input_parallel
=
input_
else
:
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
LinearWithGradAccumulationAndAsync
Allreduce
.
apply
(
output_parallel
=
LinearWithGradAccumulationAndAsync
Communication
.
apply
(
input_parallel
,
self
.
weight
,
bias
,
self
.
gradient_accumulation_fusion
,
self
.
async_tensor_model_parallel_allreduce
)
self
.
async_tensor_model_parallel_allreduce
,
self
.
sequence_parallel
)
if
self
.
gather_output
:
# All-gather across the partitions.
assert
not
self
.
sequence_parallel
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
else
:
output
=
output_parallel
...
...
@@ -402,16 +475,18 @@ class RowParallelLinear(torch.nn.Module):
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
self
.
input_size_per_partition
,
dtype
=
args
.
params_dtype
))
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
output_size
,
self
.
input_size
,
self
.
input_size_per_partition
,
1
,
init_method
,
stride
=
stride
,
return_master_weight
=
keep_master_weight_for_test
)
if
args
.
perform_initialization
:
self
.
master_weight
=
_initialize_affine_weight_cpu
(
self
.
weight
,
self
.
output_size
,
self
.
input_size
,
self
.
input_size_per_partition
,
1
,
init_method
,
stride
=
stride
,
return_master_weight
=
keep_master_weight_for_test
)
else
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
self
.
input_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
1
,
stride
=
stride
)
if
args
.
perform_initialization
:
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
1
,
stride
=
stride
)
if
bias
:
if
args
.
use_cpu_initialization
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
...
...
@@ -420,26 +495,34 @@ class RowParallelLinear(torch.nn.Module):
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
setattr
(
self
.
bias
,
'sequence_parallel'
,
args
.
sequence_parallel
)
# Always initialize bias to zero.
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
assert
not
self
.
sequence_parallel
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
LinearWithGradAccumulationAndAsync
Allreduce
.
apply
(
output_parallel
=
LinearWithGradAccumulationAndAsync
Communication
.
apply
(
input_parallel
,
self
.
weight
,
None
,
self
.
gradient_accumulation_fusion
,
None
)
self
.
gradient_accumulation_fusion
,
None
,
None
)
# All-reduce across all the partitions.
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
self
.
sequence_parallel
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
...
...
@@ -447,4 +530,3 @@ class RowParallelLinear(torch.nn.Module):
output
=
output_
output_bias
=
self
.
bias
return
output
,
output_bias
megatron/mpu/mappings.py
View file @
2eea6216
...
...
@@ -32,13 +32,13 @@ def _reduce(input_):
return
input_
def
_split
(
input_
):
def
_split
_along_last_dim
(
input_
):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
# Split along last dimension.
...
...
@@ -51,12 +51,34 @@ def _split(input_):
return
output
def
_gather
(
input_
):
def
_split_along_first_dim
(
input_
):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
# Split along first dimension.
dim_size
=
input_
.
size
()[
0
]
assert
dim_size
%
world_size
==
0
,
\
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size
=
dim_size
//
world_size
rank
=
get_tensor_model_parallel_rank
()
dim_offset
=
rank
*
local_dim_size
output
=
input_
[
dim_offset
:
dim_offset
+
local_dim_size
].
contiguous
()
return
output
def
_gather_along_last_dim
(
input_
):
"""Gather tensors and concatinate along the last dimension."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
# Size and dimension.
...
...
@@ -73,6 +95,44 @@ def _gather(input_):
return
output
def
_gather_along_first_dim
(
input_
):
"""Gather tensors and concatinate along the first dimension."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
dim_size
=
list
(
input_
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
_all_gather_base
(
output
,
input_
.
contiguous
(),
group
=
get_tensor_model_parallel_group
())
return
output
def
_reduce_scatter_along_first_dim
(
input_
):
"""Reduce-scatter the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
dim_size
=
list
(
input_
.
size
())
assert
dim_size
[
0
]
%
world_size
==
0
,
\
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size
[
0
]
=
dim_size
[
0
]
//
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
_reduce_scatter_base
(
output
,
input_
.
contiguous
(),
group
=
get_tensor_model_parallel_group
())
return
output
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Pass the input to the model parallel region."""
...
...
@@ -110,15 +170,15 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split
(
input_
)
return
_split
_along_last_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_split
(
input_
)
return
_split
_along_last_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather
(
grad_output
)
return
_gather
_along_last_dim
(
grad_output
)
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
...
...
@@ -126,15 +186,73 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_gather
(
input_
)
return
_gather
_along_last_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_gather
(
input_
)
return
_gather
_along_last_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_split
(
grad_output
)
return
_split_along_last_dim
(
grad_output
)
class
_ScatterToSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chuck to the rank."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_split_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_first_dim
(
grad_output
)
class
_GatherFromSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatinate."""
@
staticmethod
def
symbolic
(
graph
,
input_
,
tensor_parallel_output_grad
=
True
):
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
tensor_parallel_output_grad
=
True
):
ctx
.
tensor_parallel_output_grad
=
tensor_parallel_output_grad
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
tensor_parallel_output_grad
=
ctx
.
tensor_parallel_output_grad
# If the computation graph after the gather operation is
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
if
tensor_parallel_output_grad
:
return
_reduce_scatter_along_first_dim
(
grad_output
),
None
else
:
return
_split_along_first_dim
(
grad_output
),
None
class
_ReduceScatterToSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_first_dim
(
grad_output
)
# -----------------
...
...
@@ -155,3 +273,16 @@ def scatter_to_tensor_model_parallel_region(input_):
def
gather_from_tensor_model_parallel_region
(
input_
):
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_sequence_parallel_region
(
input_
):
return
_ScatterToSequenceParallelRegion
.
apply
(
input_
)
def
gather_from_sequence_parallel_region
(
input_
,
tensor_parallel_output_grad
=
True
):
return
_GatherFromSequenceParallelRegion
.
apply
(
input_
,
tensor_parallel_output_grad
)
def
reduce_scatter_to_sequence_parallel_region
(
input_
):
return
_ReduceScatterToSequenceParallelRegion
.
apply
(
input_
)
megatron/mpu/random.py
View file @
2eea6216
...
...
@@ -307,10 +307,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset.
"""
@
staticmethod
def
forward
(
ctx
,
run_function
,
distribute_
checkpoint
ed_activations
,
*
args
):
def
forward
(
ctx
,
run_function
,
distribute_
sav
ed_activations
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
distribute_
checkpoint
ed_activations
\
=
distribute_
checkpoint
ed_activations
ctx
.
distribute_
sav
ed_activations
\
=
distribute_
sav
ed_activations
# Copy the rng states.
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
...
@@ -322,7 +322,7 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if
distribute_
checkpoint
ed_activations
:
if
distribute_
sav
ed_activations
:
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
safely_set_viewless_tensor_data
(
args
[
0
],
...
...
@@ -339,7 +339,7 @@ class CheckpointFunction(torch.autograd.Function):
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs
=
ctx
.
saved_tensors
if
ctx
.
distribute_
checkpoint
ed_activations
:
if
ctx
.
distribute_
sav
ed_activations
:
safely_set_viewless_tensor_data
(
inputs
[
0
],
gather_split_1d_tensor
(
inputs
[
0
].
data
).
view
(
ctx
.
input_0_shape
))
...
...
@@ -372,8 +372,8 @@ class CheckpointFunction(torch.autograd.Function):
return
(
None
,
None
)
+
grads
def
checkpoint
(
function
,
distribute_
checkpoint
ed_activations
,
*
args
):
def
checkpoint
(
function
,
distribute_
sav
ed_activations
,
*
args
):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return
CheckpointFunction
.
apply
(
function
,
distribute_
checkpoint
ed_activations
,
*
args
)
distribute_
sav
ed_activations
,
*
args
)
megatron/optimizer/__init__.py
View file @
2eea6216
...
...
@@ -17,8 +17,8 @@ from apex.optimizers import FusedAdam as Adam
from
apex.optimizers
import
FusedSGD
as
SGD
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
from
.distrib_optimizer
import
DistributedOptimizer
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
...
...
@@ -105,7 +105,11 @@ def get_megatron_optimizer(model,
if
args
.
DDP_impl
==
'local'
:
params_have_main_grad
=
True
if
args
.
fp16
or
args
.
bf16
:
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if
args
.
fp16
or
args
.
bf16
or
args
.
use_distributed_optimizer
:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
...
...
@@ -114,9 +118,11 @@ def get_megatron_optimizer(model,
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler
=
None
# Constant loss scale.
if
args
.
loss_scale
:
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
# Dynamic loss scale.
else
:
if
args
.
fp16
:
...
...
@@ -129,16 +135,22 @@ def get_megatron_optimizer(model,
hysteresis
=
args
.
hysteresis
)
# Megatron optimizer.
return
Float16OptimizerWithFloat16Params
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
,
args
.
bf16
,
grad_scaler
)
opt_ty
=
DistributedOptimizer
\
if
args
.
use_distributed_optimizer
else
\
Float16OptimizerWithFloat16Params
return
opt_ty
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
,
args
.
fp16
,
args
.
bf16
,
grad_scaler
,
model
)
# FP32.
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
)
args
.
use_contiguous_buffers_in_local_ddp
,
model
)
megatron/optimizer/clip_grads.py
View file @
2eea6216
...
...
@@ -21,12 +21,13 @@ from torch._six import inf
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
from
megatron
import
mpu
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
def
clip_grad_norm_fp32
(
parameters
,
max_norm
,
norm_type
=
2
):
def
clip_grad_norm_fp32
(
parameters
,
grads_for_norm
,
max_norm
,
norm_type
=
2
,
model_parallel_group
=
None
):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
...
...
@@ -37,9 +38,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single
Tensor that will be used for calculating the grad norm.
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
model_parallel_group (group): given the nature of the distributed
optimizer, this is passed as an argument.
Returns:
Total norm of the parameters (viewed as a single vector).
...
...
@@ -47,25 +52,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
if
isinstance
(
grads_for_norm
,
torch
.
Tensor
):
grads_for_norm
=
[
grads_for_norm
]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
# Grads.
grads
=
[]
grads_for_norm
=
[]
for
param
in
parameters
:
grad_not_none
=
param
.
grad
is
not
None
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
if
grad_not_none
:
grad
=
param
.
grad
.
detach
()
if
grad_not_none
:
# Make sure the grads are in fp32
if
param
.
grad
is
not
None
:
assert
param
.
grad
.
type
()
==
'torch.cuda.FloatTensor'
grads
.
append
(
grad
)
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
grads_for_norm
.
append
(
grad
)
grads
.
append
(
param
.
grad
.
detach
())
# Norm parameters.
max_norm
=
float
(
max_norm
)
...
...
@@ -79,7 +74,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Take max across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
mpu
.
get_
model_parallel_group
()
)
group
=
model_parallel_group
)
total_norm
=
total_norm_cuda
[
0
].
item
()
else
:
...
...
@@ -88,12 +83,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
grad_norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
[
grads_for_norm
],
False
# no per-parameter norm
)
if
grads_for_norm
:
grad_norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
[
grads_for_norm
],
False
# no per-parameter norm
)
else
:
grad_norm
=
torch
.
cuda
.
FloatTensor
([
0
])
# Since we will be summing across data parallel groups,
# we need the pow(norm-type).
total_norm
=
grad_norm
**
norm_type
...
...
@@ -106,7 +104,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
mpu
.
get_
model_parallel_group
()
)
group
=
model_parallel_group
)
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
# Scale.
...
...
@@ -121,7 +119,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
return
total_norm
def
count_zeros_fp32
(
parameters
):
def
count_zeros_fp32
(
parameters
,
model_parallel_group
):
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
...
...
@@ -130,7 +128,7 @@ def count_zeros_fp32(parameters):
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros
=
0.0
total_num_zeros
=
torch
.
cuda
.
FloatTensor
([
0.0
])
for
param
in
parameters
:
grad_not_none
=
param
.
grad
is
not
None
is_not_shared
=
param_is_not_shared
(
param
)
...
...
@@ -143,7 +141,8 @@ def count_zeros_fp32(parameters):
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
mpu
.
get_model_parallel_group
())
group
=
model_parallel_group
)
total_num_zeros
=
total_num_zeros
.
item
()
return
total_num_zeros
megatron/optimizer/distrib_optimizer.py
0 → 100644
View file @
2eea6216
This diff is collapsed.
Click to expand it.
megatron/optimizer/optimizer.py
View file @
2eea6216
This diff is collapsed.
Click to expand it.
megatron/p2p_communication.py
View file @
2eea6216
...
...
@@ -61,7 +61,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
override_scatter_gather_tensors_in_pipeline
=
False
if
args
.
scatter_gather_tensors_in_pipeline
:
if
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
if
tensor_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
:
tensor_chunk_shape
=
tensor_chunk_shape
//
\
...
...
@@ -93,7 +94,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# Split tensor into smaller chunks if using scatter-gather optimization.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
...
...
@@ -138,7 +140,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# If using scatter-gather optimization, gather smaller chunks.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
...
...
megatron/schedules.py
View file @
2eea6216
...
...
@@ -279,8 +279,12 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
args
=
get_args
()
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
args
.
sequence_parallel
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
seq_length
=
args
.
seq_length
tensor_shape
=
(
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
num_microbatches
=
get_num_microbatches
()
*
num_model_chunks
...
...
@@ -514,18 +518,25 @@ def get_tensor_shapes(rank, model_type):
# Otherwise, send one tensor (pre-transpose).
args
=
get_args
()
tensor_shapes
=
[]
if
args
.
sequence_parallel
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
seq_length
=
args
.
seq_length
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
args
.
sequence_parallel
:
decoder_seq_length
=
args
.
decoder_seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
decoder_seq_length
=
args
.
decoder_seq_length
if
mpu
.
is_pipeline_stage_before_split
(
rank
):
# If next rank is after split, then need transpose for encoder_hidden_state.
if
mpu
.
is_pipeline_stage_before_split
(
rank
+
1
):
tensor_shapes
.
append
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
micro_batch_size
,
args
.
seq_length
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
decoder_seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
args
.
micro_batch_size
,
args
.
seq_length
,
args
.
hidden_size
))
tensor_shapes
.
append
((
decoder_seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
return
tensor_shapes
...
...
megatron/text_generation/__init__.py
View file @
2eea6216
...
...
@@ -16,4 +16,5 @@
from
.api
import
(
generate
,
generate_and_post_process
)
generate_and_post_process
,
beam_search_and_post_process
)
megatron/text_generation/api.py
View file @
2eea6216
...
...
@@ -22,7 +22,8 @@ from megatron import mpu
from
.communication
import
broadcast_float_list
from
.generation
import
(
generate_tokens_probs_and_return_on_first_stage
,
score_and_return_on_first_stage
)
score_and_return_on_first_stage
,
beam_search_and_return_on_first_stage
)
from
.tokenization
import
(
tokenize_prompts
,
detokenize_generations
)
...
...
@@ -148,3 +149,54 @@ def generate(model,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
,
stop_on_double_eol
=
stop_on_double_eol
,
stop_on_eol
=
stop_on_eol
)
def
beam_search_and_post_process
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
beam_size
=
0
,
add_BOS
=
False
,
stop_token
=
50256
,
num_return_gen
=
1
,
length_penalty
=
1
):
"""Run beam search and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens
,
scores
=
beam_search
(
model
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
beam_size
=
beam_size
,
add_BOS
=
add_BOS
,
stop_token
=
stop_token
,
num_return_gen
=
num_return_gen
,
length_penalty
=
length_penalty
)
# Only post-process on first stage.
if
mpu
.
is_pipeline_first_stage
():
lengths
=
tokens
.
size
(
1
)
*
torch
.
ones
(
beam_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
tokens
,
prompts_plus_generations
,
prompts_plus_generations_segments
=
detokenize_generations
(
tokens
,
lengths
,
True
)
scores
=
scores
.
cpu
().
numpy
().
tolist
()
return
prompts_plus_generations
,
prompts_plus_generations_segments
,
scores
return
None
def
beam_search
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
beam_size
=
0
,
add_BOS
=
False
,
stop_token
=
50256
,
num_return_gen
=
1
,
length_penalty
=
1
):
# Make sure input params are avaialble to all ranks.
values
=
[
tokens_to_generate
,
beam_size
,
add_BOS
,
stop_token
,
num_return_gen
,
length_penalty
]
values_float_tensor
=
broadcast_float_list
(
6
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
beam_size
=
int
(
values_float_tensor
[
1
].
item
())
add_BOS
=
bool
(
values_float_tensor
[
2
].
item
())
stop_token
=
int
(
values_float_tensor
[
3
].
item
())
num_return_gen
=
int
(
values_float_tensor
[
4
].
item
())
length_penalty
=
values_float_tensor
[
5
].
item
()
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
add_BOS
=
add_BOS
)
return
beam_search_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
,
beam_size
,
stop_token
=
stop_token
,
num_return_gen
=
num_return_gen
,
length_penalty
=
length_penalty
)
megatron/text_generation/beam_utils.py
0 → 100644
View file @
2eea6216
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
## from huggingface beam search
class
BeamHypotheses
(
object
):
def
__init__
(
self
,
num_beams
,
length_penalty
=
1.0
,
early_stopping
=
False
):
"""
Initialize n-best list of hypotheses.
"""
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
self
.
num_beams
=
num_beams
self
.
beams
=
[]
self
.
worst_score
=
1e9
def
__len__
(
self
):
"""
Number of hypotheses in the list.
"""
return
len
(
self
.
beams
)
def
add
(
self
,
hyp
,
sum_logprobs
,
length
):
"""
Add a new hypothesis to the list.
"""
score
=
sum_logprobs
/
length
**
self
.
length_penalty
if
len
(
self
)
<
self
.
num_beams
or
score
>
self
.
worst_score
:
self
.
beams
.
append
((
score
,
hyp
))
if
len
(
self
)
>
self
.
num_beams
:
sorted_scores
=
sorted
([(
s
,
idx
)
for
idx
,
(
s
,
_
)
in
enumerate
(
self
.
beams
)])
del
self
.
beams
[
sorted_scores
[
0
][
1
]]
self
.
worst_score
=
sorted_scores
[
1
][
0
]
else
:
self
.
worst_score
=
min
(
score
,
self
.
worst_score
)
def
is_done
(
self
,
best_sum_logprobs
,
cur_len
):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if
len
(
self
)
<
self
.
num_beams
:
return
False
elif
self
.
early_stopping
:
return
True
else
:
cur_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
ret
=
self
.
worst_score
>=
cur_score
return
ret
megatron/text_generation/forward_step.py
View file @
2eea6216
...
...
@@ -42,7 +42,18 @@ class InferenceParams:
self
.
batch_size_offset
=
0
self
.
key_value_memory_dict
=
{}
def
swap_key_value_dict
(
self
,
batch_idx
):
"swap between batches"
if
len
(
self
.
key_value_memory_dict
)
==
0
:
raise
ValueError
(
"should not swap when dict in empty"
)
for
layer_number
in
self
.
key_value_memory_dict
.
keys
():
inference_key_memory
,
inference_value_memory
=
self
.
key_value_memory_dict
[
layer_number
]
assert
len
(
batch_idx
)
==
inference_key_memory
.
shape
[
1
]
## make sure batch size is the same
new_inference_key_memory
=
inference_key_memory
[:,
batch_idx
]
new_inference_value_memory
=
inference_value_memory
[:,
batch_idx
]
self
.
key_value_memory_dict
[
layer_number
]
=
(
new_inference_key_memory
,
new_inference_value_memory
)
class
ForwardStep
:
"""Forward step function with all the communications.
...
...
megatron/text_generation/generation.py
View file @
2eea6216
...
...
@@ -26,6 +26,7 @@ from .communication import (
broadcast_from_last_to_first_pipeline_stage
)
from
.forward_step
import
ForwardStep
from
.sampling
import
sample
from
.beam_utils
import
BeamHypotheses
def
score_and_return_on_first_stage
(
model
,
tokens
,
lengths
):
"""Function for just scoring.
...
...
@@ -262,7 +263,7 @@ def generate_tokens_probs_and_return_on_first_stage(
tensor
=
done
)
if
use_eod_token_for_early_termination
and
done
:
break
# ===================================================
# Update the length of based on max generated length.
# ===================================================
...
...
@@ -285,6 +286,131 @@ def generate_tokens_probs_and_return_on_first_stage(
return
tokens
,
generated_sequence_lengths
,
output_log_probs
def
beam_search_and_return_on_first_stage
(
model
,
tokens
,
lengths
,
beam_size
,
stop_token
,
num_return_gen
,
length_penalty
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
batch_size
=
tokens
.
size
(
0
)
assert
(
batch_size
==
1
)
prompt_length
=
lengths
.
item
()
final_sequence_length
=
tokens
.
size
(
1
)
final_sequence_length
=
min
(
final_sequence_length
,
args
.
max_position_embeddings
)
# If the context is too big, this happens
if
prompt_length
>=
final_sequence_length
:
raise
ValueError
(
"context length + tokens_to_generate too large"
)
# forward step.
forward_step
=
ForwardStep
(
model
,
beam_size
,
final_sequence_length
)
beam_hyp
=
BeamHypotheses
(
beam_size
,
length_penalty
)
best_batches
=
None
done
=
torch
.
zeros
(
1
,
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_device
())
scores
=
torch
.
zeros
(
beam_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()).
unsqueeze
(
1
)
scores_size_tensor
,
tokens_size_tensor
=
None
,
None
# =============
# Run infernece
# =============
with
torch
.
no_grad
():
tokens
=
tokens
.
repeat
(
beam_size
,
1
)
attention_mask
,
position_ids
=
_build_attention_mask_and_position_ids
(
tokens
)
prev_context_length
=
0
for
context_length
in
range
(
prompt_length
,
final_sequence_length
):
# Pick the slice that we need to pass through the network.
tokens2use
=
tokens
[:,
prev_context_length
:
context_length
]
positions2use
=
position_ids
[:,
prev_context_length
:
context_length
]
attention_mask2use
=
attention_mask
[
...,
prev_context_length
:
context_length
,
:
context_length
]
# logits will be meanigful only in the last pipeline stage.
logits
=
forward_step
(
tokens2use
,
positions2use
,
attention_mask2use
)
if
mpu
.
is_pipeline_last_stage
():
vocab_size
=
logits
.
size
(
2
)
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
new_scores
=
log_probs
[:,
-
1
,
:]
+
scores
if
context_length
==
prompt_length
:
# if this is the first one
sorted_scores
,
indices
=
torch
.
sort
(
new_scores
[
0
,:],
descending
=
True
)
else
:
sorted_scores
,
indices
=
torch
.
sort
(
new_scores
.
view
(
-
1
),
descending
=
True
)
best_beam_ids
=
torch
.
div
(
indices
[:
2
*
beam_size
],
vocab_size
).
trunc
().
long
()
best_words
=
indices
[:
2
*
beam_size
]
%
vocab_size
best_scores
=
sorted_scores
[:
2
*
beam_size
]
next_beams
=
[]
for
beam_token_rank
,
(
token_id
,
beam_score
,
beam_id
)
in
enumerate
(
zip
(
best_words
,
best_scores
,
best_beam_ids
)
):
if
token_id
.
item
()
==
stop_token
:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams
=
beam_token_rank
>=
beam_size
if
is_beam_token_worse_than_top_num_beams
:
continue
beam_hyp
.
add
(
tokens
[
beam_id
].
clone
(),
beam_score
,
context_length
+
1
-
prompt_length
)
else
:
# add next predicted token since it is not eos_token
next_beams
.
append
((
token_id
,
beam_score
,
beam_id
))
if
len
(
next_beams
)
==
beam_size
:
break
if
beam_hyp
.
is_done
(
best_scores
.
max
().
item
(),
context_length
+
1
-
prompt_length
):
done
=
torch
.
ones
(
1
,
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_device
())
best_batches
=
tokens
.
new
([
item
[
2
]
for
item
in
next_beams
])
tokens
=
tokens
[
best_batches
,:]
tokens
[:,
context_length
]
=
tokens
.
new
([
item
[
0
]
for
item
in
next_beams
])
scores
=
scores
.
new
([
item
[
1
]
for
item
in
next_beams
]).
unsqueeze
(
1
)
# torch.distributed.barrier()
done
=
broadcast_from_last_pipeline_stage
(
1
,
torch
.
uint8
,
done
)
if
done
:
break
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage
(
tokens
.
size
(),
torch
.
int64
,
tokens
)
# set inference key values to make it consistent with best beam index
best_batches
=
broadcast_from_last_pipeline_stage
(
beam_size
,
torch
.
int64
,
best_batches
)
forward_step
.
inference_params
.
swap_key_value_dict
(
best_batches
)
# Update the context length for the next token generation.
prev_context_length
=
context_length
if
mpu
.
is_pipeline_last_stage
():
# if cannot find stop token, add open beams to hyps
if
not
done
:
for
beam_id
in
range
(
beam_size
):
beam_hyp
.
add
(
tokens
[
beam_id
].
clone
(),
scores
[
beam_id
],
context_length
+
1
-
prompt_length
)
# rank based on scores
sorted_hyps
=
sorted
(
beam_hyp
.
beams
,
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
num_return_gen
=
min
(
num_return_gen
,
len
(
sorted_hyps
))
scores
=
[
sorted_hyps
[
i
][
0
]
for
i
in
range
(
num_return_gen
)]
tokens
=
[
sorted_hyps
[
i
][
1
]
for
i
in
range
(
num_return_gen
)]
scores
=
torch
.
stack
(
scores
,
dim
=
0
)
tokens
=
torch
.
stack
(
tokens
,
dim
=
0
)
scores_size_tensor
=
torch
.
tensor
(
scores
.
shape
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
tokens_size_tensor
=
torch
.
tensor
(
tokens
.
shape
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
scores_size_tensor
=
broadcast_from_last_pipeline_stage
(
1
,
torch
.
int64
,
scores_size_tensor
)
tokens_size_tensor
=
broadcast_from_last_pipeline_stage
(
2
,
torch
.
int64
,
tokens_size_tensor
)
scores
=
broadcast_from_last_to_first_pipeline_stage
(
tuple
(
scores_size_tensor
),
torch
.
float32
,
scores
)
tokens
=
broadcast_from_last_to_first_pipeline_stage
(
tuple
(
tokens_size_tensor
),
torch
.
int64
,
tokens
)
return
tokens
,
scores
def
_build_attention_mask_and_position_ids
(
tokens
):
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment