Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
56400eb5
Commit
56400eb5
authored
Apr 15, 2025
by
wangxj
Browse files
修改legacy的norm
parent
4d19cbac
Pipeline
#2634
passed with stage
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
86 additions
and
2 deletions
+86
-2
megatron/legacy/model/rms_norm.py
megatron/legacy/model/rms_norm.py
+81
-0
megatron/legacy/model/utils.py
megatron/legacy/model/utils.py
+4
-1
megatron/training/arguments.py
megatron/training/arguments.py
+1
-1
No files found.
megatron/legacy/model/rms_norm.py
View file @
56400eb5
...
...
@@ -31,3 +31,84 @@ class RMSNorm(torch.nn.Module):
def
forward
(
self
,
x
):
output
=
self
.
_norm
(
x
.
float
()).
type_as
(
x
)
return
output
*
self
.
weight
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
import
lightop
# rmsnorm_forward,rmsnorm_backward
from
functools
import
partial
from
megatron.core.utils
import
is_torch_min_version
if
is_torch_min_version
(
"2.4.0a0"
):
custom_fwd
=
partial
(
torch
.
amp
.
custom_fwd
,
device_type
=
"cuda"
)
custom_bwd
=
partial
(
torch
.
amp
.
custom_bwd
,
device_type
=
"cuda"
)
else
:
custom_fwd
=
torch
.
cuda
.
amp
.
custom_fwd
custom_bwd
=
torch
.
cuda
.
amp
.
custom_bwd
def
print_rank_0
(
message
):
"""If distributed is initialized, print only on rank 0."""
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
class
_LightopRMSNorm
(
torch
.
autograd
.
Function
):
""" 使用lightop实现rmsnorm"""
@
staticmethod
# @custom_fwd
def
forward
(
ctx
,
inp
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
ln_out
:
torch
.
Tensor
,
eps
:
float
,
is_grad_enabled
):
output
=
lightop
.
rmsnorm_forward
(
inp
,
weight
,
ln_out
,
eps
,
training
=
True
)
# output = (output, weight)
# print_rank_0(f"_LightopRMSNorm: output({output[0].shape, output[1].shape}) = lightop.rmsfwd(inp{inp.shape}, weight{weight.shape}, ...)")
rsigma
=
output
[
1
]
if
is_grad_enabled
:
ctx
.
save_for_backward
(
inp
,
weight
,
rsigma
)
return
output
[
0
]
@
staticmethod
# @custom_bwd
def
backward
(
ctx
,
grad_output
):
inp
,
weight
,
rsigma
=
ctx
.
saved_tensors
dgrad
,
dgamma
=
lightop
.
rmsnorm_backward
(
grad_output
,
inp
,
rsigma
,
weight
)
# print_rank_0(f"_LightopRMSNorm: dgrad{dgrad.shape}, dgamma{dgamma.shape} = lightop.rmsbwd(grad_output{grad_output.shape}, inp{inp.shape}, rsigma{rsigma.shape}, weight{weight.shape})")
return
dgrad
,
dgamma
,
None
,
None
,
None
class
LightopRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
"""RMS Normaliation module
Args:
dim (int): The width of input, i.e. hidden size
eps (float): epsilon to use for the norm, default to 1e-6
"""
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
dim
))
# @no_torch_dynamo() # 动态torch._dynamo.disable
def
forward
(
self
,
inp
:
torch
.
Tensor
,
is_first_microbatch
:
Optional
[
bool
]
=
None
):
if
torch
.
is_grad_enabled
():
fwd_fn
=
_LightopRMSNorm
.
apply
args
=
[]
else
:
fwd_fn
=
_LightopRMSNorm
.
forward
args
=
[
None
]
ln_out
=
torch
.
empty_like
(
inp
,
dtype
=
inp
.
dtype
,
memory_format
=
torch
.
contiguous_format
)
args
+=
(
inp
,
self
.
weight
,
ln_out
,
self
.
eps
,
torch
.
is_grad_enabled
())
out
=
fwd_fn
(
*
args
)
return
out
megatron/legacy/model/utils.py
View file @
56400eb5
...
...
@@ -7,7 +7,7 @@ import math
import
torch
from
megatron.training
import
get_args
from
megatron.legacy.model
import
LayerNorm
,
RMSNorm
from
megatron.legacy.model
import
LayerNorm
,
RMSNorm
,
LightopRMSNorm
from
megatron.core.jit
import
jit_fuser
def
init_method_normal
(
sigma
):
...
...
@@ -75,5 +75,8 @@ def get_norm(config):
return
RMSNorm
(
dim
=
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
sequence_parallel
=
config
.
sequence_parallel
)
elif
args
.
normalization
==
"LightopRMSNorm"
:
return
LightopRMSNorm
(
dim
=
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
else
:
raise
Exception
(
f
"unsupported norm type '
{
args
.
normalization
}
'."
)
megatron/training/arguments.py
View file @
56400eb5
...
...
@@ -1112,7 +1112,7 @@ def _add_network_size_args(parser):
help
=
'Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.'
)
group
.
add_argument
(
'--normalization'
,
default
=
'LayerNorm'
,
choices
=
[
'LayerNorm'
,
'RMSNorm'
],
choices
=
[
'LayerNorm'
,
'RMSNorm'
,
'LightopRMSNorm'
],
help
=
'Which normalization technique to use.'
)
group
.
add_argument
(
'--norm-epsilon'
,
type
=
float
,
default
=
1e-5
,
help
=
'Epsilon for layer norm and RMS norm.'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment