Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
bf1da27e
Commit
bf1da27e
authored
Apr 04, 2023
by
Mostofa Patwary
Browse files
addressing comments
parent
5e079c87
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
34 deletions
+40
-34
megatron/model/__init__.py
megatron/model/__init__.py
+1
-1
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+30
-25
megatron/model/transformer.py
megatron/model/transformer.py
+9
-8
No files found.
megatron/model/__init__.py
View file @
bf1da27e
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.fused_layer_norm
import
MixedFusedLayerNorm1P
as
LayerNorm1P
#
from .fused_layer_norm import MixedFusedLayerNorm1P as LayerNorm1P
from
.distributed
import
DistributedDataParallel
from
.bert_model
import
BertModel
...
...
megatron/model/fused_layer_norm.py
View file @
bf1da27e
...
...
@@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter
from
torch.nn
import
init
import
importlib
from
megatron
import
get_args
from
megatron.core.utils
import
make_viewless_tensor
try
:
...
...
@@ -89,6 +90,10 @@ class MixedFusedLayerNorm(torch.nn.Module):
setattr
(
self
.
weight
,
'sequence_parallel'
,
self
.
sequence_parallel
)
setattr
(
self
.
bias
,
'sequence_parallel'
,
self
.
sequence_parallel
)
args
=
get_args
()
self
.
weight_adjustment
=
0
if
args
.
apply_layernorm_1p
:
self
.
weight_adjustment
=
1
def
reset_parameters
(
self
):
...
...
@@ -100,10 +105,10 @@ class MixedFusedLayerNorm(torch.nn.Module):
if
self
.
no_persist_layer_norm
:
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
input
,
self
.
weight
+
self
.
weight_adjustment
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
output
=
FastLayerNormFN
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
input
,
self
.
weight
+
self
.
weight_adjustment
,
self
.
bias
,
self
.
eps
)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
...
...
@@ -117,26 +122,26 @@ class MixedFusedLayerNorm(torch.nn.Module):
class
MixedFusedLayerNorm1P
(
MixedFusedLayerNorm
):
def
reset_parameters
(
self
):
init
.
zeros_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
if
self
.
no_persist_layer_norm
:
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
+
1
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
output
=
FastLayerNormFN
.
apply
(
input
,
self
.
weight
+
1
,
self
.
bias
,
self
.
eps
)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output
=
make_viewless_tensor
(
inp
=
output
,
requires_grad
=
input
.
requires_grad
,
keep_graph
=
True
)
return
output
#
class MixedFusedLayerNorm1P(MixedFusedLayerNorm):
#
def reset_parameters(self):
#
init.zeros_(self.weight)
#
init.zeros_(self.bias)
#
#
def forward(self, input):
#
#
if self.no_persist_layer_norm:
#
return FusedLayerNormAffineFunction.apply(
#
input, self.weight + 1, self.bias, self.normalized_shape, self.eps)
#
else:
#
output = FastLayerNormFN.apply(
#
input, self.weight + 1, self.bias, self.eps)
#
#
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
#
# a populated '_base' field). This will result in schedule.py's
#
# deallocate_output_tensor() throwing an error, so a viewless tensor is
#
# created to prevent this.
#
output = make_viewless_tensor(inp = output,
#
requires_grad = input.requires_grad,
#
keep_graph = True)
#
#
return output
megatron/model/transformer.py
View file @
bf1da27e
...
...
@@ -14,6 +14,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
from
megatron.model
import
LayerNorm
try
:
from
einops
import
rearrange
...
...
@@ -634,10 +635,10 @@ class ParallelTransformerLayer(MegatronModule):
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
if
args
.
apply_layernorm_1p
:
from
megatron.model
import
LayerNorm1P
as
LayerNorm
else
:
from
megatron.model
import
LayerNorm
#
if args.apply_layernorm_1p:
#
from megatron.model import LayerNorm1P as LayerNorm
#
else:
#
from megatron.model import LayerNorm
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
...
...
@@ -1024,10 +1025,10 @@ class ParallelTransformer(MegatronModule):
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
args
.
apply_layernorm_1p
:
from
megatron.model
import
LayerNorm1P
as
LayerNorm
else
:
from
megatron.model
import
LayerNorm
#
if args.apply_layernorm_1p:
#
from megatron.model import LayerNorm1P as LayerNorm
#
else:
#
from megatron.model import LayerNorm
if
self
.
post_process
and
self
.
post_layer_norm
:
# Final layer norm before output.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment