Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0760822b
"docs/vscode:/vscode.git/clone" did not exist on "fad855cdb4910ddc44730de0a2bb39a3e9f6c4a6"
Commit
0760822b
authored
Apr 05, 2023
by
Mostofa Patwary
Browse files
addressing the comments
parent
bf1da27e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
20 deletions
+35
-20
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+18
-11
megatron/model/transformer.py
megatron/model/transformer.py
+17
-9
No files found.
megatron/model/fused_layer_norm.py
View file @
0760822b
...
...
@@ -10,7 +10,6 @@ from torch.nn.parameter import Parameter
from
torch.nn
import
init
import
importlib
from
megatron
import
get_args
from
megatron.core.utils
import
make_viewless_tensor
try
:
...
...
@@ -59,9 +58,17 @@ class MixedFusedLayerNorm(torch.nn.Module):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
no_persist_layer_norm
=
True
,
sequence_parallel
=
False
):
sequence_parallel
=
False
,
apply_layernorm_1p
=
False
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
self
.
apply_layernorm_1p
=
False
if
apply_layernorm_1p
:
self
.
weight_adjustment
=
1
self
.
apply_layernorm_1p
=
True
else
:
self
.
weight_adjustment
=
0
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
...
...
@@ -89,23 +96,23 @@ class MixedFusedLayerNorm(torch.nn.Module):
# set sequence parallelism flag on weight and bias parameters
setattr
(
self
.
weight
,
'sequence_parallel'
,
self
.
sequence_parallel
)
setattr
(
self
.
bias
,
'sequence_parallel'
,
self
.
sequence_parallel
)
args
=
get_args
()
self
.
weight_adjustment
=
0
if
args
.
apply_layernorm_1p
:
self
.
weight_adjustment
=
1
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
if
self
.
apply_layernorm_1p
:
init
.
zeros_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
else
:
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
if
self
.
no_persist_layer_norm
:
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
+
self
.
weight_adjustment
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
input
,
self
.
weight
+
self
.
weight_adjustment
,
\
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
output
=
FastLayerNormFN
.
apply
(
input
,
self
.
weight
+
self
.
weight_adjustment
,
self
.
bias
,
self
.
eps
)
...
...
megatron/model/transformer.py
View file @
0760822b
...
...
@@ -10,11 +10,11 @@ from megatron import get_timers, get_args, core, get_num_microbatches
from
.module
import
MegatronModule
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.model
import
LayerNorm
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
from
megatron.model
import
LayerNorm
try
:
from
einops
import
rearrange
...
...
@@ -635,8 +635,10 @@ class ParallelTransformerLayer(MegatronModule):
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
#if args.apply_layernorm_1p:
# from megatron.model import LayerNorm1P as LayerNorm
apply_layernorm_1p
=
False
if
args
.
apply_layernorm_1p
:
apply_layernorm_1p
=
True
#from megatron.model import LayerNorm1P as LayerNorm
#else:
# from megatron.model import LayerNorm
...
...
@@ -645,7 +647,8 @@ class ParallelTransformerLayer(MegatronModule):
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
)
sequence_parallel
=
args
.
sequence_parallel
,
apply_layernorm_1p
=
apply_layernorm_1p
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
...
...
@@ -663,7 +666,8 @@ class ParallelTransformerLayer(MegatronModule):
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
)
sequence_parallel
=
args
.
sequence_parallel
,
apply_layernorm_1p
=
apply_layernorm_1p
)
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
...
...
@@ -676,7 +680,8 @@ class ParallelTransformerLayer(MegatronModule):
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
)
sequence_parallel
=
args
.
sequence_parallel
,
apply_layernorm_1p
=
apply_layernorm_1p
)
# MLP
if
args
.
num_experts
is
not
None
:
...
...
@@ -1025,8 +1030,10 @@ class ParallelTransformer(MegatronModule):
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
#if args.apply_layernorm_1p:
# from megatron.model import LayerNorm1P as LayerNorm
apply_layernorm_1p
=
False
if
args
.
apply_layernorm_1p
:
apply_layernorm_1p
=
True
#from megatron.model import LayerNorm1P as LayerNorm
#else:
# from megatron.model import LayerNorm
...
...
@@ -1036,7 +1043,8 @@ class ParallelTransformer(MegatronModule):
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
)
sequence_parallel
=
args
.
sequence_parallel
,
apply_layernorm_1p
=
apply_layernorm_1p
)
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment