Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
763941b5
Commit
763941b5
authored
Apr 23, 2025
by
dongcl
Browse files
modify ParallelAttention
parent
bf323343
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
77 additions
and
65 deletions
+77
-65
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+13
-5
dcu_megatron/legacy/model/transformer.py
dcu_megatron/legacy/model/transformer.py
+64
-60
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
763941b5
...
@@ -141,9 +141,9 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -141,9 +141,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
}),
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
}),
apply_wrapper
=
True
)
apply_wrapper
=
True
)
#
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func'
,
#
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
,
"triton.cudagraph_support_input_mutation"
:
True
}),
#
apply_wrapper=True)
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
apply_wrapper
=
True
)
...
@@ -233,13 +233,21 @@ class LegacyAdaptation(MegatronAdaptationABC):
...
@@ -233,13 +233,21 @@ class LegacyAdaptation(MegatronAdaptationABC):
self
.
patch_legacy_models
()
self
.
patch_legacy_models
()
def
patch_legacy_models
(
self
):
def
patch_legacy_models
(
self
):
from
..legacy.model.transformer
import
ParallelMLPPatch
,
ParallelAttentionPatch
from
..legacy.model.transformer
import
(
parallel_mlp_init_wrapper
,
ParallelAttentionPatch
,
parallel_attention_init_wrapper
)
from
..legacy.model.utils
import
get_norm
from
..legacy.model.utils
import
get_norm
# ParallecMLP
# ParallecMLP
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelMLP.__init__'
,
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelMLP.__init__'
,
ParallelMLPPatch
.
__init__
)
parallel_mlp_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelAttention.__init__'
,
parallel_attention_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelAttention.forward'
,
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelAttention.forward'
,
ParallelAttentionPatch
.
forward
)
ParallelAttentionPatch
.
forward
)
...
...
dcu_megatron/legacy/model/transformer.py
View file @
763941b5
...
@@ -6,79 +6,82 @@ from megatron.core import tensor_parallel
...
@@ -6,79 +6,82 @@ from megatron.core import tensor_parallel
from
megatron.legacy.model.enums
import
AttnType
from
megatron.legacy.model.enums
import
AttnType
from
megatron.core.models.common.embeddings
import
apply_rotary_pos_emb
from
megatron.core.models.common.embeddings
import
apply_rotary_pos_emb
from
megatron.legacy.model.module
import
MegatronModule
from
megatron.legacy.model.module
import
MegatronModule
from
megatron.legacy.model.transformer
import
ParallelMLP
from
megatron.legacy.model.utils
import
(
erf_gelu
,
openai_gelu
,
)
try
:
try
:
from
einops
import
rearrange
from
einops
import
rearrange
except
ImportError
:
except
ImportError
:
rearrange
=
None
rearrange
=
None
class
ParallelMLPPatch
(
MegatronModule
):
try
:
# 使用定长fa
"""MLP.
from
flash_attn
import
flash_attn_func
except
ImportError
:
flash_attn_func
=
None
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def
__init__
(
self
,
config
,
is_expert
=
False
):
super
(
ParallelMLP
,
self
).
__init__
()
args
=
get_args
()
self
.
add_bias
=
config
.
add_bias_linear
def
parallel_mlp_init_wrapper
(
fn
):
@
wraps
(
fn
)
ffn_hidden_size
=
config
.
ffn_hidden_size
def
wrapper
(
self
,
*
args
,
**
kwargs
):
if
config
.
gated_linear_unit
:
fn
(
self
,
*
args
,
**
kwargs
)
ffn_hidden_size
*=
2
args
=
get_args
()
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
if
args
.
swiglu
:
self
.
dense_h_to_4h
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
ffn_hidden_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
self
.
add_bias
,
gather_output
=
False
,
skip_bias_add
=
True
,
is_expert
=
is_expert
,
)
self
.
bias_gelu_fusion
=
False
self
.
activation_func
=
None
self
.
swiglu
=
args
.
swiglu
if
args
.
openai_gelu
:
self
.
activation_func
=
openai_gelu
elif
args
.
onnx_safe
:
self
.
activation_func
=
erf_gelu
elif
args
.
swiglu
:
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
@
torch
.
compile
(
mode
=
"max-autotune-no-cudagraphs"
)
def
swiglu
(
x
):
def
swiglu
(
x
):
x
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
x
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
return
F
.
silu
(
x
[
0
])
*
x
[
1
]
return
F
.
silu
(
x
[
0
])
*
x
[
1
]
self
.
activation_func
=
swiglu
self
.
activation_func
=
swiglu
elif
args
.
squared_relu
:
def
squared_relu
(
x
):
return
wrapper
return
torch
.
pow
(
F
.
relu
(
x
),
2
)
self
.
activation_func
=
squared_relu
else
:
class
FlashFixedSelfAttention
(
torch
.
nn
.
Module
):
self
.
bias_gelu_fusion
=
args
.
bias_gelu_fusion
"""Implement the scaled dot product attention with softmax.
self
.
activation_func
=
F
.
gelu
Arguments
---------
# Project back to h.
softmax_scale: The temperature to use for the softmax attention.
self
.
dense_4h_to_h
=
tensor_parallel
.
RowParallelLinear
(
(default: 1/sqrt(d_keys) where d_keys is computed at
config
.
ffn_hidden_size
,
runtime)
config
.
hidden_size
,
attention_dropout: The dropout rate to apply to the attention
config
=
config
,
(default: 0.0)
init_method
=
config
.
output_layer_init_method
,
"""
bias
=
self
.
add_bias
,
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
skip_bias_add
=
True
,
device
=
None
,
dtype
=
None
):
input_is_parallel
=
True
,
super
().
__init__
()
is_expert
=
is_expert
,
assert
flash_attn_func
is
not
None
,
(
'Please install FlashAttention first, '
)
'e.g., with pip install flash-attn'
)
assert
rearrange
is
not
None
,
'Please install einops first, e.g., with pip install einops'
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
self
.
flash_attn_func
=
flash_attn_func
def
forward
(
self
,
q
,
k
,
v
):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert
all
((
i
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
for
i
in
(
q
,
k
,
v
)))
assert
all
((
i
.
is_cuda
for
i
in
(
q
,
k
,
v
)))
output
=
self
.
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
self
.
dropout_p
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
)
# [b,s,a,dim]
return
output
def
parallel_attention_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
fn
(
self
,
*
args
,
**
kwargs
)
if
self
.
use_flash_attn
:
self
.
core_attention_flash
=
FlashFixedSelfAttention
(
causal
=
True
,
attention_dropout
=
self
.
config
.
attention_dropout
)
return
wrapper
class
ParallelAttentionPatch
(
MegatronModule
):
class
ParallelAttentionPatch
(
MegatronModule
):
...
@@ -87,6 +90,7 @@ class ParallelAttentionPatch(MegatronModule):
...
@@ -87,6 +90,7 @@ class ParallelAttentionPatch(MegatronModule):
Self-attention layer takes input with size [s, b, h]
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
and returns output of the same size.
"""
"""
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
inference_params
=
None
,
encoder_output
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
rotary_pos_emb
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment