Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
daefd764
Commit
daefd764
authored
Apr 30, 2025
by
zhuwenwen
Browse files
skip is_rocm_aiter_moe_enabled and add mla pad
parent
43a52016
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
20 deletions
+19
-20
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+3
-4
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+7
-7
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+9
-9
No files found.
vllm/attention/backends/mla/common.py
View file @
daefd764
...
@@ -1088,8 +1088,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1088,8 +1088,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# maybe_padded_v = torch.nn.functional.pad(
# maybe_padded_v = torch.nn.functional.pad(
# v, [0, q.shape[-1] - v.shape[-1]], value=0)
# v, [0, q.shape[-1] - v.shape[-1]], value=0)
maybe_padded_v
=
torch
.
nn
.
functional
.
pad
(
maybe_padded_v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]
]
-
32
,
value
=
0
)
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]
-
32
]
,
value
=
0
)
v_tmp
=
maybe_padded_v
[...,
:
-
32
].
reshape
(
v
.
shape
[
0
],
v
.
shape
[
1
],
v
.
shape
[
2
])
maybe_padded_v
=
maybe_padded_v
[...,
:
-
32
].
reshape
(
v
.
shape
[
0
],
v
.
shape
[
1
],
v
.
shape
[
2
])
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
\
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
\
and
not
return_softmax_lse
:
and
not
return_softmax_lse
:
...
@@ -1120,8 +1120,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1120,8 +1120,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
attn_out
=
self
.
flash_attn_varlen_func
(
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
# v=maybe_padded_v,
v
=
maybe_padded_v
,
v
=
v_tmp
,
return_attn_probs
=
return_softmax_lse
,
return_attn_probs
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
**
kwargs
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
daefd764
...
@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
...
@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
.rocm_aiter_fused_moe
import
is_rocm_aiter_moe_enabled
#
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
...
@@ -1141,9 +1141,9 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
...
@@ -1141,9 +1141,9 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
def
dispatch_topk_func
()
->
Callable
[...,
tuple
[
torch
.
Tensor
,
...]]:
def
dispatch_topk_func
()
->
Callable
[...,
tuple
[
torch
.
Tensor
,
...]]:
if
is_rocm_aiter_moe_enabled
():
#
if is_rocm_aiter_moe_enabled():
from
.rocm_aiter_fused_moe
import
rocm_aiter_topk_softmax
#
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
return
rocm_aiter_topk_softmax
#
return rocm_aiter_topk_softmax
return
vllm_topk_softmax
return
vllm_topk_softmax
...
@@ -1405,9 +1405,9 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
...
@@ -1405,9 +1405,9 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
def
dispatch_fused_experts_func
(
inplace
:
bool
)
->
Callable
[...,
torch
.
Tensor
]:
def
dispatch_fused_experts_func
(
inplace
:
bool
)
->
Callable
[...,
torch
.
Tensor
]:
if
is_rocm_aiter_moe_enabled
():
#
if is_rocm_aiter_moe_enabled():
from
.rocm_aiter_fused_moe
import
rocm_aiter_fused_experts
#
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
return
rocm_aiter_fused_experts
#
return rocm_aiter_fused_experts
if
inplace
:
if
inplace
:
return
torch_vllm_inplace_fused_experts
return
torch_vllm_inplace_fused_experts
return
torch_vllm_outplace_fused_experts
return
torch_vllm_outplace_fused_experts
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
daefd764
...
@@ -135,15 +135,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -135,15 +135,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
w13_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w13_weight
.
data
)
layer
.
w13_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w13_weight
.
data
)
layer
.
w2_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w2_weight
.
data
)
layer
.
w2_weight
.
data
=
self
.
_maybe_pad_weight
(
layer
.
w2_weight
.
data
)
# Lazy import to avoid importing triton.
# Lazy import to avoid importing triton.
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
#
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled
,
shuffle_weights
)
#
is_rocm_aiter_moe_enabled, shuffle_weights)
if
is_rocm_aiter_moe_enabled
():
#
if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel.
#
# reshaping weights is required for aiter moe kernel.
shuffled_w13
,
shuffled_w2
=
shuffle_weights
(
#
shuffled_w13, shuffled_w2 = shuffle_weights(
layer
.
w13_weight
.
data
,
layer
.
w2_weight
.
data
)
#
layer.w13_weight.data, layer.w2_weight.data)
layer
.
w13_weight
.
data
=
shuffled_w13
#
layer.w13_weight.data = shuffled_w13
layer
.
w2_weight
.
data
=
shuffled_w2
#
layer.w2_weight.data = shuffled_w2
if
current_platform
.
is_cpu
():
if
current_platform
.
is_cpu
():
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
X86
:
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
X86
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment