Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
26940c4c
Commit
26940c4c
authored
Apr 23, 2025
by
dongcl
Browse files
modify rms_norm.py
parent
763941b5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
39 deletions
+4
-39
dcu_megatron/legacy/model/rms_norm.py
dcu_megatron/legacy/model/rms_norm.py
+2
-37
dcu_megatron/legacy/model/utils.py
dcu_megatron/legacy/model/utils.py
+2
-2
No files found.
dcu_megatron/legacy/model/rms_norm.py
View file @
26940c4c
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
torch
import
torch
from
torch
import
nn
from
typing
import
Optional
import
lightop
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
,
sequence_parallel
:
bool
=
False
,
config
:
dict
=
None
):
"""RMS Normaliation module
Args:
dim (int): The width of input, i.e. hidden size
eps (float): epsilon to use for the norm, default to 1e-6
sequence_parallel (bool): Set to true if sequence parallelism is being used,
this marks the weights as needing to be allreduced.
"""
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
dim
))
setattr
(
self
.
weight
,
'sequence_parallel'
,
sequence_parallel
)
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
forward
(
self
,
x
):
output
=
self
.
_norm
(
x
.
float
()).
type_as
(
x
)
return
output
*
self
.
weight
import
torch
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
import
lightop
# rmsnorm_forward,rmsnorm_backward
from
functools
import
partial
from
functools
import
partial
from
megatron.core.utils
import
is_torch_min_version
from
megatron.core.utils
import
is_torch_min_version
...
...
dcu_megatron/legacy/model/utils.py
View file @
26940c4c
from
megatron.training
import
get_args
from
megatron.training
import
get_args
from
megatron.legacy.model
import
LayerNorm
from
megatron.legacy.model
import
LayerNorm
,
RMSNorm
from
.rms_norm
import
RMSNorm
,
LightopRMSNorm
from
.rms_norm
import
LightopRMSNorm
def
get_norm
(
config
):
def
get_norm
(
config
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment