Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4e891fe9
Commit
4e891fe9
authored
Apr 05, 2023
by
Jared Casper
Browse files
Merge branch 'next-best-lm/merge-layernorm-1p-main' into 'main'
layernorm1p added See merge request ADLR/megatron-lm!557
parents
7bd25e26
33a58153
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
14 deletions
+27
-14
megatron/arguments.py
megatron/arguments.py
+3
-0
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+14
-8
megatron/model/t5_model.py
megatron/model/t5_model.py
+1
-1
megatron/model/transformer.py
megatron/model/transformer.py
+9
-5
No files found.
megatron/arguments.py
View file @
4e891fe9
...
@@ -534,6 +534,9 @@ def _add_network_size_args(parser):
...
@@ -534,6 +534,9 @@ def _add_network_size_args(parser):
'This is added for computational efficieny reasons.'
)
'This is added for computational efficieny reasons.'
)
group
.
add_argument
(
'--layernorm-epsilon'
,
type
=
float
,
default
=
1e-5
,
group
.
add_argument
(
'--layernorm-epsilon'
,
type
=
float
,
default
=
1e-5
,
help
=
'Layer norm epsilon.'
)
help
=
'Layer norm epsilon.'
)
group
.
add_argument
(
'--apply-layernorm-1p'
,
action
=
'store_true'
,
help
=
'Adjust LayerNorm weights such that they are centered '
'around zero. This improves numerical stability.'
)
group
.
add_argument
(
'--apply-residual-connection-post-layernorm'
,
group
.
add_argument
(
'--apply-residual-connection-post-layernorm'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If set, use original BERT residula connection '
help
=
'If set, use original BERT residula connection '
...
...
megatron/model/fused_layer_norm.py
View file @
4e891fe9
...
@@ -58,9 +58,12 @@ class MixedFusedLayerNorm(torch.nn.Module):
...
@@ -58,9 +58,12 @@ class MixedFusedLayerNorm(torch.nn.Module):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
no_persist_layer_norm
=
True
,
no_persist_layer_norm
=
True
,
sequence_parallel
=
False
):
sequence_parallel
=
False
,
apply_layernorm_1p
=
False
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
self
.
apply_layernorm_1p
=
apply_layernorm_1p
global
fused_mix_prec_layer_norm_cuda
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
"fused_mix_prec_layer_norm_cuda"
)
...
@@ -92,18 +95,21 @@ class MixedFusedLayerNorm(torch.nn.Module):
...
@@ -92,18 +95,21 @@ class MixedFusedLayerNorm(torch.nn.Module):
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
if
self
.
apply_layernorm_1p
:
init
.
zeros_
(
self
.
bias
)
init
.
zeros_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
else
:
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
weight
=
self
.
weight
+
1
if
self
.
apply_layernorm_1p
else
self
.
weight
if
self
.
no_persist_layer_norm
:
if
self
.
no_persist_layer_norm
:
return
FusedLayerNormAffineFunction
.
apply
(
return
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
else
:
output
=
FastLayerNormFN
.
apply
(
output
=
FastLayerNormFN
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
eps
)
input
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# a populated '_base' field). This will result in schedule.py's
...
...
megatron/model/t5_model.py
View file @
4e891fe9
...
@@ -8,7 +8,7 @@ from megatron import get_args
...
@@ -8,7 +8,7 @@ from megatron import get_args
from
megatron.core
import
tensor_parallel
from
megatron.core
import
tensor_parallel
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
parallel_lm_logits
,
get_language_model
from
megatron.model.language_model
import
parallel_lm_logits
,
get_language_model
from
megatron.model
.transformer
import
LayerNorm
from
megatron.model
import
LayerNorm
from
megatron.model.utils
import
(
from
megatron.model.utils
import
(
openai_gelu
,
openai_gelu
,
get_linear_layer
,
get_linear_layer
,
...
...
megatron/model/transformer.py
View file @
4e891fe9
...
@@ -10,8 +10,8 @@ from megatron import get_timers, get_args, core, get_num_microbatches
...
@@ -10,8 +10,8 @@ from megatron import get_timers, get_args, core, get_num_microbatches
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.core.enums
import
ModelType
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
LayerNorm
from
megatron.model
import
LayerNorm
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.rotary_pos_embedding
import
apply_rotary_pos_emb
from
megatron.model.rotary_pos_embedding
import
apply_rotary_pos_emb
...
@@ -712,7 +712,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -712,7 +712,8 @@ class ParallelTransformerLayer(MegatronModule):
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
)
sequence_parallel
=
args
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
# Self attention.
# Self attention.
self
.
self_attention
=
ParallelAttention
(
self
.
self_attention
=
ParallelAttention
(
...
@@ -730,7 +731,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -730,7 +731,8 @@ class ParallelTransformerLayer(MegatronModule):
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
)
sequence_parallel
=
args
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
if
self
.
layer_type
==
LayerType
.
decoder
:
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
self
.
inter_attention
=
ParallelAttention
(
...
@@ -743,7 +745,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -743,7 +745,8 @@ class ParallelTransformerLayer(MegatronModule):
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
)
sequence_parallel
=
args
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
# MLP
# MLP
if
args
.
num_experts
is
not
None
:
if
args
.
num_experts
is
not
None
:
...
@@ -1108,7 +1111,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -1108,7 +1111,8 @@ class ParallelTransformer(MegatronModule):
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
)
sequence_parallel
=
args
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
def
_get_layer
(
self
,
layer_number
):
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
return
self
.
layers
[
layer_number
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment