Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
9800dec4
Commit
9800dec4
authored
Apr 14, 2025
by
dongcl
Browse files
add LightopRMSNorm
parent
0604509a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
135 additions
and
147 deletions
+135
-147
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+6
-1
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+39
-145
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+0
-1
dcu_megatron/legacy/model/rms_norm.py
dcu_megatron/legacy/model/rms_norm.py
+67
-0
dcu_megatron/legacy/model/utils.py
dcu_megatron/legacy/model/utils.py
+23
-0
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
9800dec4
...
...
@@ -116,7 +116,9 @@ class CoreAdaptation(MegatronAdaptationABC):
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
from
megatron.core.models.gpt.gpt_model
import
GPTModel
setattr
(
GPTModel
,
'shared_embedding_or_mtp_embedding_weight'
,
shared_embedding_or_mtp_embedding_weight
)
...
...
@@ -240,6 +242,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
def
patch_legacy_models
(
self
):
from
..legacy.model.transformer
import
ParallelMLP
,
ParallelAttention
from
..legacy.model.utils
import
get_norm
# ParallecMLP
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelMLP.__init__'
,
...
...
@@ -252,6 +255,8 @@ class LegacyAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.legacy.model.rms_norm.RMSNorm.forward'
,
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
),
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.legacy.model.utils.get_norm'
,
get_norm
)
MegatronAdaptation
.
execute
()
dcu_megatron/core/models/gpt/gpt_model.py
View file @
9800dec4
...
...
@@ -22,154 +22,48 @@ from dcu_megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPr
from
dcu_megatron.core.transformer.transformer_config
import
TransformerConfig
def
gpt_model_init
(
self
,
config
:
TransformerConfig
,
transformer_layer_spec
:
ModuleSpec
,
vocab_size
:
int
,
max_sequence_length
:
int
,
pre_process
:
bool
=
True
,
post_process
:
bool
=
True
,
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
share_embeddings_and_output_weights
:
bool
=
False
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
rope_scaling
:
bool
=
False
,
rope_scaling_factor
:
float
=
8.0
,
scatter_embedding_sequence_parallel
:
bool
=
True
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
mtp_spec
:
ModuleSpec
=
None
)
->
None
:
super
(
GPTModel
,
self
).
__init__
(
config
=
config
)
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
transformer_layer_spec
:
ModuleSpec
=
transformer_layer_spec
self
.
vocab_size
=
vocab_size
self
.
max_sequence_length
=
max_sequence_length
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
self
.
position_embedding_type
=
position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self
.
model_type
=
ModelType
.
encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self
.
max_position_embeddings
=
max_sequence_length
self
.
rotary_percent
=
rotary_percent
self
.
rotary_base
=
rotary_base
self
.
rotary_scaling
=
rope_scaling
if
self
.
pre_process
:
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
position_embedding_type
,
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
)
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
self
.
config
.
kv_channels
,
rotary_percent
=
rotary_percent
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
rotary_base
=
rotary_base
,
rope_scaling
=
rope_scaling
,
rope_scaling_factor
=
rope_scaling_factor
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
# Cache for RoPE tensors which do not change between iterations.
self
.
rotary_pos_emb_cache
=
{}
# Transformer.
self
.
decoder
=
TransformerBlock
(
config
=
self
.
config
,
spec
=
transformer_layer_spec
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
# Output
if
post_process
:
if
self
.
config
.
defer_embedding_wgrad_compute
:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self
.
embedding_activation_buffer
=
[]
self
.
grad_output_buffer
=
[]
else
:
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
self
.
output_layer
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
self
.
vocab_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_embeddings_and_output_weights
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
# add mtp
self
.
mtp_spec
:
ModuleSpec
=
mtp_spec
self
.
num_nextn_predict_layers
=
self
.
config
.
num_nextn_predict_layers
self
.
share_mtp_embedding_and_output_weight
=
self
.
config
.
share_mtp_embedding_and_output_weight
self
.
recompute_mtp_norm
=
self
.
config
.
recompute_mtp_norm
self
.
recompute_mtp_layer
=
self
.
config
.
recompute_mtp_layer
self
.
mtp_loss_scale
=
self
.
config
.
mtp_loss_scale
if
self
.
post_process
and
self
.
training
and
self
.
num_nextn_predict_layers
:
self
.
mtp_layers
=
torch
.
nn
.
ModuleList
(
[
MultiTokenPredictor
(
config
,
self
.
mtp_spec
.
submodules
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
layer_number
=
i
,
pre_process
=
self
.
pre_process
,
fp16_lm_cross_entropy
=
self
.
fp16_lm_cross_entropy
,
parallel_output
=
self
.
parallel_output
,
position_embedding_type
=
self
.
position_embedding_type
,
rotary_percent
=
self
.
rotary_percent
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
share_mtp_embedding_and_output_weight
=
self
.
share_mtp_embedding_and_output_weight
,
recompute_mtp_norm
=
self
.
recompute_mtp_norm
,
recompute_mtp_layer
=
self
.
recompute_mtp_layer
,
add_output_layer_bias
=
False
def
gpt_model_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
fn
(
self
,
*
args
,
**
kwargs
)
# add mtp
self
.
num_nextn_predict_layers
=
self
.
config
.
num_nextn_predict_layers
if
self
.
num_nextn_predict_layers
:
assert
hasattr
(
self
.
config
,
"mtp_spec"
)
self
.
mtp_spec
:
ModuleSpec
=
self
.
config
.
mtp_spec
self
.
share_mtp_embedding_and_output_weight
=
self
.
config
.
share_mtp_embedding_and_output_weight
self
.
recompute_mtp_norm
=
self
.
config
.
recompute_mtp_norm
self
.
recompute_mtp_layer
=
self
.
config
.
recompute_mtp_layer
self
.
mtp_loss_scale
=
self
.
config
.
mtp_loss_scale
if
self
.
post_process
and
self
.
training
:
self
.
mtp_layers
=
torch
.
nn
.
ModuleList
(
[
MultiTokenPredictor
(
config
,
self
.
mtp_spec
.
submodules
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
layer_number
=
i
,
pre_process
=
self
.
pre_process
,
fp16_lm_cross_entropy
=
self
.
fp16_lm_cross_entropy
,
parallel_output
=
self
.
parallel_output
,
position_embedding_type
=
self
.
position_embedding_type
,
rotary_percent
=
self
.
rotary_percent
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
share_mtp_embedding_and_output_weight
=
self
.
share_mtp_embedding_and_output_weight
,
recompute_mtp_norm
=
self
.
recompute_mtp_norm
,
recompute_mtp_layer
=
self
.
recompute_mtp_layer
,
add_output_layer_bias
=
False
)
for
i
in
range
(
self
.
num_nextn_predict_layers
)
]
)
for
i
in
range
(
self
.
num_nextn_predict_layers
)
]
)
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
if
has_config_logger_enabled
(
self
.
config
):
log_config_to_disk
(
self
.
config
,
self
.
state_dict
(),
prefix
=
f
'
{
type
(
self
).
__name__
}
_init_ckpt'
)
if
self
.
pre_process
or
self
.
post_process
:
setup_mtp_embeddings
(
self
)
if
self
.
num_nextn_predict_layers
and
(
self
.
pre_process
or
self
.
post_process
):
setup_mtp_embeddings
(
self
)
return
wrapper
def
shared_embedding_or_mtp_embedding_weight
(
self
)
->
Tensor
:
...
...
dcu_megatron/core/transformer/transformer_block.py
View file @
9800dec4
...
...
@@ -199,4 +199,3 @@ def transformer_block_forward(
)
return
hidden_states
dcu_megatron/legacy/model/rms_norm.py
View file @
9800dec4
import
torch
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
import
lightop
# rmsnorm_forward,rmsnorm_backward
from
functools
import
partial
from
megatron.core.utils
import
is_torch_min_version
if
is_torch_min_version
(
"2.4.0a0"
):
custom_fwd
=
partial
(
torch
.
amp
.
custom_fwd
,
device_type
=
"cuda"
)
custom_bwd
=
partial
(
torch
.
amp
.
custom_bwd
,
device_type
=
"cuda"
)
else
:
custom_fwd
=
torch
.
cuda
.
amp
.
custom_fwd
custom_bwd
=
torch
.
cuda
.
amp
.
custom_bwd
class
_LightopRMSNorm
(
torch
.
autograd
.
Function
):
""" 使用lightop实现rmsnorm"""
@
staticmethod
# @custom_fwd
def
forward
(
ctx
,
inp
:
torch
.
Tensor
,
ln_out
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
,
is_grad_enabled
):
output
=
lightop
.
rmsnorm_forward
(
inp
,
weight
,
ln_out
,
eps
,
training
=
True
)
# output = (output, weight)
rsigma
=
output
[
1
]
if
is_grad_enabled
:
ctx
.
save_for_backward
(
inp
,
weight
,
ln_out
,
rsigma
)
return
output
[
0
]
@
staticmethod
# @custom_bwd
def
backward
(
ctx
,
grad_output
):
inp
,
weight
,
ln_out
,
rsigma
=
ctx
.
saved_tensors
dgrad
,
dgamma
=
lightop
.
rmsnorm_backward
(
grad_output
,
inp
,
rsigma
,
weight
)
return
dgrad
,
None
,
dgamma
,
None
,
None
,
None
,
None
,
None
,
None
class
LightopRMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
,):
"""RMS Normaliation module
Args:
dim (int): The width of input, i.e. hidden size
eps (float): epsilon to use for the norm, default to 1e-6
"""
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
dim
))
# @no_torch_dynamo() # 动态torch._dynamo.disable
def
forward
(
self
,
inp
:
torch
.
Tensor
,
is_first_microbatch
:
Optional
[
bool
]
=
None
):
if
torch
.
is_grad_enabled
():
fwd_fn
=
_LightopRMSNorm
.
apply
args
=
[]
else
:
fwd_fn
=
_LightopRMSNorm
.
forward
args
=
[
None
]
ln_out
=
torch
.
empty_like
(
inp
,
dtype
=
inp
.
dtype
,
memory_format
=
torch
.
contiguous_format
)
args
+=
(
inp
,
ln_out
,
self
.
weight
,
self
.
eps
,
torch
.
is_grad_enabled
())
out
=
fwd_fn
(
*
args
)
return
out
dcu_megatron/legacy/model/utils.py
0 → 100644
View file @
9800dec4
from
megatron.training
import
get_args
from
megatron.legacy.model
import
LayerNorm
from
.rms_norm
import
LightopRMSNorm
def
get_norm
(
config
):
args
=
get_args
()
if
args
.
normalization
==
"LayerNorm"
:
return
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
no_persist_layer_norm
=
not
config
.
persist_layer_norm
,
sequence_parallel
=
config
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
elif
args
.
normalization
==
"RMSNorm"
:
if
args
.
apply_layernorm_1p
:
raise
NotImplementedError
(
'RMSNorm does not currently support the layernorm_1p formulation.'
)
return
LightopRMSNorm
(
dim
=
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
sequence_parallel
=
config
.
sequence_parallel
)
else
:
raise
Exception
(
f
"unsupported norm type '
{
args
.
normalization
}
'."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment