Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99324e25
Commit
99324e25
authored
Jul 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.2' into v0.9.2-ori
parents
cc7f22a8
a5dd03c1
Changes
475
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1570 additions
and
304 deletions
+1570
-304
tests/kernels/attention/test_attention.py
tests/kernels/attention/test_attention.py
+18
-1
tests/kernels/attention/test_attention_selector.py
tests/kernels/attention/test_attention_selector.py
+42
-5
tests/kernels/attention/test_cache.py
tests/kernels/attention/test_cache.py
+8
-8
tests/kernels/attention/test_encoder_decoder_attn.py
tests/kernels/attention/test_encoder_decoder_attn.py
+2
-2
tests/kernels/attention/test_mla_decode_cpu.py
tests/kernels/attention/test_mla_decode_cpu.py
+1
-4
tests/kernels/attention/test_rocm_attention_selector.py
tests/kernels/attention/test_rocm_attention_selector.py
+4
-2
tests/kernels/attention/test_triton_decode_attention.py
tests/kernels/attention/test_triton_decode_attention.py
+1
-4
tests/kernels/core/test_rotary_embedding.py
tests/kernels/core/test_rotary_embedding.py
+4
-4
tests/kernels/mamba/test_mamba_ssm_ssd.py
tests/kernels/mamba/test_mamba_ssm_ssd.py
+13
-14
tests/kernels/moe/parallel_utils.py
tests/kernels/moe/parallel_utils.py
+9
-13
tests/kernels/moe/test_batched_moe.py
tests/kernels/moe/test_batched_moe.py
+244
-37
tests/kernels/moe/test_block_fp8.py
tests/kernels/moe/test_block_fp8.py
+296
-0
tests/kernels/moe/test_block_int8.py
tests/kernels/moe/test_block_int8.py
+147
-0
tests/kernels/moe/test_cutlass_grouped_gemm.py
tests/kernels/moe/test_cutlass_grouped_gemm.py
+116
-0
tests/kernels/moe/test_cutlass_moe.py
tests/kernels/moe/test_cutlass_moe.py
+23
-9
tests/kernels/moe/test_deepep_deepgemm_moe.py
tests/kernels/moe/test_deepep_deepgemm_moe.py
+34
-86
tests/kernels/moe/test_deepep_moe.py
tests/kernels/moe/test_deepep_moe.py
+76
-44
tests/kernels/moe/test_deepgemm.py
tests/kernels/moe/test_deepgemm.py
+226
-0
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+216
-71
tests/kernels/moe/test_moe_align_block_size.py
tests/kernels/moe/test_moe_align_block_size.py
+90
-0
No files found.
Too many changes to show.
To preserve performance only
475 of 475+
files are displayed.
Plain diff
Email patch
tests/kernels/attention/test_attention.py
View file @
99324e25
...
...
@@ -10,6 +10,7 @@ import torch
from
tests.kernels.allclose_default
import
get_default_atol
,
get_default_rtol
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.layer
import
Attention
,
MultiHeadAttention
from
vllm.platforms
import
current_platform
from
vllm.utils
import
get_max_shared_memory_bytes
...
...
@@ -449,7 +450,8 @@ def test_multi_query_kv_attention(
start
+=
seq_len
# xformers.AttentionBias to Tensor for use in reference impl.
alibi_bias
=
[
b
.
materialize
(
b
.
shape
,
device
=
device
).
squeeze
()
for
b
in
attn_bias
b
.
materialize
((
1
,
num_query_heads
,
i
,
i
),
device
=
device
).
squeeze
()
for
b
,
i
in
zip
(
attn_bias
,
seq_lens
)
]
else
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
...
...
@@ -506,3 +508,18 @@ def test_multi_query_kv_attention_with_alibi(
device
,
use_alibi
=
True
,
)
@
pytest
.
mark
.
parametrize
(
"attention_cls"
,
[
Attention
,
MultiHeadAttention
])
def
test_num_heads_not_divisble_by_num_kv_heads
(
attention_cls
:
type
)
->
None
:
head_size
=
64
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
num_heads
=
16
num_kv_heads
=
5
with
pytest
.
raises
(
AssertionError
):
_
=
attention_cls
(
num_heads
=
num_heads
,
head_size
=
head_size
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
,
)
tests/kernels/attention/test_attention_selector.py
View file @
99324e25
...
...
@@ -106,10 +106,8 @@ def test_env(
block_size
,
False
,
use_mla
=
use_mla
)
if
use_v1
and
name
!=
"TRITON_MLA"
:
assert
backend
.
get_name
()
==
f
"
{
name
}
_VLLM_V1"
else
:
assert
backend
.
get_name
()
==
name
expected
=
f
"
{
name
}
_VLLM_V1"
if
use_v1
else
name
assert
backend
.
get_name
()
==
expected
else
:
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
get_attn_backend
(
16
,
...
...
@@ -173,7 +171,7 @@ def test_env(
expected
=
"FLASHINFER_VLLM_V1"
if
use_v1
else
name
assert
backend
.
get_name
()
==
expected
else
:
backend
=
get_attn_backend
(
16
,
backend
=
get_attn_backend
(
32
,
torch
.
float16
,
torch
.
float16
,
block_size
,
...
...
@@ -182,6 +180,45 @@ def test_env(
expected
=
"FLASH_ATTN_VLLM_V1"
if
use_v1
else
name
assert
backend
.
get_name
()
==
expected
if
use_v1
:
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
block_size
,
False
,
use_mla
=
use_mla
)
assert
backend
.
get_name
()
==
"FLEX_ATTENTION"
,
(
"Should fallback to FlexAttention if head size is "
"not supported by FlashAttention"
)
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"use_v1"
,
[
True
,
False
])
def
test_fp32_fallback
(
device
:
str
,
use_v1
:
bool
,
monkeypatch
:
pytest
.
MonkeyPatch
,
):
"""Test attention backend selection with fp32."""
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
if
use_v1
else
"0"
)
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
torch
.
float32
,
16
,
False
)
assert
(
backend
.
get_name
()
==
"TORCH_SDPA_VLLM_V1"
if
use_v1
else
"TORCH_SDPA"
)
elif
device
==
"cuda"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
torch
.
float32
,
16
,
False
)
assert
(
backend
.
get_name
()
==
"FLEX_ATTENTION"
if
use_v1
else
"XFORMERS"
)
def
test_flash_attn
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test FlashAttn validation."""
...
...
tests/kernels/attention/test_cache.py
View file @
99324e25
...
...
@@ -72,8 +72,8 @@ def test_copy_blocks(
# destination blocks.
assert
2
*
num_mappings
<=
num_blocks
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
remaini
n
g_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remaini
n
g_blocks
,
2
*
num_mappings
)
block_mapping
:
list
[
tuple
[
int
,
int
]]
=
[]
for
i
in
range
(
num_mappings
):
src
=
src_blocks
[
i
]
...
...
@@ -189,12 +189,12 @@ def test_reshape_and_cache(
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indic
i
es
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indic
i
es_lst
=
block_indic
i
es
.
cpu
().
tolist
()
block_indices
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indices_lst
=
block_indices
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indic
i
es_lst
[
i
]
block_idx
=
block_indices_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
...
...
@@ -322,12 +322,12 @@ def test_reshape_and_cache_flash(
kv_dtype
=
kv_cache_dtype
)
# Run the reference implementation.
block_indic
i
es
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indic
i
es_lst
=
block_indic
i
es
.
cpu
().
tolist
()
block_indices
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indices_lst
=
block_indices
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indic
i
es_lst
[
i
]
block_idx
=
block_indices_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
if
kv_cache_layout
==
"NHD"
:
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
...
...
tests/kernels/attention/test_encoder_decoder_attn.py
View file @
99324e25
...
...
@@ -46,7 +46,7 @@ CUDA_DEVICE = "cuda:0"
MAX_DEC_SEQ_LENS
=
[
128
]
MAX_ENC_SEQ_LENS
=
[
128
]
# Narrow te
e
st-cases for unsupported-scenario
# Narrow test-cases for unsupported-scenario
# tests
HEAD_SIZES_FOR_UNSUPP
=
[
HEAD_SIZES
[
0
]]
...
...
@@ -99,7 +99,7 @@ class TestResources(NamedTuple):
Attributes:
* scale: 1/sqrt(d) scale factor for attn
* attn_backend: implementati
n
o of abstraction
* attn_backend: implementatio
ns
of abstraction
attention interface using
a particular kernel library
i.e. XFormers
...
...
tests/kernels/attention/test_mla_decode_cpu.py
View file @
99324e25
...
...
@@ -7,10 +7,7 @@ from torch import Tensor
import
vllm._custom_ops
as
ops
from
vllm.platforms
import
current_platform
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
from
vllm.utils
import
cdiv
def
ref_mla
(
...
...
tests/kernels/attention/test_rocm_attention_selector.py
View file @
99324e25
...
...
@@ -35,7 +35,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"TRITON_MLA"
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
False
,
True
)
assert
backend
.
get_name
()
==
"TRITON_MLA"
assert
(
backend
.
get_name
()
==
"TRITON_MLA"
or
backend
.
get_name
()
==
"TRITON_MLA_VLLM_V1"
)
# If attention backend is None
# If use_mla is true
...
...
@@ -43,7 +44,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
None
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
False
,
True
)
assert
backend
.
get_name
()
==
"TRITON_MLA"
assert
(
backend
.
get_name
()
==
"TRITON_MLA"
or
backend
.
get_name
()
==
"TRITON_MLA_VLLM_V1"
)
# change the attention backend to AITER MLA
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"ROCM_AITER_MLA"
)
...
...
tests/kernels/attention/test_triton_decode_attention.py
View file @
99324e25
...
...
@@ -5,10 +5,7 @@ import pytest
import
torch
from
vllm.attention.ops.triton_decode_attention
import
decode_attention_fwd
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
from
vllm.utils
import
cdiv
@
pytest
.
mark
.
parametrize
(
"B"
,
[
3
,
5
])
...
...
tests/kernels/core/test_rotary_embedding.py
View file @
99324e25
...
...
@@ -39,10 +39,10 @@ def rotary_embedding_opcheck(rot,
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
32
,
108
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
11
,
1024
])
@
pytest
.
mark
.
parametrize
(
"use_key"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"head_stride_is_conti
n
gous"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"head_stride_is_contig
u
ous"
,
[
True
,
False
])
def
test_rotary_embedding_opcheck
(
dist_init
,
device
,
max_position
,
is_neox_style
,
rotary_dim
,
head_size
,
seq_len
,
use_key
,
head_stride_is_conti
n
gous
):
seq_len
,
use_key
,
head_stride_is_contig
u
ous
):
batch_size
=
1
base
=
10000
num_heads
=
7
...
...
@@ -52,7 +52,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
device
=
device
)
head_stride
=
head_size
+
(
64
if
head_stride_is_conti
n
gous
else
0
)
head_stride
=
head_size
+
(
64
if
head_stride_is_contig
u
ous
else
0
)
query
=
torch
.
randn
(
batch_size
,
seq_len
,
...
...
@@ -72,7 +72,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
# if we have a contiguous head stride, test the alternate
# [..., num_heads * head_dim] shape/layout
if
head_stride_is_conti
n
gous
:
if
head_stride_is_contig
u
ous
:
rotary_embedding_opcheck
(
rot
,
positions
,
query
.
flatten
(
start_dim
=-
2
),
key
.
flatten
(
start_dim
=-
2
)
if
use_key
else
None
)
tests/kernels/mamba/test_mamba_ssm_ssd.py
View file @
99324e25
...
...
@@ -107,15 +107,15 @@ def generate_random_inputs(batch_size,
return
A
,
dt
,
X
,
B
,
C
def
generate_continous_batched_examples
(
example_lens_by_batch
,
num_examples
,
full_length
,
last_taken
,
exhausted
,
n_heads
,
d_head
,
itype
,
device
=
'cuda'
):
def
generate_contin
u
ous_batched_examples
(
example_lens_by_batch
,
num_examples
,
full_length
,
last_taken
,
exhausted
,
n_heads
,
d_head
,
itype
,
device
=
'cuda'
):
# this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed
...
...
@@ -269,11 +269,10 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
exhausted
:
dict
=
{}
# map: eg -> boolean indicating example is exhausted
states
=
None
for
Y_min
,
cu_seqlens
,
seq_idx
,
(
A
,
dt
,
X
,
B
,
C
)
in
generate_continous_batched_examples
(
cases
,
num_examples
,
seqlen
,
last_taken
,
exhausted
,
n_heads
,
d_head
,
itype
):
for
Y_min
,
cu_seqlens
,
seq_idx
,
(
A
,
dt
,
X
,
B
,
C
)
in
generate_continuous_batched_examples
(
cases
,
num_examples
,
seqlen
,
last_taken
,
exhausted
,
n_heads
,
d_head
,
itype
):
chunk_indices
,
chunk_offsets
=
\
_query_start_loc_to_chunk_indices_offsets
(
...
...
tests/kernels/moe/
deepep
_utils.py
→
tests/kernels/moe/
parallel
_utils.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
DeepEP test utilities
"""
import
dataclasses
import
importlib
import
os
import
traceback
from
typing
import
Callable
,
Optional
...
...
@@ -13,6 +15,8 @@ from torch.multiprocessing import (
spawn
)
# pyright: ignore[reportPrivateImportUsage]
from
typing_extensions
import
Concatenate
,
ParamSpec
from
vllm.utils
import
get_open_port
has_deep_ep
=
importlib
.
util
.
find_spec
(
"deep_ep"
)
is
not
None
if
has_deep_ep
:
from
vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize
import
(
# noqa: E501
...
...
@@ -92,7 +96,7 @@ def parallel_launch(
world_size
,
world_size
,
0
,
"tcp://
localhost:29500
"
,
f
"tcp://
{
os
.
getenv
(
'LOCALHOST'
,
'localhost'
)
}
:
{
get_open_port
()
}
"
,
worker
,
)
+
args
,
nprocs
=
world_size
,
...
...
@@ -134,18 +138,14 @@ def make_deepep_ht_a2a(pg: ProcessGroup,
low_latency_mode
=
low_latency_mode
,
num_qps_per_rank
=
num_qps_per_rank
)
return
DeepEPHTPrepareAndFinalize
(
buffer
=
buffer
,
world_size
=
pgi
.
world_size
,
rank
=
pgi
.
rank
,
num_dispatchers
=
pgi
.
world_size
,
dp_size
=
dp_size
,
rank_expert_offset
=
pgi
.
rank
*
ht_args
.
num_local_experts
,
quant_dtype
=
q_dtype
,
block_shape
=
block_shape
)
ht_args
.
num_local_experts
)
def
make_deepep_ll_a2a
(
pg
:
ProcessGroup
,
pgi
:
ProcessGroupInfo
,
dp_size
:
int
,
deepep_ll_args
:
DeepEPLLArgs
,
q_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
):
...
...
@@ -165,11 +165,8 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
return
DeepEPLLPrepareAndFinalize
(
buffer
=
buffer
,
world_size
=
pgi
.
world_size
,
dp_size
=
dp_size
,
num_dispatchers
=
pgi
.
world_size
,
max_tokens_per_rank
=
deepep_ll_args
.
max_tokens_per_rank
,
quant_dtype
=
q_dtype
,
block_shape
=
block_shape
,
use_fp8_dispatch
=
deepep_ll_args
.
use_fp8_dispatch
,
)
...
...
@@ -187,5 +184,4 @@ def make_deepep_a2a(pg: ProcessGroup,
block_shape
)
assert
deepep_ll_args
is
not
None
return
make_deepep_ll_a2a
(
pg
,
pgi
,
dp_size
,
deepep_ll_args
,
q_dtype
,
block_shape
)
return
make_deepep_ll_a2a
(
pg
,
pgi
,
deepep_ll_args
,
q_dtype
,
block_shape
)
tests/kernels/moe/test_batched_moe.py
View file @
99324e25
...
...
@@ -2,18 +2,57 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Optional
import
pytest
import
torch
import
triton.language
as
tl
from
tests.kernels.moe.utils
import
(
batched_moe
,
make_quantized_test_activations
,
make_test_weights
,
naive_batched_moe
)
from
tests.kernels.quant_utils
import
native_batched_masked_quant_matmul
from
tests.kernels.utils
import
torch_experts
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
invoke_moe_batched_triton_kernel
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.platforms
import
current_platform
MNK_FACTORS
=
[
(
1
,
128
,
128
),
(
1
,
128
,
2048
),
(
1
,
512
,
512
),
(
1
,
1024
,
128
),
(
1
,
1024
,
2048
),
(
32
,
128
,
128
),
(
32
,
512
,
512
),
(
32
,
1024
,
2048
),
(
45
,
128
,
128
),
(
45
,
128
,
2048
),
(
45
,
512
,
512
),
(
45
,
1024
,
128
),
(
45
,
1024
,
2048
),
(
64
,
512
,
512
),
(
64
,
1024
,
2048
),
(
222
,
128
,
128
),
(
222
,
128
,
2048
),
(
222
,
1024
,
128
),
(
222
,
1024
,
2048
),
]
NUM_EXPERTS
=
[
8
,
64
]
TOP_KS
=
[
1
,
2
,
6
]
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
@
dataclass
class
BatchedMMConfig
:
dtype
:
torch
.
dtype
in_dtype
:
torch
.
dtype
quant_dtype
:
Optional
[
torch
.
dtype
]
out_dtype
:
torch
.
dtype
num_experts
:
int
max_tokens_per_expert
:
int
K
:
int
...
...
@@ -32,79 +71,129 @@ class BatchedMMTensors:
A
=
torch
.
randn
(
(
config
.
num_experts
,
config
.
max_tokens_per_expert
,
config
.
K
),
device
=
"cuda"
,
dtype
=
config
.
dtype
)
/
10
dtype
=
config
.
in_
dtype
)
/
10
B
=
torch
.
randn
((
config
.
num_experts
,
config
.
N
,
config
.
K
),
device
=
"cuda"
,
dtype
=
config
.
dtype
)
dtype
=
config
.
in_
dtype
)
C
=
torch
.
zeros
(
(
config
.
num_experts
,
config
.
max_tokens_per_expert
,
config
.
N
),
device
=
"cuda"
,
dtype
=
config
.
dtype
)
dtype
=
config
.
out_dtype
)
num_expert_tokens
=
torch
.
randint
(
low
=
0
,
high
=
config
.
max_tokens_per_expert
,
size
=
(
config
.
num_experts
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
return
BatchedMMTensors
(
A
,
B
,
C
,
num_expert_tokens
)
def
ref_impl
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
num_expert_tokens
:
torch
.
Tensor
)
->
torch
.
Tensor
:
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
8
,
16
,
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens_per_expert"
,
[
32
,
64
,
128
,
192
,
224
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"block_shape"
,
[
None
,
[
128
,
128
]])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
def
test_batched_mm
(
num_experts
:
int
,
max_tokens_per_expert
:
int
,
K
:
int
,
N
:
int
,
dtype
:
torch
.
dtype
,
block_shape
:
Optional
[
list
[
int
]],
per_act_token_quant
:
bool
):
current_platform
.
seed_everything
(
7
)
num_expert_tokens_cpu
=
num_expert_tokens
.
clone
()
num_expert_tokens_cpu
=
num_expert_tokens_cpu
.
to
(
device
=
"cpu"
)
num_experts
=
num_expert_tokens
.
size
(
0
)
use_fp8_w8a8
=
dtype
==
torch
.
float8_e4m3fn
for
e
in
range
(
num_experts
):
num_tokens
=
num_expert_tokens_cpu
[
e
]
C
[
e
,
:
num_tokens
,
:]
=
A
[
e
,
:
num_tokens
,
:]
@
B
[
e
].
transpose
(
0
,
1
)
if
(
per_act_token_quant
or
block_shape
is
not
None
)
and
not
use_fp8_w8a8
:
pytest
.
skip
(
"Don't test blocking for non-quantized types."
)
return
C
if
per_act_token_quant
and
block_shape
is
not
None
:
pytest
.
skip
(
"Skip illegal quantization test."
)
if
dtype
.
itemsize
==
1
:
act_dtype
=
torch
.
bfloat16
quant_dtype
=
dtype
else
:
act_dtype
=
dtype
quant_dtype
=
None
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
16
,
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens_per_expert"
,
[
32
,
64
,
128
,
192
,
224
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
256
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_batched_mm
(
num_experts
:
int
,
max_tokens_per_expert
:
int
,
K
:
int
,
N
:
int
,
dtype
:
torch
.
dtype
):
num_expert_tokens
=
torch
.
randint
(
low
=
0
,
high
=
max_tokens_per_expert
,
size
=
(
num_experts
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
config
=
BatchedMMConfig
(
dtype
,
num_experts
,
max_tokens_per_expert
,
K
,
N
)
tensors
=
BatchedMMTensors
.
make_tensors
(
config
)
A
,
A_q
,
A_scale
=
make_quantized_test_activations
(
num_experts
,
max_tokens_per_expert
,
K
,
in_dtype
=
act_dtype
,
quant_dtype
=
quant_dtype
,
block_shape
=
block_shape
,
per_act_token_quant
=
per_act_token_quant
,
)
test_output
=
tensors
.
C
ref_output
=
test_output
.
clone
()
B
,
B_q
,
B_scale
,
_
,
_
,
_
=
make_test_weights
(
num_experts
,
N
//
2
,
K
,
in_dtype
=
act_dtype
,
quant_dtype
=
quant_dtype
,
block_shape
=
block_shape
,
per_act_token_quant
=
per_act_token_quant
,
)
out_shape
=
(
num_experts
,
max_tokens_per_expert
,
N
)
test_output
=
torch
.
zeros
(
out_shape
,
dtype
=
act_dtype
,
device
=
"cuda"
)
ref_output
=
torch
.
zeros
(
out_shape
,
dtype
=
act_dtype
,
device
=
"cuda"
)
q_ref_output
=
torch
.
zeros
(
out_shape
,
dtype
=
act_dtype
,
device
=
"cuda"
)
compute_tl_dtype
=
{
torch
.
float16
:
tl
.
float16
,
torch
.
bfloat16
:
tl
.
bfloat16
,
torch
.
float32
:
tl
.
float32
}[
test_output
.
dtype
]
assert
A_q
.
dtype
==
B_q
.
dtype
invoke_moe_batched_triton_kernel
(
tensors
.
A
,
tensors
.
B
,
A_q
,
B_q
,
test_output
,
tensors
.
num_expert_tokens
,
num_expert_tokens
,
compute_tl_dtype
,
# Quantization data
Non
e
,
Non
e
,
A_scal
e
,
B_scal
e
,
None
,
# Quantization schemes
False
,
use_fp8_w8a8
,
False
,
False
,
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
16
})
"BLOCK_SIZE_K"
:
16
if
dtype
.
itemsize
>
1
else
32
},
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
ref_output
=
ref_impl
(
tensors
.
A
,
tensors
.
B
,
ref_output
,
tensors
.
num_expert_tokens
)
ref_output
=
native_batched_masked_quant_matmul
(
A
,
B
,
ref_output
,
num_expert_tokens
,
)
q_ref_output
=
native_batched_masked_quant_matmul
(
A_q
,
B_q
,
q_ref_output
,
num_expert_tokens
,
A_scale
,
B_scale
,
block_shape
,
per_act_token_quant
)
rtol
,
atol
=
{
torch
.
float16
:
(
6e-2
,
6e-2
),
...
...
@@ -112,4 +201,122 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
torch
.
float32
:
(
1e-2
,
1e-2
),
}[
test_output
.
dtype
]
torch
.
testing
.
assert_close
(
test_output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_close
(
ref_output
,
q_ref_output
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_close
(
test_output
,
q_ref_output
,
atol
=
atol
,
rtol
=
rtol
)
@
pytest
.
mark
.
parametrize
((
"m"
,
"n"
,
"k"
),
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"block_shape"
,
[
None
,
[
128
,
128
]])
@
pytest
.
mark
.
parametrize
(
"input_scales"
,
[
False
])
def
test_fused_moe_batched_experts
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]],
input_scales
:
bool
,
):
current_platform
.
seed_everything
(
7
)
use_fp8_w8a8
=
dtype
==
torch
.
float8_e4m3fn
if
topk
>
e
:
pytest
.
skip
(
"topk > e"
)
if
not
use_fp8_w8a8
and
(
per_act_token_quant
or
block_shape
is
not
None
):
pytest
.
skip
(
"Skip quantization test for non-quantized type"
)
if
per_act_token_quant
and
block_shape
is
not
None
:
pytest
.
skip
(
"Skip illegal quantization test."
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
if
dtype
.
itemsize
==
1
:
act_dtype
=
torch
.
bfloat16
quant_dtype
=
dtype
else
:
act_dtype
=
dtype
quant_dtype
=
None
w1_16
,
w1
,
w1_s
,
w2_16
,
w2
,
w2_s
=
make_test_weights
(
e
,
n
,
k
,
block_shape
=
block_shape
,
in_dtype
=
act_dtype
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
)
if
input_scales
and
quant_dtype
is
not
None
:
a1_scale
=
torch
.
tensor
(
1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
tensor
(
1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
else
:
a1_scale
=
None
a2_scale
=
None
with
set_current_vllm_config
(
vllm_config
):
topk_weight
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
baseline_output
=
torch_experts
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
batched_output
=
naive_batched_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
triton_output
=
batched_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
torch
.
testing
.
assert_close
(
batched_output
,
baseline_output
,
atol
=
3e-2
,
rtol
=
2e-2
)
torch
.
testing
.
assert_close
(
triton_output
,
batched_output
,
atol
=
2e-2
,
rtol
=
2e-2
)
tests/kernels/moe/test_block_fp8.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
tests.kernels.moe.utils
import
make_test_weights
from
tests.kernels.quant_utils
import
(
native_per_token_group_quant_fp8
,
native_w8a8_block_matmul
)
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
_valid_deep_gemm_shape
,
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
modular_triton_fused_moe
)
from
vllm.platforms
import
current_platform
dg_available
=
False
try
:
import
deep_gemm
dg_available
=
True
except
ImportError
:
pass
if
current_platform
.
get_device_capability
()
<
(
9
,
0
):
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
allow_module_level
=
True
)
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
# Test configurations
DTYPES
=
[
torch
.
bfloat16
]
# [torch.half, torch.bfloat16, torch.float32]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
MNK_FACTORS
=
[
(
1
,
128
,
128
),
(
1
,
512
,
512
),
(
1
,
128
,
7168
),
(
1
,
1024
,
7168
),
(
1
,
4608
,
128
),
(
1
,
4608
,
512
),
(
1
,
4608
,
7168
),
(
83
,
128
,
128
),
(
83
,
512
,
512
),
(
83
,
1024
,
7168
),
(
83
,
4608
,
512
),
(
83
,
4608
,
7168
),
(
128
,
128
,
128
),
(
128
,
512
,
512
),
(
128
,
1024
,
7168
),
(
128
,
4608
,
512
),
(
128
,
4608
,
7168
),
(
2048
,
128
,
128
),
(
2048
,
1024
,
7168
),
(
2048
,
4608
,
512
),
(
2048
,
4608
,
7168
),
(
8192
,
128
,
128
),
(
8192
,
512
,
512
),
(
8192
,
128
,
7168
),
(
8192
,
1024
,
7168
),
(
8192
,
4608
,
512
),
(
8192
,
4608
,
7168
),
]
MNK_FACTORS_DG
=
[
(
128
,
128
,
128
),
(
128
,
512
,
512
),
(
128
,
128
,
7168
),
(
128
,
1024
,
7168
),
(
128
,
4608
,
128
),
(
128
,
4608
,
512
),
(
128
,
4608
,
7168
),
(
192
,
128
,
128
),
(
192
,
512
,
512
),
(
192
,
1024
,
7168
),
(
192
,
4608
,
512
),
(
192
,
4608
,
7168
),
(
1335
,
128
,
128
),
(
1335
,
1024
,
7168
),
(
1335
,
4608
,
512
),
(
1335
,
4608
,
7168
),
(
2048
,
128
,
128
),
(
2048
,
512
,
512
),
(
2048
,
128
,
7168
),
(
2048
,
1024
,
7168
),
(
2048
,
4608
,
128
),
(
2048
,
4608
,
512
),
(
2048
,
4608
,
7168
),
]
BLOCK_SIZE
=
[[
128
,
128
]]
E
=
[
2
,
8
,
16
]
# [128, 256]
TOP_KS
=
[
1
,
2
,
6
]
SEEDS
=
[
0
]
def
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weight
,
topk_ids
,
block_shape
):
"""Fused moe with block-wise quantization using native torch."""
B
,
D
=
a
.
shape
topk
=
topk_ids
.
size
(
1
)
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
a_q
,
a_s
=
native_per_token_group_quant_fp8
(
a
,
block_k
)
a_q
=
a_q
.
to
(
torch
.
float32
)
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
inter_out
=
native_w8a8_block_matmul
(
a_q
[
mask
],
w1
[
i
],
a_s
[
mask
],
w1_s
[
i
],
block_shape
,
output_dtype
=
a
.
dtype
)
act_out
=
SiluAndMul
().
forward_native
(
inter_out
)
act_out_q
,
act_out_s
=
native_per_token_group_quant_fp8
(
act_out
,
block_k
)
out
[
mask
]
=
native_w8a8_block_matmul
(
act_out_q
,
w2
[
i
],
act_out_s
,
w2_s
[
i
],
block_shape
,
output_dtype
=
a
.
dtype
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
# Skip all tests if CUDA is not available
pytest
.
importorskip
(
"torch.cuda"
)
@
pytest
.
fixture
(
autouse
=
True
)
def
setup_cuda
():
torch
.
set_default_device
(
"cuda"
)
@
pytest
.
mark
.
parametrize
((
"M"
,
"N"
,
"K"
),
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"E"
,
E
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_fused_moe
(
M
,
N
,
K
,
E
,
topk
,
block_size
,
dtype
,
seed
,
monkeypatch
):
if
topk
>
E
:
pytest
.
skip
(
f
"Skipping test; topk=
{
topk
}
> E=
{
E
}
"
)
torch
.
manual_seed
(
seed
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"2048"
)
a
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
_
,
w1
,
w1_s
,
_
,
w2
,
w2_s
=
make_test_weights
(
E
,
N
,
K
,
dtype
,
torch
.
float8_e4m3fn
,
per_act_token_quant
=
False
,
block_shape
=
block_size
)
m_fused_moe
=
modular_triton_fused_moe
(
use_fp8_w8a8
=
True
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
per_act_token_quant
=
False
,
block_shape
=
block_size
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
# Set the context to avoid lots of warning spam.
with
set_current_vllm_config
(
vllm_config
):
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
,
block_size
,
)
out
=
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
block_shape
=
block_size
,
)
m_out
=
m_fused_moe
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
)
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
tol
=
0.035
if
M
<
40000
else
0.039
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
tol
,
rtol
=
tol
)
torch
.
testing
.
assert_close
(
m_out
,
ref_out
,
atol
=
tol
,
rtol
=
tol
)
@
pytest
.
mark
.
parametrize
((
"M"
,
"N"
,
"K"
),
MNK_FACTORS_DG
)
@
pytest
.
mark
.
parametrize
(
"E"
,
E
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
skipif
(
not
dg_available
,
reason
=
"DeepGemm kernels not available."
)
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_deep_gemm_fused_moe
(
M
,
N
,
K
,
E
,
topk
,
seed
,
monkeypatch
):
if
topk
>
E
:
pytest
.
skip
(
f
"Skipping test: topk=
{
topk
}
> E=
{
E
}
"
)
if
not
_valid_deep_gemm_shape
(
M
,
N
,
K
):
pytest
.
skip
(
f
"Skipping test: invalid size m=
{
M
}
, n=
{
N
}
, k=
{
K
}
"
)
chunk_size
=
1024
torch
.
manual_seed
(
seed
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
str
(
chunk_size
))
block_m
=
deep_gemm
.
get_m_alignment_for_contiguous_layout
()
block_size
=
[
block_m
,
block_m
]
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
_
,
w1
,
w1_s
,
_
,
w2
,
w2_s
=
make_test_weights
(
E
,
N
,
K
,
dtype
,
torch
.
float8_e4m3fn
,
per_act_token_quant
=
False
,
block_shape
=
block_size
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile
=
False
use_cudagraph
=
(
chunk_size
<
M
and
N
>=
1024
and
K
>=
1024
and
current_platform
.
is_cuda_alike
())
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
# Set the context to avoid lots of warning spam.
with
set_current_vllm_config
(
vllm_config
):
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
,
block_size
)
if
use_compile
:
deep_gemm_moe_fp8_fn
=
torch
.
compile
(
deep_gemm_moe_fp8
,
backend
=
"inductor"
,
fullgraph
=
True
)
torch
.
_dynamo
.
mark_dynamic
(
a
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
topk_weights
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
topk_ids
,
0
)
else
:
deep_gemm_moe_fp8_fn
=
deep_gemm_moe_fp8
out
=
deep_gemm_moe_fp8_fn
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
)
if
use_cudagraph
:
out
.
fill_
(
0
)
stream
=
torch
.
cuda
.
Stream
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
stream
):
out
=
deep_gemm_moe_fp8_fn
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
)
torch
.
cuda
.
synchronize
()
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
0.035
,
rtol
=
0.035
)
tests/kernels/moe/test_block_int8.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
tests.kernels.moe.utils
import
make_test_weights
from
tests.kernels.quant_utils
import
(
native_per_token_group_quant_int8
,
native_w8a8_block_matmul
)
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.platforms
import
current_platform
if
current_platform
.
get_device_capability
()
<
(
7
,
0
):
pytest
.
skip
(
"INT8 Triton requires CUDA 7.0 or higher"
,
allow_module_level
=
True
)
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
]
MNK_FACTORS
=
[
(
1
,
128
,
128
),
(
1
,
512
,
512
),
(
1
,
128
,
7168
),
(
1
,
1024
,
7168
),
(
1
,
4096
,
128
),
(
1
,
4096
,
512
),
(
1
,
4096
,
7168
),
(
33
,
128
,
128
),
(
33
,
512
,
512
),
(
33
,
128
,
7168
),
(
33
,
1024
,
7168
),
(
33
,
4096
,
128
),
(
33
,
4096
,
512
),
(
33
,
4096
,
7168
),
(
128
,
128
,
128
),
(
128
,
512
,
512
),
(
128
,
1024
,
7168
),
(
128
,
4096
,
512
),
(
128
,
4096
,
7168
),
(
222
,
128
,
128
),
(
222
,
512
,
512
),
(
222
,
1024
,
7168
),
(
222
,
4096
,
512
),
(
222
,
4096
,
7168
),
(
2048
,
128
,
128
),
(
2048
,
1024
,
7168
),
(
2048
,
4096
,
512
),
(
2048
,
4096
,
7168
),
]
E
=
[
8
,
24
]
TOP_KS
=
[
2
,
6
]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE
=
[[
128
,
128
]]
SEEDS
=
[
0
]
# For test
def
torch_w8a8_block_int8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_shape
):
"""This function performs fused moe with block-wise quantization using
native torch."""
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
a_q
,
a_s
=
native_per_token_group_quant_int8
(
a
,
block_k
)
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
inter_out
=
native_w8a8_block_matmul
(
a_q
[
mask
],
w1
[
i
],
a_s
[
mask
],
w1_s
[
i
],
block_shape
,
output_dtype
=
a
.
dtype
)
act_out
=
SiluAndMul
().
forward_native
(
inter_out
)
act_out_q
,
act_out_s
=
native_per_token_group_quant_int8
(
act_out
,
block_k
)
act_out
=
act_out
.
to
(
torch
.
float32
)
out
[
mask
]
=
native_w8a8_block_matmul
(
act_out_q
,
w2
[
i
],
act_out_s
,
w2_s
[
i
],
block_shape
,
output_dtype
=
a
.
dtype
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
@
pytest
.
fixture
(
autouse
=
True
,
scope
=
"module"
)
def
setup_cuda
():
"""Sets the default CUDA device for all tests in this module."""
torch
.
set_default_device
(
"cuda"
)
@
pytest
.
mark
.
parametrize
((
"M"
,
"N"
,
"K"
),
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"E"
,
E
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_w8a8_block_int8_fused_moe
(
M
,
N
,
K
,
E
,
topk
,
block_size
,
dtype
,
seed
):
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
native torch reference."""
torch
.
manual_seed
(
seed
)
a
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
_
,
w1
,
w1_s
,
_
,
w2
,
w2_s
=
make_test_weights
(
E
,
N
,
K
,
dtype
,
torch
.
int8
,
per_act_token_quant
=
False
,
block_shape
=
block_size
)
# Set the context to avoid lots of warning spam.
with
set_current_vllm_config
(
vllm_config
):
out
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
,
use_int8_w8a8
=
True
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
block_shape
=
block_size
,
)
ref_out
=
torch_w8a8_block_int8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_size
)
# Check results
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
0.065
,
rtol
=
0.065
)
tests/kernels/moe/test_cutlass_grouped_gemm.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# DeepGEMM Style Cutlass Grouped GEMM Test
# See https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py
import
random
import
pytest
import
torch
from
tests.kernels.utils
import
baseline_scaled_mm
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
def
per_token_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
pad_size
=
(
128
-
(
n
%
128
))
%
128
x
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
pad_size
),
value
=
0
)
if
pad_size
>
0
else
x
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
fp8_data
=
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
dtype
=
torch
.
float8_e4m3fn
)
return
fp8_data
.
view
(
m
,
n
+
pad_size
)[:,
:
n
],
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
((
cdiv
(
m
,
128
)
*
128
,
cdiv
(
n
,
128
)
*
128
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
@
pytest
.
mark
.
parametrize
(
"num_groups, expected_m_per_group, k, n"
,
[
(
4
,
8192
,
7168
,
4096
),
(
4
,
8192
,
2048
,
7168
),
(
8
,
4096
,
7168
,
4096
),
(
8
,
4096
,
2048
,
7168
),
(
32
,
1024
,
7168
,
4096
),
(
32
,
1024
,
2048
,
7168
),
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
skipif
(
(
lambda
x
:
x
is
None
or
x
.
to_int
()
!=
100
)(
current_platform
.
get_device_capability
()),
reason
=
"Block Scaled Grouped GEMM is only supported on SM100."
)
def
test_cutlass_grouped_gemm
(
num_groups
:
int
,
expected_m_per_group
:
int
,
k
:
int
,
n
:
int
,
out_dtype
:
torch
.
dtype
,
):
device
=
"cuda"
alignment
=
128
group_ms
=
[
int
(
expected_m_per_group
*
random
.
uniform
(
0.7
,
1.3
))
for
_
in
range
(
num_groups
)
]
m
=
sum
([
cdiv
(
m
,
alignment
)
*
alignment
for
m
in
group_ms
])
x
=
torch
.
randn
((
m
,
k
),
device
=
device
,
dtype
=
out_dtype
)
y
=
torch
.
randn
((
num_groups
,
n
,
k
),
device
=
device
,
dtype
=
out_dtype
)
out
=
torch
.
empty
((
m
,
n
),
device
=
device
,
dtype
=
out_dtype
)
ref_out
=
torch
.
randn
((
m
,
n
),
device
=
device
,
dtype
=
out_dtype
)
ep_offset
=
[
0
]
+
[
sum
(
group_ms
[:
i
])
for
i
in
range
(
1
,
num_groups
)]
+
[
m
]
pb_size
=
[]
for
i
in
range
(
num_groups
):
pb_size
.
append
([
ep_offset
[
i
+
1
]
-
ep_offset
[
i
],
n
,
k
])
problem_sizes
=
torch
.
tensor
(
pb_size
,
device
=
device
,
dtype
=
torch
.
int32
)
expert_offsets
=
torch
.
tensor
(
ep_offset
,
device
=
device
,
dtype
=
torch
.
int32
)
x_fp8
=
per_token_cast_to_fp8
(
x
)
y_fp8
=
(
torch
.
empty_like
(
y
,
dtype
=
torch
.
float8_e4m3fn
),
torch
.
empty
((
num_groups
,
cdiv
(
n
,
128
),
k
//
128
),
device
=
device
,
dtype
=
torch
.
float
))
for
i
in
range
(
num_groups
):
y_fp8
[
0
][
i
],
y_fp8
[
1
][
i
]
=
per_block_cast_to_fp8
(
y
[
i
])
for
i
in
range
(
num_groups
):
a
=
x_fp8
[
0
][
ep_offset
[
i
]:
ep_offset
[
i
+
1
]]
a_scale
=
x_fp8
[
1
][
ep_offset
[
i
]:
ep_offset
[
i
+
1
]]
b
=
y_fp8
[
0
][
i
].
t
()
b_scale
=
y_fp8
[
1
][
i
].
t
()
baseline
=
baseline_scaled_mm
(
a
,
b
,
a_scale
,
b_scale
,
out_dtype
)
ref_out
[
ep_offset
[
i
]:
ep_offset
[
i
+
1
]]
=
baseline
ops
.
cutlass_blockwise_scaled_grouped_mm
(
out
,
x_fp8
[
0
],
y_fp8
[
0
],
x_fp8
[
1
],
y_fp8
[
1
],
problem_sizes
,
expert_offsets
[:
-
1
],
)
torch
.
testing
.
assert_close
(
ref_out
,
out
,
atol
=
5e-1
,
rtol
=
1e-3
)
tests/kernels/moe/test_cutlass_moe.py
View file @
99324e25
...
...
@@ -29,6 +29,10 @@ MNK_FACTORS = [
(
224
,
1024
,
1536
),
(
224
,
3072
,
1024
),
(
224
,
3072
,
1536
),
(
32768
,
1024
,
1024
),
# These sizes trigger wrong answers.
#(7232, 2048, 5120),
#(40000, 2048, 5120),
]
vllm_config
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
...
...
@@ -93,11 +97,9 @@ class MOETensors8Bit(MOETensors):
n_b_scales
=
2
*
n
if
per_out_channel
else
1
k_b_scales
=
k
if
per_out_channel
else
1
# Get the right scale for tests.
_
,
a_scale
=
ops
.
scaled_fp8_quant
(
moe_tensors_fp16
.
a
,
use_per_token_if_dynamic
=
per_act_token
)
a_q
,
_
=
ops
.
scaled_fp8_quant
(
moe_tensors_fp16
.
a
,
a_scale
,
use_per_token_if_dynamic
=
per_act_token
)
a_q
,
a_scale
=
ops
.
scaled_fp8_quant
(
moe_tensors_fp16
.
a
,
None
,
use_per_token_if_dynamic
=
per_act_token
)
w1_q
=
torch
.
empty
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
q_dtype
)
w2_q
=
torch
.
empty
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
q_dtype
)
...
...
@@ -183,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
def
run_8_bit
(
moe_tensors
:
MOETensors8Bit
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
per_act_token
:
bool
,
num_local_experts
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
assert
not
any
([
t
is
None
for
t
in
[
...
...
@@ -199,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids'
:
topk_ids
,
'w1_scale'
:
moe_tensors
.
w1_scale
,
'w2_scale'
:
moe_tensors
.
w2_scale
,
'a1_scale'
:
moe_tensors
.
a_scale
'per_act_token'
:
per_act_token
,
'a1_scale'
:
None
#moe_tensors.a_scale
}
num_experts
=
moe_tensors
.
w1
.
size
(
0
)
...
...
@@ -231,8 +235,10 @@ def test_cutlass_moe_8_bit_no_graph(
topk
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
monkeypatch
,
):
current_platform
.
seed_everything
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"8192"
)
with
set_current_vllm_config
(
vllm_config
):
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
per_out_ch
)
...
...
@@ -248,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph(
triton_output
=
fused_experts
(
mt
.
a_d
,
mt
.
w1_d
,
mt
.
w2_d
,
topk_weights
,
topk_ids
)
cutlass_output
=
run_8_bit
(
mt
,
topk_weights
,
topk_ids
)
cutlass_output
=
run_8_bit
(
mt
,
topk_weights
,
topk_ids
,
per_act_token
)
# Note 5.5 only needed for larger problem sizes, 5 works ok for
# the rest.
torch
.
testing
.
assert_close
(
triton_output
,
cutlass_output
,
atol
=
5e-2
,
atol
=
5.
5e-2
,
rtol
=
1e-2
)
...
...
@@ -273,8 +281,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
topk
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
monkeypatch
,
):
current_platform
.
seed_everything
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"8192"
)
with
set_current_vllm_config
(
vllm_config
):
dtype
=
torch
.
half
...
...
@@ -295,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph(
stream
=
torch
.
cuda
.
Stream
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
stream
):
cutlass_output
=
run_8_bit
(
mt
,
topk_weights
,
topk_ids
)
cutlass_output
=
run_8_bit
(
mt
,
topk_weights
,
topk_ids
,
per_act_token
)
torch
.
cuda
.
synchronize
()
graph
.
replay
()
...
...
@@ -328,8 +339,10 @@ def test_cutlass_moe_8_bit_EP(
per_act_token
:
bool
,
per_out_channel
:
bool
,
ep_size
:
int
,
monkeypatch
,
):
current_platform
.
seed_everything
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"8192"
)
with
set_current_vllm_config
(
vllm_config
):
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
per_out_channel
)
...
...
@@ -349,6 +362,7 @@ def test_cutlass_moe_8_bit_EP(
cutlass_output
=
run_8_bit
(
mt
,
topk_weights
,
topk_ids
,
per_act_token
,
num_local_experts
=
e
//
ep_size
)
torch
.
testing
.
assert_close
(
triton_output
,
...
...
tests/kernels/moe/test_deepep_deepgemm_moe.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test DeepEP + DeepGEMM integration
Test DeepEP + DeepGEMM integration
DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case.
"""
import
dataclasses
import
importlib
from
typing
import
Optional
import
pytest
...
...
@@ -18,41 +18,34 @@ from vllm.config import VllmConfig, set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_ep
,
has_deep_gemm
from
.deepep_utils
import
ProcessGroupInfo
,
parallel_launch
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
from
.utils
import
make_test_weights
has_deep_ep
=
importlib
.
util
.
find_spec
(
"deep_ep"
)
is
not
None
try
:
import
deep_gemm
has_deep_gemm
=
True
except
ImportError
:
has_deep_gemm
=
False
if
has_deep_ep
:
if
has_deep_ep
():
from
vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize
import
(
# noqa: E501
DeepEPHTPrepareAndFinalize
)
from
vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize
import
(
# noqa: E501
DeepEPLLPrepareAndFinalize
)
from
.deepep_utils
import
DeepEPHTArgs
,
DeepEPLLArgs
,
make_deepep_a2a
from
.parallel_utils
import
DeepEPHTArgs
,
DeepEPLLArgs
,
make_deepep_a2a
if
has_deep_gemm
():
if
has_deep_gemm
:
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
BatchedDeepGemmExperts
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
DeepGemmExperts
)
requires_deep_ep
=
pytest
.
mark
.
skipif
(
not
has_deep_ep
,
not
has_deep_ep
()
,
reason
=
"Requires deep_ep kernels"
,
)
requires_deep_gemm
=
pytest
.
mark
.
skipif
(
not
has_deep_gemm
,
not
has_deep_gemm
()
,
reason
=
"Requires deep_gemm kernels"
,
)
...
...
@@ -66,25 +59,6 @@ def next_power_of_2(x):
return
2
**
math
.
ceil
(
math
.
log2
(
x
))
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
,
block_size_n
:
int
=
128
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
deep_gemm
.
ceil_div
(
m
,
128
)
*
128
,
deep_gemm
.
ceil_div
(
n
,
block_size_n
)
*
block_size_n
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
block_size_n
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
x_scaled_sub
=
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
()
scales
=
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
return
x_scaled_sub
,
scales
def
make_block_quant_fp8_weights
(
e
:
int
,
n
:
int
,
...
...
@@ -92,43 +66,11 @@ def make_block_quant_fp8_weights(
block_size
:
list
[
int
],
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Return weights
w1, w2,
w1q, w2q, w1_scale, w2_scale
Return weights w1q, w2q, w1_scale, w2_scale
"""
dtype
=
torch
.
bfloat16
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
w1_bf16
=
torch
.
randn
((
e
,
2
*
n
,
k
),
dtype
=
dtype
)
/
10
w1_bf16
=
w1_bf16
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
dtype
=
dtype
)
w2_bf16
=
torch
.
randn
((
e
,
k
,
n
),
dtype
=
dtype
)
/
10
w2_bf16
=
w2_bf16
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
dtype
=
dtype
)
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles_w1
=
((
2
*
n
)
+
block_n
-
1
)
//
block_n
k_tiles_w1
=
(
k
+
block_k
-
1
)
//
block_k
n_tiles_w2
=
(
k
+
block_n
-
1
)
//
block_n
k_tiles_w2
=
(
n
+
block_k
-
1
)
//
block_k
w1
=
torch
.
empty_like
(
w1_bf16
,
dtype
=
torch
.
float8_e4m3fn
)
w2
=
torch
.
empty_like
(
w2_bf16
,
dtype
=
torch
.
float8_e4m3fn
)
w1_s
=
torch
.
empty
((
e
,
n_tiles_w1
,
k_tiles_w1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w2_s
=
torch
.
empty
((
e
,
n_tiles_w2
,
k_tiles_w2
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
assert
w1_s
.
shape
==
(
e
,
(
2
*
n
+
127
)
//
128
,
(
k
+
127
)
//
128
)
assert
(
w2
.
shape
[
-
2
]
+
block_n
-
1
)
//
block_n
==
w2_s
.
shape
[
-
2
]
for
i
in
range
(
e
):
w1
[
i
],
w1_s
[
i
]
=
per_block_cast_to_fp8
(
w1_bf16
[
i
])
w2
[
i
],
w2_s
[
i
]
=
per_block_cast_to_fp8
(
w2_bf16
[
i
])
return
w1
,
w2
,
w1_s
,
w2_s
w1
,
w1q
,
w1_scale
,
w2
,
w2q
,
w2_scale
=
make_test_weights
(
e
,
n
,
k
,
torch
.
bfloat16
,
torch
.
float8_e4m3fn
,
block_size
)
return
w1q
,
w2q
,
w1_scale
,
w2_scale
@
dataclasses
.
dataclass
...
...
@@ -138,6 +80,7 @@ class TestConfig:
k
:
int
n
:
int
num_experts
:
int
per_act_token_quant
:
bool
block_size
:
list
[
int
]
# configs for testing low-latency kernels
low_latency
:
bool
...
...
@@ -156,8 +99,7 @@ class TestTensors:
def
make
(
config
:
TestConfig
,
rank
)
->
"TestTensors"
:
dtype
=
torch
.
bfloat16
topk
,
m
,
k
,
block_size
=
(
config
.
topk
,
config
.
m
,
config
.
k
,
config
.
block_size
)
topk
,
m
,
k
=
(
config
.
topk
,
config
.
m
,
config
.
k
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
...
...
@@ -165,9 +107,7 @@ class TestTensors:
rank_tokens
=
torch
.
randn
(
(
m
,
k
),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
/
10.0
rank_tokens
=
rank_tokens
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
)
block_k
=
block_size
[
1
]
_
,
rank_token_scales
=
per_token_group_quant_fp8
(
rank_tokens
,
block_k
)
rank_token_scales
=
None
topk_ids
=
torch
.
randint
(
low
=
0
,
...
...
@@ -207,10 +147,11 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype
=
q_dtype
,
block_shape
=
test_config
.
block_size
)
fused_experts
=
BatchedDeepGemmExperts
(
max_num_tokens
=
max_tokens_per_rank
,
world_size
=
pgi
.
world_size
,
dp_size
=
dp_size
,
block_shape
=
test_config
.
block_size
)
fused_experts
=
BatchedDeepGemmExperts
(
max_num_tokens
=
max_tokens_per_rank
,
num_dispatchers
=
pgi
.
world_size
//
dp_size
,
block_shape
=
test_config
.
block_size
,
per_act_token_quant
=
test_config
.
per_act_token_quant
)
mk
=
FusedMoEModularKernel
(
prepare_finalize
=
a2a
,
fused_experts
=
fused_experts
)
return
mk
...
...
@@ -432,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
"""
Tests for High-Throughput DeepEP + DeepGemm integration.
"""
import
deep_gemm
m
,
n
,
k
=
mnk
current_platform
.
seed_everything
(
7
)
...
...
@@ -448,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
k
=
k
,
n
=
n
,
num_experts
=
num_experts
,
per_act_token_quant
=
False
,
block_size
=
block_size
,
low_latency
=
False
,
use_fp8_dispatch
=
None
)
...
...
@@ -480,10 +423,14 @@ USE_FP8_DISPATCH = [False]
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
requires_deep_ep
@
requires_deep_gemm
def
test_ll_deepep_deepgemm_moe
(
mnk
:
tuple
[
int
,
int
,
int
],
num_experts
:
int
,
topk
:
int
,
use_fp8_dispatch
:
bool
,
block_size
:
list
[
int
],
world_dp_size
:
tuple
[
int
,
int
]):
def
test_ll_deepep_deepgemm_moe
(
mnk
:
tuple
[
int
,
int
,
int
],
num_experts
:
int
,
topk
:
int
,
use_fp8_dispatch
:
bool
,
block_size
:
list
[
int
],
world_dp_size
:
tuple
[
int
,
int
],
):
"""
Tests for Low-Latency DeepEP + DeepGemm integration.
"""
...
...
@@ -501,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
k
=
k
,
n
=
n
,
num_experts
=
num_experts
,
per_act_token_quant
=
False
,
block_size
=
block_size
,
low_latency
=
True
,
use_fp8_dispatch
=
use_fp8_dispatch
,
...
...
tests/kernels/moe/test_deepep_moe.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test deepep dispatch-combine logic
"""
import
dataclasses
import
importlib
from
typing
import
Optional
,
Union
import
pytest
...
...
@@ -22,21 +22,20 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_ep
from
.
deepep
_utils
import
ProcessGroupInfo
,
parallel_launch
from
.
parallel
_utils
import
ProcessGroupInfo
,
parallel_launch
has_deep_ep
=
importlib
.
util
.
find_spec
(
"deep_ep"
)
is
not
None
if
has_deep_ep
:
if
has_deep_ep
():
from
vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize
import
(
# noqa: E501
DeepEPHTPrepareAndFinalize
)
from
vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize
import
(
# noqa: E501
DeepEPLLPrepareAndFinalize
)
from
.
deepep
_utils
import
DeepEPHTArgs
,
DeepEPLLArgs
,
make_deepep_a2a
from
.
parallel
_utils
import
DeepEPHTArgs
,
DeepEPLLArgs
,
make_deepep_a2a
requires_deep_ep
=
pytest
.
mark
.
skipif
(
not
has_deep_ep
,
not
has_deep_ep
()
,
reason
=
"Requires deep_ep kernels"
,
)
...
...
@@ -104,10 +103,6 @@ class TestTensors:
rank_tokens
=
torch
.
randn
(
(
config
.
m
,
config
.
k
),
device
=
"cuda"
,
dtype
=
token_dtype
)
/
10
rank_token_scales
=
None
if
config
.
dtype
==
torch
.
float8_e4m3fn
:
# low_latency_mode kernels dont support per-token quant.
_
,
rank_token_scales
=
ops
.
scaled_fp8_quant
(
rank_tokens
,
use_per_token_if_dynamic
=
not
low_latency_mode
)
topk
=
torch
.
randint
(
low
=
0
,
high
=
config
.
num_experts
,
...
...
@@ -123,11 +118,18 @@ class TestTensors:
config
=
config
)
def
make_modular_kernel
(
pg
:
ProcessGroup
,
pgi
:
ProcessGroupInfo
,
low_latency_mode
:
bool
,
hidden_size
:
int
,
dp_size
:
int
,
num_experts
:
int
,
num_local_experts
:
int
,
q_dtype
:
Optional
[
torch
.
dtype
],
use_fp8_dispatch
:
bool
)
->
FusedMoEModularKernel
:
def
make_modular_kernel
(
pg
:
ProcessGroup
,
pgi
:
ProcessGroupInfo
,
low_latency_mode
:
bool
,
hidden_size
:
int
,
dp_size
:
int
,
num_experts
:
int
,
num_local_experts
:
int
,
q_dtype
:
Optional
[
torch
.
dtype
],
use_fp8_dispatch
:
bool
,
per_act_token_quant
:
bool
,
)
->
FusedMoEModularKernel
:
is_quantized
=
q_dtype
is
not
None
...
...
@@ -153,33 +155,47 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
deepep_ht_args
=
ht_args
,
deepep_ll_args
=
ll_args
)
num_dispatchers
=
pgi
.
world_size
//
dp_size
if
low_latency_mode
:
assert
not
per_act_token_quant
,
"not supported in ll mode"
fused_experts
=
BatchedTritonExperts
(
max_num_tokens
=
MAX_TOKENS_PER_RANK
,
world_size
=
pgi
.
world_size
,
dp_size
=
dp_size
,
num_dispatchers
=
num_dispatchers
,
use_fp8_w8a8
=
is_quantized
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
)
use_int4_w4a16
=
False
,
per_act_token_quant
=
False
,
)
else
:
fused_experts
=
TritonExperts
(
use_fp8_w8a8
=
is_quantized
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
per_channel_quant
=
False
)
fused_experts
=
TritonExperts
(
use_fp8_w8a8
=
is_quantized
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
per_act_token_quant
=
per_act_token_quant
,
)
mk
=
FusedMoEModularKernel
(
prepare_finalize
=
a2a
,
fused_experts
=
fused_experts
)
return
mk
def
deep_ep_moe_impl
(
pg
:
ProcessGroup
,
pgi
:
ProcessGroupInfo
,
low_latency_mode
:
bool
,
dp_size
:
int
,
test_tensors
:
TestTensors
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
num_experts
:
int
,
use_fp8_dispatch
:
bool
)
->
torch
.
Tensor
:
def
deep_ep_moe_impl
(
pg
:
ProcessGroup
,
pgi
:
ProcessGroupInfo
,
low_latency_mode
:
bool
,
dp_size
:
int
,
test_tensors
:
TestTensors
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
num_experts
:
int
,
use_fp8_dispatch
:
bool
,
per_act_token_quant
:
bool
,
)
->
torch
.
Tensor
:
num_local_experts
=
w1
.
size
(
0
)
...
...
@@ -201,11 +217,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype
=
torch
.
float8_e4m3fn
# Make modular kernel
mk
:
FusedMoEModularKernel
=
make_modular_kernel
(
pg
,
pgi
,
low_latency_mode
,
hidden_size
,
dp_size
,
num_experts
,
num_local_experts
,
q_dtype
,
use_fp8_dispatch
)
mk
:
FusedMoEModularKernel
=
make_modular_kernel
(
pg
,
pgi
,
low_latency_mode
,
hidden_size
,
dp_size
,
num_experts
,
num_local_experts
,
q_dtype
,
use_fp8_dispatch
,
per_act_token_quant
)
out_hidden_states
=
torch
.
empty_like
(
test_tensors
.
rank_tokens
)
total_num_tokens
=
test_tensors
.
rank_tokens
.
size
(
0
)
...
...
@@ -259,9 +273,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
return
out_hidden_states
def
torch_moe_impl
(
test_tensors
:
TestTensors
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
using_fp8_dispatch
:
bool
):
def
torch_moe_impl
(
test_tensors
:
TestTensors
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
using_fp8_dispatch
:
bool
,
per_act_token_quant
:
bool
,
):
a
,
topk_ids
,
topk_weights
=
(
test_tensors
.
rank_tokens
,
test_tensors
.
topk
,
test_tensors
.
topk_weights
)
...
...
@@ -269,6 +289,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
# The DeepEP implementation is requested to dispatch using FP8.
# For numerical stability for testing, emulate the fp8 dispatch by
# blockwise quant and de-quant.
assert
not
per_act_token_quant
a
=
test_tensors
.
rank_tokens
aq
,
aq_scale
=
per_token_group_quant_fp8
(
a
,
128
)
a
=
(
aq
.
view
(
-
1
,
128
).
to
(
torch
.
float32
)
*
aq_scale
.
view
(
-
1
,
1
)).
view
(
...
...
@@ -312,6 +333,7 @@ def _deep_ep_moe(
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
use_fp8_dispatch
:
bool
,
per_act_token_quant
:
bool
,
):
if
not
low_latency_mode
:
...
...
@@ -333,7 +355,8 @@ def _deep_ep_moe(
with
set_current_vllm_config
(
VllmConfig
()):
# Reference
torch_combined
=
torch_moe_impl
(
test_tensors
,
w1
,
w2
,
w1_scale
,
w2_scale
,
use_fp8_dispatch
)
w2_scale
,
use_fp8_dispatch
,
per_act_token_quant
)
# Splice experts for this rank.
num_local_experts
=
config
.
num_experts
//
pgi
.
world_size
...
...
@@ -358,6 +381,7 @@ def _deep_ep_moe(
w2_scale_ep
,
config
.
num_experts
,
use_fp8_dispatch
,
per_act_token_quant
,
)
torch
.
testing
.
assert_close
(
...
...
@@ -386,10 +410,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
@
requires_deep_ep
def
test_deep_ep_moe
(
dtype
:
torch
.
dtype
,
mnk
:
tuple
[
int
,
int
,
int
],
num_experts
:
int
,
topk
:
int
,
world_dp_size
:
tuple
[
int
,
int
]):
def
test_deep_ep_moe
(
dtype
:
torch
.
dtype
,
mnk
:
tuple
[
int
,
int
,
int
],
num_experts
:
int
,
topk
:
int
,
world_dp_size
:
tuple
[
int
,
int
],
per_act_token_quant
:
bool
,
):
low_latency_mode
=
False
use_fp8_dispatch
=
False
m
,
n
,
k
=
mnk
...
...
@@ -406,7 +436,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1
,
w2
,
w1_scale
,
w2_scale
=
make_weights
(
num_experts
,
n
,
k
,
dtype
)
parallel_launch
(
world_size
,
_deep_ep_moe
,
low_latency_mode
,
dp_size
,
config
,
w1
,
w2
,
w1_scale
,
w2_scale
,
use_fp8_dispatch
)
config
,
w1
,
w2
,
w1_scale
,
w2_scale
,
use_fp8_dispatch
,
per_act_token_quant
)
MNKs
=
[
...
...
@@ -456,4 +487,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1
,
w2
,
w1_scale
,
w2_scale
=
make_weights
(
num_experts
,
n
,
k
,
dtype
)
parallel_launch
(
world_size
,
_deep_ep_moe
,
low_latency_mode
,
dp_size
,
config
,
w1
,
w2
,
w1_scale
,
w2_scale
,
use_fp8_dispatch
)
config
,
w1
,
w2
,
w1_scale
,
w2_scale
,
use_fp8_dispatch
,
False
)
tests/kernels/moe/test_deepgemm.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit-test DeepGEMM FP8 kernels (no DeepEP).
Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts.
"""
import
importlib
import
math
import
pytest
import
torch
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.utils
import
cdiv
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
if
has_deep_gemm
:
import
deep_gemm
BLOCK_M
=
deep_gemm
.
get_m_alignment_for_contiguous_layout
()
BLOCK_SIZE
=
[
BLOCK_M
,
BLOCK_M
]
requires_deep_gemm
=
pytest
.
mark
.
skipif
(
not
has_deep_gemm
,
reason
=
"Requires deep_gemm kernels"
,
)
def
calc_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
):
x
,
y
=
x
.
double
(),
y
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
1
-
sim
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
,
block_size_n
:
int
=
128
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
cdiv
(
m
,
128
)
*
128
,
cdiv
(
n
,
block_size_n
)
*
block_size_n
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
block_size_n
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
x_scaled_sub
=
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
()
scales
=
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
return
x_scaled_sub
,
scales
def
make_block_quant_fp8_weights
(
e
:
int
,
n
:
int
,
k
:
int
,
block_size
:
list
[
int
],
):
"""
Generate (w1, w2) expert weights and their per-block scale tensors
in FP8 block-quantized format.
w1 shape: (E, 2N, K)
w2 shape: (E, K, N)
"""
dtype
=
torch
.
bfloat16
fp8_max
,
fp8_min
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
,
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
# bf16 reference weights
w1_bf16
=
torch
.
randn
(
e
,
2
*
n
,
k
,
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2_bf16
=
torch
.
randn
(
e
,
k
,
n
,
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1_bf16
.
clamp_
(
fp8_min
,
fp8_max
)
w2_bf16
.
clamp_
(
fp8_min
,
fp8_max
)
block_n
,
block_k
=
block_size
n_tiles_w1
=
math
.
ceil
((
2
*
n
)
/
block_n
)
k_tiles_w1
=
math
.
ceil
(
k
/
block_k
)
n_tiles_w2
=
math
.
ceil
(
k
/
block_n
)
k_tiles_w2
=
math
.
ceil
(
n
/
block_k
)
w1
=
torch
.
empty_like
(
w1_bf16
,
dtype
=
torch
.
float8_e4m3fn
)
w2
=
torch
.
empty_like
(
w2_bf16
,
dtype
=
torch
.
float8_e4m3fn
)
w1_s
=
torch
.
empty
(
e
,
n_tiles_w1
,
k_tiles_w1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w2_s
=
torch
.
empty
(
e
,
n_tiles_w2
,
k_tiles_w2
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
for
i
in
range
(
e
):
w1
[
i
],
w1_s
[
i
]
=
per_block_cast_to_fp8
(
w1_bf16
[
i
])
w2
[
i
],
w2_s
[
i
]
=
per_block_cast_to_fp8
(
w2_bf16
[
i
])
return
w1
,
w2
,
w1_s
,
w2_s
def
run_single_case
(
m
,
n
,
k
,
topk
,
num_experts
,
block_size
):
"""
Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
Triton baseline within tolerance.
"""
tokens_bf16
=
torch
.
randn
(
m
,
k
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
).
clamp_min_
(
-
1
).
clamp_max_
(
1
)
_
,
a1_scale
=
per_token_group_quant_fp8
(
tokens_bf16
,
block_size
[
1
])
# expert weight tensors
w1
,
w2
,
w1_s
,
w2_s
=
make_block_quant_fp8_weights
(
num_experts
,
n
,
k
,
block_size
)
router_logits
=
torch
.
randn
(
m
,
num_experts
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
=
torch
.
topk
(
router_logits
,
k
=
topk
,
dim
=-
1
)
topk_weights
=
torch
.
nn
.
functional
.
softmax
(
topk_weights
,
dim
=-
1
)
# triton referrence
out_triton
=
fused_experts
(
hidden_states
=
tokens_bf16
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
block_shape
=
block_size
,
allow_deep_gemm
=
False
,
)
# DeepGemm
out_deepgemm
=
fused_experts
(
hidden_states
=
tokens_bf16
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
block_shape
=
block_size
,
allow_deep_gemm
=
True
,
)
base
=
out_triton
.
abs
().
mean
()
atol
=
0.1
*
base
.
clamp
(
min
=
1e-2
)
# 10% of mean, but not lower than 1e-3
rtol
=
0.05
# ----- Compare -----
torch
.
testing
.
assert_close
(
out_deepgemm
.
to
(
torch
.
float32
),
out_triton
.
to
(
torch
.
float32
),
rtol
=
rtol
,
atol
=
float
(
atol
),
)
# Note: W1 has shape (E, 2N, K), so N = 512
# can trigger the deepgemm path.
MNKs
=
[
(
1024
,
512
,
128
),
(
1024
,
512
,
512
),
(
2048
,
512
,
512
),
(
512
,
1024
,
1024
),
(
512
,
2048
,
2048
),
(
4096
,
4096
,
1024
),
]
TOPKS
=
[
2
,
6
]
NUM_EXPERTS
=
[
32
]
@
pytest
.
mark
.
parametrize
(
"mnk"
,
MNKs
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOPKS
)
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
NUM_EXPERTS
)
@
requires_deep_gemm
def
test_deepgemm_vs_triton
(
mnk
,
topk
,
num_experts
,
monkeypatch
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_DEEP_GEMM"
,
"1"
)
_fused_moe_mod
=
importlib
.
import_module
(
"vllm.model_executor.layers.fused_moe.fused_moe"
)
call_counter
=
{
"cnt"
:
0
}
orig_fn
=
_fused_moe_mod
.
deep_gemm_moe_fp8
def
_spy_deep_gemm_moe_fp8
(
*
args
,
**
kwargs
):
call_counter
[
"cnt"
]
+=
1
return
orig_fn
(
*
args
,
**
kwargs
)
monkeypatch
.
setattr
(
_fused_moe_mod
,
"deep_gemm_moe_fp8"
,
_spy_deep_gemm_moe_fp8
)
m
,
n
,
k
=
mnk
if
topk
>
num_experts
:
pytest
.
skip
(
f
"topk=
{
topk
}
> num_experts=
{
num_experts
}
"
)
run_single_case
(
m
=
m
,
n
=
n
,
k
=
k
,
topk
=
topk
,
num_experts
=
num_experts
,
block_size
=
BLOCK_SIZE
,
)
# ensure that the DeepGEMM path was indeed taken.
assert
call_counter
[
"cnt"
]
==
1
,
\
f
"DeepGEMM path was not executed during the test. "
\
f
"Call counter:
{
call_counter
[
'cnt'
]
}
"
tests/kernels/moe/test_moe.py
View file @
99324e25
...
...
@@ -4,6 +4,9 @@
Run `pytest tests/kernels/test_moe.py`.
"""
import
functools
from
typing
import
Callable
,
Optional
,
Union
import
pytest
import
torch
from
torch.nn
import
Parameter
...
...
@@ -14,8 +17,11 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import
vllm.model_executor.layers.fused_moe
# noqa
from
tests.kernels.utils
import
opcheck
,
stack_and_dev
,
torch_moe
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.distributed.parallel_state
import
init_distributed_environment
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
modular_triton_fused_moe
)
from
vllm.model_executor.layers.fused_moe.moe_torch_iterative
import
(
fused_moe
as
iterative_moe
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
...
...
@@ -39,7 +45,76 @@ vllm_config.scheduler_config.max_num_seqs = 128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
,
1024
*
128
])
def
run_moe_test
(
baseline
:
Union
[
Callable
,
torch
.
Tensor
],
moe_fn
:
Callable
,
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
score
:
torch
.
Tensor
,
topk
:
int
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
padding
:
bool
=
False
,
use_compile
:
bool
=
False
,
use_cudagraph
:
bool
=
False
,
atol
:
float
=
2e-2
,
rtol
:
float
=
0
,
)
->
torch
.
Tensor
:
if
isinstance
(
baseline
,
torch
.
Tensor
):
baseline_output
=
baseline
else
:
baseline_output
=
baseline
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
# Pad the weight if moe padding is enabled
if
padding
:
w1
=
F
.
pad
(
w1
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
w2
=
F
.
pad
(
w2
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
if
use_compile
:
moe_fn
=
torch
.
compile
(
moe_fn
,
backend
=
"inductor"
,
fullgraph
=
True
)
torch
.
_dynamo
.
mark_dynamic
(
a
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
score
,
0
)
test_output
=
moe_fn
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
if
use_cudagraph
:
test_output
.
fill_
(
0
)
stream
=
torch
.
cuda
.
Stream
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
stream
):
test_output
=
moe_fn
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
torch
.
cuda
.
synchronize
()
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
test_output
,
baseline_output
,
atol
=
atol
,
rtol
=
rtol
)
return
baseline_output
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
,
32768
,
40000
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
...
...
@@ -47,6 +122,7 @@ vllm_config.scheduler_config.max_model_len = 8192
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"chunk_size"
,
[
8192
])
def
test_fused_moe
(
m
:
int
,
n
:
int
,
...
...
@@ -56,7 +132,21 @@ def test_fused_moe(
ep_size
:
int
,
dtype
:
torch
.
dtype
,
padding
:
bool
,
chunk_size
:
int
,
monkeypatch
,
):
current_platform
.
seed_everything
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
str
(
chunk_size
))
#
# Setup test data
#
#
# Setup test data
#
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
...
...
@@ -76,38 +166,70 @@ def test_fused_moe(
else
:
e_map
=
None
with
set_current_vllm_config
(
vllm_config
):
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
iterative_output
=
iterative_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
#
# Setup test functions
#
m_fused_moe_fn
=
modular_triton_fused_moe
(
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
per_act_token_quant
=
False
,
block_shape
=
None
)
def
m_fused_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
score
:
torch
.
Tensor
,
topk
:
int
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
return
m_fused_moe_fn
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
fused_moe_fn
=
functools
.
partial
(
fused_moe
,
renormalize
=
False
)
#
# Run tests
#
runner
=
functools
.
partial
(
run_moe_test
,
a
=
a
,
w1
=
w1
,
w2
=
w2
,
score
=
score
,
topk
=
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
padding
=
padding
,
)
# Pad the weight if moe padding is enabled
if
padding
:
w1
=
F
.
pad
(
w1
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
torch
.
cuda
.
empty_cache
()
w2
=
F
.
pad
(
w2
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
torch
.
cuda
.
empty_cache
()
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile
=
False
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
use_cudagraph
=
(
n
>=
1024
and
k
>=
1024
and
current_platform
.
is_cuda_alike
())
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
iterative_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
with
set_current_vllm_config
(
vllm_config
):
baseline_output
=
runner
(
torch_moe
,
iterative_moe
)
runner
(
baseline_output
,
fused_moe_fn
,
use_compile
=
use_compile
,
use_cudagraph
=
use_cudagraph
)
runner
(
baseline_output
,
m_fused_moe
,
use_compile
=
use_compile
,
use_cudagraph
=
use_cudagraph
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
32
,
222
])
...
...
@@ -217,7 +339,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
w1_zp
=
w1_qzeros
if
has_zp
else
None
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
group_size
])
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
,
e_map
)
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
,
expert_map
=
e_map
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
...
...
@@ -243,46 +370,59 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
if
dtype
==
torch
.
float32
:
pytest
.
skip
(
"AITER ROCm test skip for float32"
)
monkeypatch
.
setenv
(
'RANK'
,
"0"
)
monkeypatch
.
setenv
(
'LOCAL_RANK'
,
"0"
)
monkeypatch
.
setenv
(
'WORLD_SIZE'
,
"1"
)
monkeypatch
.
setenv
(
'MASTER_ADDR'
,
'localhost'
)
monkeypatch
.
setenv
(
'MASTER_PORT'
,
'12345'
)
init_distributed_environment
()
# Instantiate our and huggingface's MoE blocks
config
=
MixtralConfig
()
hf_moe
=
MixtralSparseMoeBlock
(
config
).
to
(
dtype
).
to
(
"cuda"
)
vllm_moe
=
MixtralMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
params_dtype
=
dtype
,
tp_size
=
1
,
dp_size
=
1
,
).
cuda
()
# Load the weights
vllm_moe
.
gate
.
weight
.
data
[:]
=
hf_moe
.
gate
.
weight
.
data
for
i
in
range
(
config
.
num_local_experts
):
weights
=
(
hf_moe
.
experts
[
i
].
w1
.
weight
.
data
,
hf_moe
.
experts
[
i
].
w3
.
weight
.
data
)
vllm_moe
.
experts
.
w13_weight
[
i
][:]
=
torch
.
cat
(
weights
,
dim
=
0
)
vllm_moe
.
experts
.
w2_weight
[
i
][:]
=
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs
=
torch
.
randn
((
1
,
64
,
config
.
hidden_size
)).
to
(
dtype
).
to
(
"cuda"
)
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs
=
hf_inputs
.
flatten
(
0
,
1
)
vllm_config
.
compilation_config
.
static_forward_context
=
dict
()
with
(
set_current_vllm_config
(
vllm_config
),
set_forward_context
(
None
,
vllm_config
)):
config
=
MixtralConfig
()
hf_moe
=
MixtralSparseMoeBlock
(
config
).
to
(
dtype
).
to
(
"cuda"
)
vllm_moe
=
MixtralMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
params_dtype
=
dtype
,
tp_size
=
1
,
dp_size
=
1
,
).
cuda
()
# Load the weights
vllm_moe
.
gate
.
weight
.
data
[:]
=
hf_moe
.
gate
.
weight
.
data
for
i
in
range
(
config
.
num_local_experts
):
weights
=
(
hf_moe
.
experts
[
i
].
w1
.
weight
.
data
,
hf_moe
.
experts
[
i
].
w3
.
weight
.
data
)
vllm_moe
.
experts
.
w13_weight
[
i
][:]
=
torch
.
cat
(
weights
,
dim
=
0
)
vllm_moe
.
experts
.
w2_weight
[
i
][:]
=
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs
=
torch
.
randn
(
(
1
,
64
,
config
.
hidden_size
)).
to
(
dtype
).
to
(
"cuda"
)
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs
=
hf_inputs
.
flatten
(
0
,
1
)
# Pad the weight if moe padding is enabled
if
padding
:
vllm_moe
.
experts
.
w13_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w13_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
vllm_moe
.
experts
.
w2_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w2_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
# Run forward passes for both MoE blocks
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_inputs
)
# Pad the weight if moe padding is enabled
if
padding
:
vllm_moe
.
experts
.
w13_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w13_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
vllm_moe
.
experts
.
w2_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w2_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
# Run forward passes for both MoE blocks
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_inputs
)
mixtral_moe_tol
=
{
torch
.
float32
:
1e-3
,
...
...
@@ -525,7 +665,12 @@ def test_fused_marlin_moe(
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
with
set_current_vllm_config
(
vllm_config
):
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
e_map
)
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
expert_map
=
e_map
)
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
a
,
...
...
tests/kernels/moe/test_moe_align_block_size.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size_triton
)
@
pytest
.
mark
.
parametrize
(
"block_size,num_tokens,topk,num_experts"
,
list
(
itertools
.
product
(
[
32
,
64
,
128
,
256
],
# block_size
[
1
,
3
,
7
,
16
,
256
,
2256
,
4096
,
],
# num_tokens
[
1
,
4
,
16
,
64
],
# topk
[
64
,
160
,
256
,
257
,
260
,
264
],
# num_experts
)),
)
def
test_moe_align_block_size_compare_implementations
(
block_size
,
num_tokens
,
topk
,
num_experts
):
topk_ids
=
torch
.
stack
([
torch
.
randperm
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[:
topk
]
for
_
in
range
(
num_tokens
)
])
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids_cuda
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids_cuda
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
expert_ids_cuda
=
torch
.
zeros
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad_cuda
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids_triton
=
torch
.
empty_like
(
sorted_ids_cuda
)
sorted_ids_triton
.
fill_
(
topk_ids
.
numel
())
expert_ids_triton
=
torch
.
zeros_like
(
expert_ids_cuda
)
num_tokens_post_pad_triton
=
torch
.
empty_like
(
num_tokens_post_pad_cuda
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_cuda
,
expert_ids_cuda
,
num_tokens_post_pad_cuda
,
)
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_triton
,
expert_ids_triton
,
num_tokens_post_pad_triton
,
)
assert
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_triton
),
(
f
"Expert IDs mismatch for block_size=
{
block_size
}
, "
f
"num_tokens=
{
num_tokens
}
, topk=
{
topk
}
\n
"
f
"CUDA expert_ids:
{
expert_ids_cuda
}
\n
"
f
"Triton expert_ids:
{
expert_ids_triton
}
"
)
assert
torch
.
allclose
(
num_tokens_post_pad_cuda
,
num_tokens_post_pad_triton
),
(
f
"Num tokens post pad mismatch for block_size=
{
block_size
}
, "
f
"num_tokens=
{
num_tokens
}
, topk=
{
topk
}
\n
"
f
"CUDA num_tokens_post_pad:
{
num_tokens_post_pad_cuda
}
\n
"
f
"Triton num_tokens_post_pad:
{
num_tokens_post_pad_triton
}
"
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Prev
1
…
12
13
14
15
16
17
18
19
20
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment