Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c0f10643
Commit
c0f10643
authored
Mar 02, 2022
by
Vijay Korthikanti
Browse files
layernorm grad sync + name chnages
parent
5d4689c4
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
79 additions
and
61 deletions
+79
-61
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+8
-1
megatron/model/language_model.py
megatron/model/language_model.py
+7
-2
megatron/model/transformer.py
megatron/model/transformer.py
+13
-6
megatron/model/vision/vit_backbone.py
megatron/model/vision/vit_backbone.py
+0
-1
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+5
-6
megatron/mpu/layers.py
megatron/mpu/layers.py
+8
-8
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+17
-36
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+0
-1
megatron/training.py
megatron/training.py
+21
-0
No files found.
megatron/model/fused_layer_norm.py
View file @
c0f10643
...
@@ -67,7 +67,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
...
@@ -67,7 +67,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
no_persist_layer_norm
=
True
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
no_persist_layer_norm
=
True
,
sequence_parallel
=
False
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
global
fused_mix_prec_layer_norm_cuda
global
fused_mix_prec_layer_norm_cuda
...
@@ -92,6 +94,11 @@ class MixedFusedLayerNorm(torch.nn.Module):
...
@@ -92,6 +94,11 @@ class MixedFusedLayerNorm(torch.nn.Module):
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
reset_parameters
()
self
.
reset_parameters
()
self
.
no_persist_layer_norm
=
no_persist_layer_norm
self
.
no_persist_layer_norm
=
no_persist_layer_norm
self
.
sequence_parallel
=
sequence_parallel
# set sequence parallelism flag on weight and bias parameters
self
.
weight
.
sequence_parallel
=
self
.
sequence_parallel
self
.
bias
.
sequence_parallel
=
self
.
sequence_parallel
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
...
...
megatron/model/language_model.py
View file @
c0f10643
...
@@ -29,8 +29,13 @@ from megatron.model.utils import init_method_normal, scaled_init_method_normal
...
@@ -29,8 +29,13 @@ from megatron.model.utils import init_method_normal, scaled_init_method_normal
def
parallel_lm_logits
(
input_
,
word_embeddings_weight
,
parallel_output
,
def
parallel_lm_logits
(
input_
,
word_embeddings_weight
,
parallel_output
,
bias
=
None
):
bias
=
None
):
"""LM logits using word embedding weights."""
"""LM logits using word embedding weights."""
args
=
get_args
()
# Parallel logits.
# Parallel logits.
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
if
not
args
.
model_parallel_memory_opt
:
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
else
:
input_parallel
=
input_
# Matrix multiply.
# Matrix multiply.
if
bias
is
None
:
if
bias
is
None
:
logits_parallel
=
F
.
linear
(
input_parallel
,
word_embeddings_weight
)
logits_parallel
=
F
.
linear
(
input_parallel
,
word_embeddings_weight
)
...
@@ -40,7 +45,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -40,7 +45,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if
parallel_output
:
if
parallel_output
:
return
logits_parallel
return
logits_parallel
return
mpu
.
gather_
along_last_dim_
from_tensor_model_parallel_region
(
logits_parallel
)
return
mpu
.
gather_from_tensor_model_parallel_region
(
logits_parallel
)
def
get_language_model
(
num_tokentypes
,
add_pooler
,
def
get_language_model
(
num_tokentypes
,
add_pooler
,
...
...
megatron/model/transformer.py
View file @
c0f10643
...
@@ -447,7 +447,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -447,7 +447,8 @@ class ParallelTransformerLayer(MegatronModule):
self
.
input_layernorm
=
LayerNorm
(
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
model_parallel_memory_opt
)
# Self attention.
# Self attention.
self
.
self_attention
=
ParallelAttention
(
self
.
self_attention
=
ParallelAttention
(
...
@@ -464,7 +465,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -464,7 +465,8 @@ class ParallelTransformerLayer(MegatronModule):
self
.
post_attention_layernorm
=
LayerNorm
(
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
model_parallel_memory_opt
)
if
self
.
layer_type
==
LayerType
.
decoder
:
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
self
.
inter_attention
=
ParallelAttention
(
...
@@ -476,7 +478,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -476,7 +478,8 @@ class ParallelTransformerLayer(MegatronModule):
self
.
post_inter_attention_layernorm
=
LayerNorm
(
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
model_parallel_memory_opt
)
# MLP
# MLP
self
.
mlp
=
ParallelMLP
(
init_method
,
self
.
mlp
=
ParallelMLP
(
init_method
,
...
@@ -697,7 +700,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -697,7 +700,8 @@ class ParallelTransformer(MegatronModule):
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
model_parallel_memory_opt
)
def
_get_layer
(
self
,
layer_number
):
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
return
self
.
layers
[
layer_number
]
...
@@ -775,7 +779,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -775,7 +779,7 @@ class ParallelTransformer(MegatronModule):
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
model_parallel_memory_opt
:
if
self
.
model_parallel_memory_opt
:
hidden_states
=
mpu
.
scatter_
along_first_dim_to_tensor_model
_parallel_region
(
hidden_states
)
hidden_states
=
mpu
.
scatter_
to_sequence
_parallel_region
(
hidden_states
)
else
:
else
:
# See set_input_tensor()
# See set_input_tensor()
...
@@ -806,6 +810,9 @@ class ParallelTransformer(MegatronModule):
...
@@ -806,6 +810,9 @@ class ParallelTransformer(MegatronModule):
if
encoder_output
is
not
None
:
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
model_parallel_memory_opt
:
encoder_output
=
mpu
.
scatter_to_sequence_parallel_region
(
encoder_output
)
# Forward pass.
# Forward pass.
if
self
.
activations_checkpoint_method
is
not
None
:
if
self
.
activations_checkpoint_method
is
not
None
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
...
@@ -829,7 +836,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -829,7 +836,7 @@ class ParallelTransformer(MegatronModule):
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
if
self
.
model_parallel_memory_opt
:
if
self
.
model_parallel_memory_opt
:
hidden_states
=
mpu
.
gather_
along_first_dim_from_tensor_model
_parallel_region
(
hidden_states
)
hidden_states
=
mpu
.
gather_
from_sequence
_parallel_region
(
hidden_states
)
output
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
else
:
else
:
...
...
megatron/model/vision/vit_backbone.py
View file @
c0f10643
...
@@ -21,7 +21,6 @@ import torch
...
@@ -21,7 +21,6 @@ import torch
import
apex
import
apex
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
(
from
megatron.model.utils
import
(
get_linear_layer
,
get_linear_layer
,
...
...
megatron/mpu/__init__.py
View file @
c0f10643
...
@@ -58,12 +58,11 @@ from .layers import (set_tensor_model_parallel_attributes,
...
@@ -58,12 +58,11 @@ from .layers import (set_tensor_model_parallel_attributes,
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
scatter_along_last_dim_to_tensor_model_parallel_region
from
.mappings
import
scatter_to_tensor_model_parallel_region
from
.mappings
import
gather_along_last_dim_from_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
scatter_along_first_dim_to_tensor_model_parallel_region
from
.mappings
import
scatter_to_sequence_parallel_region
from
.mappings
import
gather_along_first_dim_from_tensor_model_parallel_region
from
.mappings
import
gather_from_seqeuence_parallel_region
from
.mappings
import
reduce_scatter_along_first_dim_to_tensor_model_parallel_region
from
.mappings
import
reduce_scatter_to_sequence_parallel_region
from
.mappings
import
reduce_scatter_along_last_dim_to_tensor_model_parallel_region
from
.random
import
checkpoint
from
.random
import
checkpoint
from
.random
import
get_cuda_rng_tracker
from
.random
import
get_cuda_rng_tracker
...
...
megatron/mpu/layers.py
View file @
c0f10643
...
@@ -29,11 +29,11 @@ from .initialize import get_tensor_model_parallel_rank
...
@@ -29,11 +29,11 @@ from .initialize import get_tensor_model_parallel_rank
from
.initialize
import
get_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_group
from
.initialize
import
get_tensor_model_parallel_group
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
gather_
along_first_dim_
from_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
gather_
along_last_dim_from_tensor_model
_parallel_region
from
.mappings
import
gather_
from_sequence
_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
scatter_
along_last_dim_
to_tensor_model_parallel_region
from
.mappings
import
scatter_to_tensor_model_parallel_region
from
.mappings
import
reduce_scatter_
along_first_dim_to_tensor_model
_parallel_region
from
.mappings
import
reduce_scatter_
to_sequence
_parallel_region
from
.random
import
get_cuda_rng_tracker
from
.random
import
get_cuda_rng_tracker
from
.utils
import
divide
from
.utils
import
divide
...
@@ -328,7 +328,7 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -328,7 +328,7 @@ class ColumnParallelLinear(torch.nn.Module):
else
:
else
:
# Set up backprop all-reduce.
# Set up backprop all-reduce.
if
self
.
model_parallel_memory_opt
:
if
self
.
model_parallel_memory_opt
:
input_parallel
=
gather_
along_first_dim_from_tensor_model
_parallel_region
(
input_
)
input_parallel
=
gather_
from_sequence
_parallel_region
(
input_
)
else
:
else
:
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
...
@@ -338,7 +338,7 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -338,7 +338,7 @@ class ColumnParallelLinear(torch.nn.Module):
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
assert
not
self
.
model_parallel_memory_opt
assert
not
self
.
model_parallel_memory_opt
output
=
gather_
along_last_dim_
from_tensor_model_parallel_region
(
output_parallel
)
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
...
@@ -433,12 +433,12 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -433,12 +433,12 @@ class RowParallelLinear(torch.nn.Module):
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
assert
not
self
.
model_parallel_memory_opt
assert
not
self
.
model_parallel_memory_opt
input_parallel
=
scatter_
along_last_dim_
to_tensor_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
# All-reduce across all the partitions.
# All-reduce across all the partitions.
if
self
.
model_parallel_memory_opt
:
if
self
.
model_parallel_memory_opt
:
output_
=
reduce_scatter_
along_first_dim_to_tensor_model
_parallel_region
(
output_parallel
)
output_
=
reduce_scatter_
to_sequence
_parallel_region
(
output_parallel
)
else
:
else
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
if
not
self
.
skip_bias_add
:
...
...
megatron/mpu/mappings.py
View file @
c0f10643
...
@@ -32,7 +32,6 @@ def _reduce(input_):
...
@@ -32,7 +32,6 @@ def _reduce(input_):
return
input_
return
input_
def
_split_along_last_dim
(
input_
):
def
_split_along_last_dim
(
input_
):
"""Split the tensor along its last dimension and keep the
"""Split the tensor along its last dimension and keep the
corresponding slice."""
corresponding slice."""
...
@@ -51,6 +50,7 @@ def _split_along_last_dim(input_):
...
@@ -51,6 +50,7 @@ def _split_along_last_dim(input_):
return
output
return
output
def
_split_along_first_dim
(
input_
):
def
_split_along_first_dim
(
input_
):
"""Split the tensor along its first dimension and keep the
"""Split the tensor along its first dimension and keep the
corresponding slice."""
corresponding slice."""
...
@@ -174,7 +174,7 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
...
@@ -174,7 +174,7 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
return
grad_output
return
grad_output
class
_Scatter
AlongLastDim
ToModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_ScatterToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chuck to the rank."""
"""Split the input and keep only the corresponding chuck to the rank."""
@
staticmethod
@
staticmethod
...
@@ -190,7 +190,7 @@ class _ScatterAlongLastDimToModelParallelRegion(torch.autograd.Function):
...
@@ -190,7 +190,7 @@ class _ScatterAlongLastDimToModelParallelRegion(torch.autograd.Function):
return
_gather_along_last_dim
(
grad_output
)
return
_gather_along_last_dim
(
grad_output
)
class
_Gather
AlongLastDim
FromModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatinate."""
"""Gather the input from model parallel region and concatinate."""
@
staticmethod
@
staticmethod
...
@@ -203,10 +203,10 @@ class _GatherAlongLastDimFromModelParallelRegion(torch.autograd.Function):
...
@@ -203,10 +203,10 @@ class _GatherAlongLastDimFromModelParallelRegion(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
return
_
reduce_scatter
_along_last_dim
(
grad_output
)
return
_
split
_along_last_dim
(
grad_output
)
class
_Scatter
AlongFirstDimToModel
ParallelRegion
(
torch
.
autograd
.
Function
):
class
_Scatter
ToSequence
ParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chuck to the rank."""
"""Split the input and keep only the corresponding chuck to the rank."""
@
staticmethod
@
staticmethod
...
@@ -222,7 +222,7 @@ class _ScatterAlongFirstDimToModelParallelRegion(torch.autograd.Function):
...
@@ -222,7 +222,7 @@ class _ScatterAlongFirstDimToModelParallelRegion(torch.autograd.Function):
return
_gather_along_first_dim
(
grad_output
)
return
_gather_along_first_dim
(
grad_output
)
class
_Gather
AlongFirstDimFromModel
ParallelRegion
(
torch
.
autograd
.
Function
):
class
_Gather
FromSequence
ParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatinate."""
#TODO
"""Gather the input from model parallel region and concatinate."""
#TODO
@
staticmethod
@
staticmethod
...
@@ -238,23 +238,7 @@ class _GatherAlongFirstDimFromModelParallelRegion(torch.autograd.Function):
...
@@ -238,23 +238,7 @@ class _GatherAlongFirstDimFromModelParallelRegion(torch.autograd.Function):
return
_reduce_scatter_along_first_dim
(
grad_output
)
return
_reduce_scatter_along_first_dim
(
grad_output
)
class
_ReduceScatterAlongLastDimToModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_ReduceScatterToSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce_scatter_along_last_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_reduce_scatter_along_last_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_last_dim
(
grad_output
)
class
_ReduceScatterAlongFirstDimToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the model parallel region."""
"""Reduce scatter the input from the model parallel region."""
@
staticmethod
@
staticmethod
...
@@ -282,25 +266,22 @@ def reduce_from_tensor_model_parallel_region(input_):
...
@@ -282,25 +266,22 @@ def reduce_from_tensor_model_parallel_region(input_):
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
def
scatter_along_last_dim_to_tensor_model_parallel_region
(
input_
):
def
scatter_to_tensor_model_parallel_region
(
input_
):
return
_ScatterAlongLastDimToModelParallelRegion
.
apply
(
input_
)
return
_ScatterToModelParallelRegion
.
apply
(
input_
)
def
gather_along_last_dim_from_tensor_model_parallel_region
(
input_
):
return
_GatherAlongLastDimFromModelParallelRegion
.
apply
(
input_
)
def
gather_from_tensor_model_parallel_region
(
input_
):
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
def
scatter_along_first_dim_to_tensor_model_parallel_region
(
input_
):
return
_ScatterAlongFirstDimToModelParallelRegion
.
apply
(
input_
)
def
scatter_to_sequence_parallel_region
(
input_
):
return
_ScatterToSequenceParallelRegion
.
apply
(
input_
)
def
gather_along_first_dim_from_tensor_model_parallel_region
(
input_
):
return
_GatherAlongFirstDimFromModelParallelRegion
.
apply
(
input_
)
def
gather_from_seqeuence_parallel_region
(
input_
):
return
_GatherFromSequenceParallelRegion
.
apply
(
input_
)
def
reduce_scatter_along_first_dim_to_tensor_model_parallel_region
(
input_
):
return
_ReduceScatterAlongFirstDimToModelParallelRegion
.
apply
(
input_
)
def
reduce_scatter_to_sequence_parallel_region
(
input_
):
return
_ReduceScatterToSequenceParallelRegion
.
apply
(
input_
)
def
reduce_scatter_along_last_dim_to_tensor_model_parallel_region
(
input_
):
return
_ReduceScatterAlongLastDimToModelParallelRegion
.
apply
(
input_
)
megatron/optimizer/__init__.py
View file @
c0f10643
...
@@ -17,7 +17,6 @@ from apex.optimizers import FusedAdam as Adam
...
@@ -17,7 +17,6 @@ from apex.optimizers import FusedAdam as Adam
from
apex.optimizers
import
FusedSGD
as
SGD
from
apex.optimizers
import
FusedSGD
as
SGD
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
...
...
megatron/training.py
View file @
c0f10643
...
@@ -23,6 +23,7 @@ import time
...
@@ -23,6 +23,7 @@ import time
_TRAIN_START_TIME
=
time
.
time
()
_TRAIN_START_TIME
=
time
.
time
()
import
torch
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_signal_handler
from
megatron
import
get_signal_handler
...
@@ -418,6 +419,26 @@ def train_step(forward_step_func, data_iterator,
...
@@ -418,6 +419,26 @@ def train_step(forward_step_func, data_iterator,
if
args
.
empty_unused_memory_level
>=
1
:
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if
args
.
get_tensor_model_parallel_world_size
>
1
and
\
args
.
model_parallel_memory_opt
:
grads
=
[]
for
model_module
in
model
:
unwrapped_model
=
unwrap_model
(
model_module
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
for
param
in
unwrapped_model
.
parameters
():
if
param
.
get_attr
(
'sequence_parallel'
,
False
):
assert
param
.
requires_grad
and
param
.
grad
is
not
None
grads
.
append
(
param
.
grad
.
data
)
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
/=
mpu
.
get_tensor_model_parallel_world_size
()
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
mpu
.
get_tensor_model_parallel_group
())
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
# All-reduce if needed.
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
timers
(
'backward-params-all-reduce'
).
start
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment