Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
9800dec4
Commit
9800dec4
authored
Apr 14, 2025
by
dongcl
Browse files
add LightopRMSNorm
parent
0604509a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
135 additions
and
147 deletions
+135
-147
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+6
-1
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+39
-145
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+0
-1
dcu_megatron/legacy/model/rms_norm.py
dcu_megatron/legacy/model/rms_norm.py
+67
-0
dcu_megatron/legacy/model/utils.py
dcu_megatron/legacy/model/utils.py
+23
-0
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
9800dec4
...
@@ -116,7 +116,9 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -116,7 +116,9 @@ class CoreAdaptation(MegatronAdaptationABC):
# GPT Model
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
from
megatron.core.models.gpt.gpt_model
import
GPTModel
from
megatron.core.models.gpt.gpt_model
import
GPTModel
setattr
(
GPTModel
,
'shared_embedding_or_mtp_embedding_weight'
,
shared_embedding_or_mtp_embedding_weight
)
setattr
(
GPTModel
,
'shared_embedding_or_mtp_embedding_weight'
,
shared_embedding_or_mtp_embedding_weight
)
...
@@ -240,6 +242,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
...
@@ -240,6 +242,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
def
patch_legacy_models
(
self
):
def
patch_legacy_models
(
self
):
from
..legacy.model.transformer
import
ParallelMLP
,
ParallelAttention
from
..legacy.model.transformer
import
ParallelMLP
,
ParallelAttention
from
..legacy.model.utils
import
get_norm
# ParallecMLP
# ParallecMLP
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelMLP.__init__'
,
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelMLP.__init__'
,
...
@@ -252,6 +255,8 @@ class LegacyAdaptation(MegatronAdaptationABC):
...
@@ -252,6 +255,8 @@ class LegacyAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.legacy.model.rms_norm.RMSNorm.forward'
,
MegatronAdaptation
.
register
(
'megatron.legacy.model.rms_norm.RMSNorm.forward'
,
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
),
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
),
apply_wrapper
=
True
)
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.legacy.model.utils.get_norm'
,
get_norm
)
MegatronAdaptation
.
execute
()
MegatronAdaptation
.
execute
()
dcu_megatron/core/models/gpt/gpt_model.py
View file @
9800dec4
...
@@ -22,121 +22,21 @@ from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPr
...
@@ -22,121 +22,21 @@ from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPr
from
dcu_megatron.core.transformer.transformer_config
import
TransformerConfig
from
dcu_megatron.core.transformer.transformer_config
import
TransformerConfig
def
gpt_model_init
(
def
gpt_model_init_wrapper
(
fn
):
self
,
@
wraps
(
fn
)
config
:
TransformerConfig
,
def
wrapper
(
self
,
*
args
,
**
kwargs
):
transformer_layer_spec
:
ModuleSpec
,
fn
(
self
,
*
args
,
**
kwargs
)
vocab_size
:
int
,
max_sequence_length
:
int
,
pre_process
:
bool
=
True
,
post_process
:
bool
=
True
,
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
share_embeddings_and_output_weights
:
bool
=
False
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
rope_scaling
:
bool
=
False
,
rope_scaling_factor
:
float
=
8.0
,
scatter_embedding_sequence_parallel
:
bool
=
True
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
mtp_spec
:
ModuleSpec
=
None
)
->
None
:
super
(
GPTModel
,
self
).
__init__
(
config
=
config
)
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
transformer_layer_spec
:
ModuleSpec
=
transformer_layer_spec
self
.
vocab_size
=
vocab_size
self
.
max_sequence_length
=
max_sequence_length
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
self
.
position_embedding_type
=
position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self
.
model_type
=
ModelType
.
encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self
.
max_position_embeddings
=
max_sequence_length
self
.
rotary_percent
=
rotary_percent
self
.
rotary_base
=
rotary_base
self
.
rotary_scaling
=
rope_scaling
if
self
.
pre_process
:
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
position_embedding_type
,
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
)
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
self
.
config
.
kv_channels
,
rotary_percent
=
rotary_percent
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
rotary_base
=
rotary_base
,
rope_scaling
=
rope_scaling
,
rope_scaling_factor
=
rope_scaling_factor
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
# Cache for RoPE tensors which do not change between iterations.
self
.
rotary_pos_emb_cache
=
{}
# Transformer.
self
.
decoder
=
TransformerBlock
(
config
=
self
.
config
,
spec
=
transformer_layer_spec
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
# Output
if
post_process
:
if
self
.
config
.
defer_embedding_wgrad_compute
:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self
.
embedding_activation_buffer
=
[]
self
.
grad_output_buffer
=
[]
else
:
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
self
.
output_layer
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
self
.
vocab_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_embeddings_and_output_weights
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
# add mtp
# add mtp
self
.
mtp_spec
:
ModuleSpec
=
mtp_spec
self
.
num_nextn_predict_layers
=
self
.
config
.
num_nextn_predict_layers
self
.
num_nextn_predict_layers
=
self
.
config
.
num_nextn_predict_layers
if
self
.
num_nextn_predict_layers
:
assert
hasattr
(
self
.
config
,
"mtp_spec"
)
self
.
mtp_spec
:
ModuleSpec
=
self
.
config
.
mtp_spec
self
.
share_mtp_embedding_and_output_weight
=
self
.
config
.
share_mtp_embedding_and_output_weight
self
.
share_mtp_embedding_and_output_weight
=
self
.
config
.
share_mtp_embedding_and_output_weight
self
.
recompute_mtp_norm
=
self
.
config
.
recompute_mtp_norm
self
.
recompute_mtp_norm
=
self
.
config
.
recompute_mtp_norm
self
.
recompute_mtp_layer
=
self
.
config
.
recompute_mtp_layer
self
.
recompute_mtp_layer
=
self
.
config
.
recompute_mtp_layer
self
.
mtp_loss_scale
=
self
.
config
.
mtp_loss_scale
self
.
mtp_loss_scale
=
self
.
config
.
mtp_loss_scale
if
self
.
post_process
and
self
.
training
and
self
.
num_nextn_predict_layers
:
if
self
.
post_process
and
self
.
training
:
self
.
mtp_layers
=
torch
.
nn
.
ModuleList
(
self
.
mtp_layers
=
torch
.
nn
.
ModuleList
(
[
[
MultiTokenPredictor
(
MultiTokenPredictor
(
...
@@ -161,16 +61,10 @@ def gpt_model_init(
...
@@ -161,16 +61,10 @@ def gpt_model_init(
)
)
if
self
.
pre_process
or
self
.
post_process
:
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
if
has_config_logger_enabled
(
self
.
config
):
log_config_to_disk
(
self
.
config
,
self
.
state_dict
(),
prefix
=
f
'
{
type
(
self
).
__name__
}
_init_ckpt'
)
if
self
.
num_nextn_predict_layers
and
(
self
.
pre_process
or
self
.
post_process
):
setup_mtp_embeddings
(
self
)
setup_mtp_embeddings
(
self
)
return
wrapper
def
shared_embedding_or_mtp_embedding_weight
(
self
)
->
Tensor
:
def
shared_embedding_or_mtp_embedding_weight
(
self
)
->
Tensor
:
"""Gets the embedding weight when share embedding and mtp embedding weights set to True.
"""Gets the embedding weight when share embedding and mtp embedding weights set to True.
...
...
dcu_megatron/core/transformer/transformer_block.py
View file @
9800dec4
...
@@ -199,4 +199,3 @@ def transformer_block_forward(
...
@@ -199,4 +199,3 @@ def transformer_block_forward(
)
)
return
hidden_states
return
hidden_states
dcu_megatron/legacy/model/rms_norm.py
View file @
9800dec4
import
torch
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
import
lightop
# rmsnorm_forward,rmsnorm_backward
from
functools
import
partial
from
megatron.core.utils
import
is_torch_min_version
if
is_torch_min_version
(
"2.4.0a0"
):
custom_fwd
=
partial
(
torch
.
amp
.
custom_fwd
,
device_type
=
"cuda"
)
custom_bwd
=
partial
(
torch
.
amp
.
custom_bwd
,
device_type
=
"cuda"
)
else
:
custom_fwd
=
torch
.
cuda
.
amp
.
custom_fwd
custom_bwd
=
torch
.
cuda
.
amp
.
custom_bwd
class
_LightopRMSNorm
(
torch
.
autograd
.
Function
):
""" 使用lightop实现rmsnorm"""
@
staticmethod
# @custom_fwd
def
forward
(
ctx
,
inp
:
torch
.
Tensor
,
ln_out
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
,
is_grad_enabled
):
output
=
lightop
.
rmsnorm_forward
(
inp
,
weight
,
ln_out
,
eps
,
training
=
True
)
# output = (output, weight)
rsigma
=
output
[
1
]
if
is_grad_enabled
:
ctx
.
save_for_backward
(
inp
,
weight
,
ln_out
,
rsigma
)
return
output
[
0
]
@
staticmethod
# @custom_bwd
def
backward
(
ctx
,
grad_output
):
inp
,
weight
,
ln_out
,
rsigma
=
ctx
.
saved_tensors
dgrad
,
dgamma
=
lightop
.
rmsnorm_backward
(
grad_output
,
inp
,
rsigma
,
weight
)
return
dgrad
,
None
,
dgamma
,
None
,
None
,
None
,
None
,
None
,
None
class
LightopRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
,):
"""RMS Normaliation module
Args:
dim (int): The width of input, i.e. hidden size
eps (float): epsilon to use for the norm, default to 1e-6
"""
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
dim
))
# @no_torch_dynamo() # 动态torch._dynamo.disable
def
forward
(
self
,
inp
:
torch
.
Tensor
,
is_first_microbatch
:
Optional
[
bool
]
=
None
):
if
torch
.
is_grad_enabled
():
fwd_fn
=
_LightopRMSNorm
.
apply
args
=
[]
else
:
fwd_fn
=
_LightopRMSNorm
.
forward
args
=
[
None
]
ln_out
=
torch
.
empty_like
(
inp
,
dtype
=
inp
.
dtype
,
memory_format
=
torch
.
contiguous_format
)
args
+=
(
inp
,
ln_out
,
self
.
weight
,
self
.
eps
,
torch
.
is_grad_enabled
())
out
=
fwd_fn
(
*
args
)
return
out
dcu_megatron/legacy/model/utils.py
0 → 100644
View file @
9800dec4
from
megatron.training
import
get_args
from
megatron.legacy.model
import
LayerNorm
from
.rms_norm
import
LightopRMSNorm
def
get_norm
(
config
):
args
=
get_args
()
if
args
.
normalization
==
"LayerNorm"
:
return
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
no_persist_layer_norm
=
not
config
.
persist_layer_norm
,
sequence_parallel
=
config
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
elif
args
.
normalization
==
"RMSNorm"
:
if
args
.
apply_layernorm_1p
:
raise
NotImplementedError
(
'RMSNorm does not currently support the layernorm_1p formulation.'
)
return
LightopRMSNorm
(
dim
=
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
sequence_parallel
=
config
.
sequence_parallel
)
else
:
raise
Exception
(
f
"unsupported norm type '
{
args
.
normalization
}
'."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment