Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4a3213f1
Commit
4a3213f1
authored
Nov 17, 2021
by
Jared Casper
Browse files
Merge branch 'slym/persist_ln' into 'main'
Persistent layer norm See merge request ADLR/megatron-lm!351
parents
3ae12a47
a2fdcdf0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
7 deletions
+42
-7
megatron/arguments.py
megatron/arguments.py
+15
-0
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+19
-3
megatron/model/transformer.py
megatron/model/transformer.py
+8
-4
No files found.
megatron/arguments.py
View file @
4a3213f1
...
@@ -257,6 +257,16 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -257,6 +257,16 @@ def parse_args(extra_args_provider=None, defaults={},
'currently distrobuted checkpoint activations only supported for '
\
'currently distrobuted checkpoint activations only supported for '
\
'nointerleaved pipeline parallelism'
'nointerleaved pipeline parallelism'
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
# Persistent fused layer norm.
if
TORCH_MAJOR
<
1
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
<
11
):
args
.
no_persist_layer_norm
=
True
if
args
.
rank
==
0
:
print
(
'Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True'
)
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
@@ -486,6 +496,11 @@ def _add_training_args(parser):
...
@@ -486,6 +496,11 @@ def _add_training_args(parser):
help
=
'Disable asynchronous execution of '
help
=
'Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.'
)
'gradient compuation of a column-linear layer.'
)
group
.
add_argument
(
'--no-persist-layer-norm'
,
action
=
'store_true'
,
help
=
'Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.'
)
return
parser
return
parser
...
...
megatron/model/fused_layer_norm.py
View file @
4a3213f1
...
@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter
...
@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter
from
torch.nn
import
init
from
torch.nn
import
init
import
importlib
import
importlib
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
global
fused_mix_prec_layer_norm_cuda
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
None
fused_mix_prec_layer_norm_cuda
=
None
...
@@ -61,13 +63,22 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
...
@@ -61,13 +63,22 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
no_persist_layer_norm
=
True
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
global
fused_mix_prec_layer_norm_cuda
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
"fused_mix_prec_layer_norm_cuda"
)
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes
=
[
1024
,
1536
,
2048
,
2304
,
3072
,
3840
,
4096
,
5120
,
6144
,
8192
,
10240
,
12288
,
12800
,
15360
,
16384
,
18432
,
20480
,
24576
,
25600
,
30720
,
32768
,
40960
,
49152
,
65536
]
if
normalized_shape
not
in
persist_ln_hidden_sizes
:
no_persist_layer_norm
=
True
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
normalized_shape
=
(
normalized_shape
,)
normalized_shape
=
(
normalized_shape
,)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
...
@@ -75,6 +86,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
...
@@ -75,6 +86,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
reset_parameters
()
self
.
reset_parameters
()
self
.
no_persist_layer_norm
=
no_persist_layer_norm
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
...
@@ -85,6 +97,10 @@ class MixedFusedLayerNorm(torch.nn.Module):
...
@@ -85,6 +97,10 @@ class MixedFusedLayerNorm(torch.nn.Module):
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
if
self
.
no_persist_layer_norm
:
return
FusedLayerNormAffineFunction
.
apply
(
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
return
FastLayerNormFN
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
megatron/model/transformer.py
View file @
4a3213f1
...
@@ -423,7 +423,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -423,7 +423,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the input data.
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# Self attention.
# Self attention.
self
.
self_attention
=
ParallelAttention
(
self
.
self_attention
=
ParallelAttention
(
...
@@ -438,7 +439,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -438,7 +439,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the attention output
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
if
self
.
layer_type
==
LayerType
.
decoder
:
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
self
.
inter_attention
=
ParallelAttention
(
...
@@ -449,7 +451,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -449,7 +451,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the attention output.
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# MLP
# MLP
self
.
mlp
=
ParallelMLP
(
init_method
,
self
.
mlp
=
ParallelMLP
(
init_method
,
...
@@ -602,7 +605,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -602,7 +605,8 @@ class ParallelTransformer(MegatronModule):
# Final layer norm before output.
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
def
_get_layer
(
self
,
layer_number
):
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
return
self
.
layers
[
layer_number
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment