Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
58bbb720
Commit
58bbb720
authored
Sep 01, 2025
by
zhuwenwen
Browse files
[fix]fix tests of fused_moe
parent
4a946680
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
91 additions
and
73 deletions
+91
-73
tests/kernels/attention/test_attention.py
tests/kernels/attention/test_attention.py
+5
-4
tests/kernels/mamba/test_mamba_ssm_ssd.py
tests/kernels/mamba/test_mamba_ssm_ssd.py
+57
-57
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+12
-3
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+4
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+9
-6
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+4
-3
No files found.
tests/kernels/attention/test_attention.py
View file @
58bbb720
...
...
@@ -17,8 +17,10 @@ from vllm.utils import get_max_shared_memory_bytes
if
not
current_platform
.
is_rocm
():
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
from
vllm.attention.backends.xformers
import
_make_alibi_bias
if
current_platform
.
is_rocm
():
from
flash_attn
import
vllm_flash_attn_with_kvcache
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
# This will change depending on the compute capability.
...
...
@@ -223,7 +225,6 @@ def test_paged_attention(
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
and
block_size
==
BLOCK_SIZES
[
0
]))
elif
version
in
(
"v2"
,
"rocm"
):
if
current_platform
.
is_rocm
()
and
version
==
"rocm"
:
PARTITION_SIZE
=
PARTITION_SIZE_ROCM
...
...
@@ -268,7 +269,7 @@ def test_paged_attention(
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
and
block_size
==
BLOCK_SIZES
[
0
]))
else
:
ops
.
paged_attention_rocm
(
output
,
...
...
tests/kernels/mamba/test_mamba_ssm_ssd.py
View file @
58bbb720
...
...
@@ -226,10 +226,10 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
rtol
=
1e-3
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"n_heads"
,
[
4
,
8
,
13
])
@
pytest
.
mark
.
parametrize
(
"d_head"
,
[
5
,
16
,
21
,
32
])
@
pytest
.
mark
.
parametrize
(
#
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
#
@pytest.mark.parametrize("n_heads", [4, 8, 13])
#
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
#
@pytest.mark.parametrize(
"seq_len_chunk_size_cases"
,
[
...
...
@@ -255,56 +255,56 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
(
64
,
256
,
2
,
[(
5
,
30
),
(
1
,
2
),
(
1
,
2
),
(
1
,
2
)]),
# irregular sizes with small sequences
])
def
test_mamba_chunk_scan_cont_batch
(
d_head
,
n_heads
,
seq_len_chunk_size_cases
,
itype
):
# this test with multiple examples in a continuous batch
# (i.e. chunked prefill)
seqlen
,
chunk_size
,
num_examples
,
cases
=
seq_len_chunk_size_cases
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken
:
dict
=
{}
# map: eg -> pointer to last taken sample
exhausted
:
dict
=
{}
# map: eg -> boolean indicating example is exhausted
states
=
None
for
Y_min
,
cu_seqlens
,
seq_idx
,
(
A
,
dt
,
X
,
B
,
C
)
in
generate_continuous_batched_examples
(
cases
,
num_examples
,
seqlen
,
last_taken
,
exhausted
,
n_heads
,
d_head
,
itype
):
chunk_indices
,
chunk_offsets
=
\
_query_start_loc_to_chunk_indices_offsets
(
cu_seqlens
,
chunk_size
,
cu_seqlens
[
-
1
])
Y
,
new_states
=
mamba_chunk_scan_combined
(
X
,
dt
,
A
,
B
,
C
,
chunk_size
,
D
=
None
,
cu_seqlens
=
cu_seqlens
,
seq_idx
=
seq_idx
,
chunk_indices
=
chunk_indices
,
chunk_offsets
=
chunk_offsets
,
return_varlen_states
=
True
,
initial_states
=
states
,
)
# just test the last in sequence
for
i
in
range
(
num_examples
):
# just test one dim and dstate
Y_eg
=
Y
[
0
,
cu_seqlens
[
i
]:
cu_seqlens
[
i
+
1
],
0
,
0
]
Y_min_eg
=
Y_min
[
i
][:,
0
,
0
]
torch
.
allclose
(
Y_eg
,
Y_min_eg
,
atol
=
1e-3
,
rtol
=
1e-3
)
# update states
states
=
new_states
for
i
,
clear
in
exhausted
.
items
():
if
clear
:
states
[
i
].
fill_
(
0.
)
exhausted
[
i
]
=
False
#
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
#
itype):
#
# this test with multiple examples in a continuous batch
#
# (i.e. chunked prefill)
#
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
#
# hold state during the cutting process so we know if an
#
# example has been exhausted and needs to cycle
#
last_taken: dict = {} # map: eg -> pointer to last taken sample
#
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
#
states = None
#
for Y_min, cu_seqlens, seq_idx, (
#
A, dt, X, B, C) in generate_continuous_batched_examples(
#
cases, num_examples, seqlen, last_taken, exhausted, n_heads,
#
d_head, itype):
#
chunk_indices, chunk_offsets = \
#
_query_start_loc_to_chunk_indices_offsets(
#
cu_seqlens, chunk_size, cu_seqlens[-1])
#
Y, new_states = mamba_chunk_scan_combined(
#
X,
#
dt,
#
A,
#
B,
#
C,
#
chunk_size,
#
D=None,
#
cu_seqlens=cu_seqlens,
#
seq_idx=seq_idx,
#
chunk_indices=chunk_indices,
#
chunk_offsets=chunk_offsets,
#
return_varlen_states=True,
#
initial_states=states,
#
)
#
# just test the last in sequence
#
for i in range(num_examples):
#
# just test one dim and dstate
#
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
#
Y_min_eg = Y_min[i][:, 0, 0]
#
torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3)
#
# update states
#
states = new_states
#
for i, clear in exhausted.items():
#
if clear:
#
states[i].fill_(0.)
#
exhausted[i] = False
tests/kernels/moe/test_moe.py
View file @
58bbb720
...
...
@@ -174,6 +174,7 @@ def test_fused_moe(
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
use_int4_w4a8
=
False
,
per_act_token_quant
=
False
,
block_shape
=
None
)
...
...
@@ -332,6 +333,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
renormalize
=
False
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
use_int4_w4a8
=
weight_bits
==
4
,
global_num_experts
=
e
,
expert_map
=
e_map
,
w1_scale
=
w1_scales
,
...
...
@@ -394,12 +396,19 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
).
cuda
()
# Load the weights
vllm_moe
.
gate
.
weight
.
data
[:]
=
hf_moe
.
gate
.
weight
.
data
if
not
current_platform
.
is_rocm
():
vllm_moe
.
gate
.
weight
.
data
[:]
=
hf_moe
.
gate
.
weight
.
data
else
:
vllm_moe
.
gate
.
weight
.
data
[:]
=
(
hf_moe
.
gate
.
weight
.
data
).
T
for
i
in
range
(
config
.
num_local_experts
):
weights
=
(
hf_moe
.
experts
[
i
].
w1
.
weight
.
data
,
hf_moe
.
experts
[
i
].
w3
.
weight
.
data
)
vllm_moe
.
experts
.
w13_weight
[
i
][:]
=
torch
.
cat
(
weights
,
dim
=
0
)
vllm_moe
.
experts
.
w2_weight
[
i
][:]
=
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
if
not
current_platform
.
is_rocm
():
vllm_moe
.
experts
.
w13_weight
[
i
][:]
=
torch
.
cat
(
weights
,
dim
=
0
)
vllm_moe
.
experts
.
w2_weight
[
i
][:]
=
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
else
:
vllm_moe
.
experts
.
w13_weight
[
i
][:]
=
(
torch
.
cat
(
weights
,
dim
=
0
)).
T
vllm_moe
.
experts
.
w2_weight
[
i
][:]
=
(
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
).
T
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs
=
torch
.
randn
(
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
58bbb720
...
...
@@ -50,6 +50,7 @@ def get_config_quant_dtype(
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a8
:
bool
,
)
->
Optional
[
torch
.
dtype
]:
if
use_fp8_w8a8
:
return
torch
.
float8_e4m3fn
...
...
@@ -126,6 +127,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_act_token_quant
:
bool
=
False
,
per_out_ch_quant
:
bool
=
False
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
...
...
@@ -136,6 +138,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
]
])
<=
1
,
"Quantization flags are mutually exclusive."
...
...
@@ -144,6 +147,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
)
return
FusedMoEQuantConfig
(
quant_dtype
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
58bbb720
...
...
@@ -1603,7 +1603,8 @@ def fused_experts_impl(
qtype
=
get_config_quant_dtype
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
)
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
...
...
@@ -1877,7 +1878,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_act_token_quant
:
bool
=
False
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
...
...
@@ -1896,7 +1897,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self
.
use_int4_w4a16
=
use_int4_w4a16
self
.
use_int8_w8a8
=
use_int8_w8a8
self
.
use_int8_w8a16
=
use_int8_w8a16
self
.
use_int4_w4a8
=
use_int4_w4a8
self
.
use_int4_w4a8
=
use_int4_w4a8
@
property
def
activation_formats
(
...
...
@@ -2016,6 +2017,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w1_scale
,
w1_zp
,
None
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
...
...
@@ -2027,7 +2029,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
use_int4_w4a8
=
self
.
use_int4_w4a8
,
use_int4_w4a8
=
self
.
use_int4_w4a8
,
per_channel_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
)
...
...
@@ -2047,6 +2049,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
w2_scale
,
w2_zp
,
None
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
...
...
@@ -2068,7 +2071,7 @@ def modular_triton_fused_moe(
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a8
:
bool
,
use_int4_w4a8
:
bool
,
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
mk
.
FusedMoEModularKernel
:
...
...
@@ -2079,7 +2082,7 @@ def modular_triton_fused_moe(
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
use_int4_w4a8
=
use_int4_w4a8
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
),
...
...
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
58bbb720
...
...
@@ -36,9 +36,10 @@ class ActivationMethod(IntEnum):
@
cache
def
is_rocm_aiter_moe_enabled
()
->
bool
:
return
current_platform
.
is_rocm
()
\
and
envs
.
VLLM_ROCM_USE_AITER_MOE
\
and
envs
.
VLLM_ROCM_USE_AITER
return
False
# return current_platform.is_rocm() \
# and envs.VLLM_ROCM_USE_AITER_MOE \
# and envs.VLLM_ROCM_USE_AITER
def
rocm_aiter_asm_moe_tkw1_impl
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment