Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
1e498ef0
Commit
1e498ef0
authored
Apr 15, 2025
by
wangxj
Browse files
修改lightop的rmsnorm
parent
c1200c81
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
3 deletions
+54
-3
dcu_megatron/legacy/model/rms_norm.py
dcu_megatron/legacy/model/rms_norm.py
+34
-0
dcu_megatron/legacy/model/utils.py
dcu_megatron/legacy/model/utils.py
+7
-3
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+13
-0
No files found.
dcu_megatron/legacy/model/rms_norm.py
View file @
1e498ef0
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
torch
from
torch
import
nn
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
,
sequence_parallel
:
bool
=
False
,
config
:
dict
=
None
):
"""RMS Normaliation module
Args:
dim (int): The width of input, i.e. hidden size
eps (float): epsilon to use for the norm, default to 1e-6
sequence_parallel (bool): Set to true if sequence parallelism is being used,
this marks the weights as needing to be allreduced.
"""
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
dim
))
setattr
(
self
.
weight
,
'sequence_parallel'
,
sequence_parallel
)
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
forward
(
self
,
x
):
output
=
self
.
_norm
(
x
.
float
()).
type_as
(
x
)
return
output
*
self
.
weight
import
torch
import
torch
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
...
...
dcu_megatron/legacy/model/utils.py
View file @
1e498ef0
from
megatron.training
import
get_args
from
megatron.training
import
get_args
from
megatron.legacy.model
import
LayerNorm
from
megatron.legacy.model
import
LayerNorm
from
.rms_norm
import
LightopRMSNorm
from
.rms_norm
import
RMSNorm
,
LightopRMSNorm
def
get_norm
(
config
):
def
get_norm
(
config
):
...
@@ -15,8 +15,12 @@ def get_norm(config):
...
@@ -15,8 +15,12 @@ def get_norm(config):
elif
args
.
normalization
==
"RMSNorm"
:
elif
args
.
normalization
==
"RMSNorm"
:
if
args
.
apply_layernorm_1p
:
if
args
.
apply_layernorm_1p
:
raise
NotImplementedError
(
'RMSNorm does not currently support the layernorm_1p formulation.'
)
raise
NotImplementedError
(
'RMSNorm does not currently support the layernorm_1p formulation.'
)
return
RMSNorm
(
dim
=
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
sequence_parallel
=
config
.
sequence_parallel
)
elif
args
.
normalization
==
"LightopRMSNorm"
:
return
LightopRMSNorm
(
dim
=
config
.
hidden_size
,
return
LightopRMSNorm
(
dim
=
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
)
eps
=
config
.
layernorm_epsilon
)
else
:
else
:
raise
Exception
(
f
"unsupported norm type '
{
args
.
normalization
}
'."
)
raise
Exception
(
f
"unsupported norm type '
{
args
.
normalization
}
'."
)
dcu_megatron/training/arguments.py
View file @
1e498ef0
...
@@ -51,6 +51,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
...
@@ -51,6 +51,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
# Standard arguments.
# Standard arguments.
parser
=
_add_network_size_args
(
parser
)
parser
=
_add_network_size_args
(
parser
)
parser
=
_add_extra_network_size_args
(
parser
)
parser
=
_add_regularization_args
(
parser
)
parser
=
_add_regularization_args
(
parser
)
parser
=
_add_training_args
(
parser
)
parser
=
_add_training_args
(
parser
)
parser
=
_add_extra_training_args
(
parser
)
parser
=
_add_extra_training_args
(
parser
)
...
@@ -106,6 +107,18 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
...
@@ -106,6 +107,18 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
return
args
return
args
def
_add_extra_network_size_args
(
parser
):
# 删除原参数
remove_original_params
(
parser
,
[
"normalization"
])
# 重定义参数
group
=
parser
.
add_argument_group
(
title
=
'extra network size args'
)
group
.
add_argument
(
'--normalization'
,
default
=
'LayerNorm'
,
choices
=
[
'LayerNorm'
,
'RMSNorm'
,
'LightopRMSNorm'
],
help
=
'Which normalization technique to use.'
)
return
parser
def
_add_extra_distributed_args
(
parser
):
def
_add_extra_distributed_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'extra distributed args'
)
group
=
parser
.
add_argument_group
(
title
=
'extra distributed args'
)
group
.
add_argument
(
'--rank'
,
default
=-
1
,
type
=
int
,
group
.
add_argument
(
'--rank'
,
default
=-
1
,
type
=
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment