Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99324e25
Commit
99324e25
authored
Jul 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.2' into v0.9.2-ori
parents
cc7f22a8
a5dd03c1
Changes
475
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1570 additions
and
304 deletions
+1570
-304
tests/kernels/attention/test_attention.py
tests/kernels/attention/test_attention.py
+18
-1
tests/kernels/attention/test_attention_selector.py
tests/kernels/attention/test_attention_selector.py
+42
-5
tests/kernels/attention/test_cache.py
tests/kernels/attention/test_cache.py
+8
-8
tests/kernels/attention/test_encoder_decoder_attn.py
tests/kernels/attention/test_encoder_decoder_attn.py
+2
-2
tests/kernels/attention/test_mla_decode_cpu.py
tests/kernels/attention/test_mla_decode_cpu.py
+1
-4
tests/kernels/attention/test_rocm_attention_selector.py
tests/kernels/attention/test_rocm_attention_selector.py
+4
-2
tests/kernels/attention/test_triton_decode_attention.py
tests/kernels/attention/test_triton_decode_attention.py
+1
-4
tests/kernels/core/test_rotary_embedding.py
tests/kernels/core/test_rotary_embedding.py
+4
-4
tests/kernels/mamba/test_mamba_ssm_ssd.py
tests/kernels/mamba/test_mamba_ssm_ssd.py
+13
-14
tests/kernels/moe/parallel_utils.py
tests/kernels/moe/parallel_utils.py
+9
-13
tests/kernels/moe/test_batched_moe.py
tests/kernels/moe/test_batched_moe.py
+244
-37
tests/kernels/moe/test_block_fp8.py
tests/kernels/moe/test_block_fp8.py
+296
-0
tests/kernels/moe/test_block_int8.py
tests/kernels/moe/test_block_int8.py
+147
-0
tests/kernels/moe/test_cutlass_grouped_gemm.py
tests/kernels/moe/test_cutlass_grouped_gemm.py
+116
-0
tests/kernels/moe/test_cutlass_moe.py
tests/kernels/moe/test_cutlass_moe.py
+23
-9
tests/kernels/moe/test_deepep_deepgemm_moe.py
tests/kernels/moe/test_deepep_deepgemm_moe.py
+34
-86
tests/kernels/moe/test_deepep_moe.py
tests/kernels/moe/test_deepep_moe.py
+76
-44
tests/kernels/moe/test_deepgemm.py
tests/kernels/moe/test_deepgemm.py
+226
-0
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+216
-71
tests/kernels/moe/test_moe_align_block_size.py
tests/kernels/moe/test_moe_align_block_size.py
+90
-0
No files found.
Too many changes to show.
To preserve performance only
475 of 475+
files are displayed.
Plain diff
Email patch
tests/kernels/attention/test_attention.py
View file @
99324e25
...
@@ -10,6 +10,7 @@ import torch
...
@@ -10,6 +10,7 @@ import torch
from
tests.kernels.allclose_default
import
get_default_atol
,
get_default_rtol
from
tests.kernels.allclose_default
import
get_default_atol
,
get_default_rtol
from
tests.kernels.utils
import
opcheck
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.layer
import
Attention
,
MultiHeadAttention
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
get_max_shared_memory_bytes
from
vllm.utils
import
get_max_shared_memory_bytes
...
@@ -449,7 +450,8 @@ def test_multi_query_kv_attention(
...
@@ -449,7 +450,8 @@ def test_multi_query_kv_attention(
start
+=
seq_len
start
+=
seq_len
# xformers.AttentionBias to Tensor for use in reference impl.
# xformers.AttentionBias to Tensor for use in reference impl.
alibi_bias
=
[
alibi_bias
=
[
b
.
materialize
(
b
.
shape
,
device
=
device
).
squeeze
()
for
b
in
attn_bias
b
.
materialize
((
1
,
num_query_heads
,
i
,
i
),
device
=
device
).
squeeze
()
for
b
,
i
in
zip
(
attn_bias
,
seq_lens
)
]
]
else
:
else
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
...
@@ -506,3 +508,18 @@ def test_multi_query_kv_attention_with_alibi(
...
@@ -506,3 +508,18 @@ def test_multi_query_kv_attention_with_alibi(
device
,
device
,
use_alibi
=
True
,
use_alibi
=
True
,
)
)
@
pytest
.
mark
.
parametrize
(
"attention_cls"
,
[
Attention
,
MultiHeadAttention
])
def
test_num_heads_not_divisble_by_num_kv_heads
(
attention_cls
:
type
)
->
None
:
head_size
=
64
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
num_heads
=
16
num_kv_heads
=
5
with
pytest
.
raises
(
AssertionError
):
_
=
attention_cls
(
num_heads
=
num_heads
,
head_size
=
head_size
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
,
)
tests/kernels/attention/test_attention_selector.py
View file @
99324e25
...
@@ -106,10 +106,8 @@ def test_env(
...
@@ -106,10 +106,8 @@ def test_env(
block_size
,
block_size
,
False
,
False
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
if
use_v1
and
name
!=
"TRITON_MLA"
:
expected
=
f
"
{
name
}
_VLLM_V1"
if
use_v1
else
name
assert
backend
.
get_name
()
==
f
"
{
name
}
_VLLM_V1"
assert
backend
.
get_name
()
==
expected
else
:
assert
backend
.
get_name
()
==
name
else
:
else
:
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
get_attn_backend
(
16
,
get_attn_backend
(
16
,
...
@@ -173,7 +171,7 @@ def test_env(
...
@@ -173,7 +171,7 @@ def test_env(
expected
=
"FLASHINFER_VLLM_V1"
if
use_v1
else
name
expected
=
"FLASHINFER_VLLM_V1"
if
use_v1
else
name
assert
backend
.
get_name
()
==
expected
assert
backend
.
get_name
()
==
expected
else
:
else
:
backend
=
get_attn_backend
(
16
,
backend
=
get_attn_backend
(
32
,
torch
.
float16
,
torch
.
float16
,
torch
.
float16
,
torch
.
float16
,
block_size
,
block_size
,
...
@@ -182,6 +180,45 @@ def test_env(
...
@@ -182,6 +180,45 @@ def test_env(
expected
=
"FLASH_ATTN_VLLM_V1"
if
use_v1
else
name
expected
=
"FLASH_ATTN_VLLM_V1"
if
use_v1
else
name
assert
backend
.
get_name
()
==
expected
assert
backend
.
get_name
()
==
expected
if
use_v1
:
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
block_size
,
False
,
use_mla
=
use_mla
)
assert
backend
.
get_name
()
==
"FLEX_ATTENTION"
,
(
"Should fallback to FlexAttention if head size is "
"not supported by FlashAttention"
)
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"use_v1"
,
[
True
,
False
])
def
test_fp32_fallback
(
device
:
str
,
use_v1
:
bool
,
monkeypatch
:
pytest
.
MonkeyPatch
,
):
"""Test attention backend selection with fp32."""
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
if
use_v1
else
"0"
)
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
torch
.
float32
,
16
,
False
)
assert
(
backend
.
get_name
()
==
"TORCH_SDPA_VLLM_V1"
if
use_v1
else
"TORCH_SDPA"
)
elif
device
==
"cuda"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
torch
.
float32
,
16
,
False
)
assert
(
backend
.
get_name
()
==
"FLEX_ATTENTION"
if
use_v1
else
"XFORMERS"
)
def
test_flash_attn
(
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_flash_attn
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test FlashAttn validation."""
"""Test FlashAttn validation."""
...
...
tests/kernels/attention/test_cache.py
View file @
99324e25
...
@@ -72,8 +72,8 @@ def test_copy_blocks(
...
@@ -72,8 +72,8 @@ def test_copy_blocks(
# destination blocks.
# destination blocks.
assert
2
*
num_mappings
<=
num_blocks
assert
2
*
num_mappings
<=
num_blocks
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
remaini
n
g_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
dst_blocks
=
random
.
sample
(
remaini
n
g_blocks
,
2
*
num_mappings
)
block_mapping
:
list
[
tuple
[
int
,
int
]]
=
[]
block_mapping
:
list
[
tuple
[
int
,
int
]]
=
[]
for
i
in
range
(
num_mappings
):
for
i
in
range
(
num_mappings
):
src
=
src_blocks
[
i
]
src
=
src_blocks
[
i
]
...
@@ -189,12 +189,12 @@ def test_reshape_and_cache(
...
@@ -189,12 +189,12 @@ def test_reshape_and_cache(
# Run the reference implementation.
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indic
i
es
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indices
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indic
i
es_lst
=
block_indic
i
es
.
cpu
().
tolist
()
block_indices_lst
=
block_indices
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
slot_mapping
%
block_size
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
for
i
in
range
(
num_tokens
):
block_idx
=
block_indic
i
es_lst
[
i
]
block_idx
=
block_indices_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
...
@@ -322,12 +322,12 @@ def test_reshape_and_cache_flash(
...
@@ -322,12 +322,12 @@ def test_reshape_and_cache_flash(
kv_dtype
=
kv_cache_dtype
)
kv_dtype
=
kv_cache_dtype
)
# Run the reference implementation.
# Run the reference implementation.
block_indic
i
es
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indices
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indic
i
es_lst
=
block_indic
i
es
.
cpu
().
tolist
()
block_indices_lst
=
block_indices
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
slot_mapping
%
block_size
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
for
i
in
range
(
num_tokens
):
block_idx
=
block_indic
i
es_lst
[
i
]
block_idx
=
block_indices_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
if
kv_cache_layout
==
"NHD"
:
if
kv_cache_layout
==
"NHD"
:
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
...
...
tests/kernels/attention/test_encoder_decoder_attn.py
View file @
99324e25
...
@@ -46,7 +46,7 @@ CUDA_DEVICE = "cuda:0"
...
@@ -46,7 +46,7 @@ CUDA_DEVICE = "cuda:0"
MAX_DEC_SEQ_LENS
=
[
128
]
MAX_DEC_SEQ_LENS
=
[
128
]
MAX_ENC_SEQ_LENS
=
[
128
]
MAX_ENC_SEQ_LENS
=
[
128
]
# Narrow te
e
st-cases for unsupported-scenario
# Narrow test-cases for unsupported-scenario
# tests
# tests
HEAD_SIZES_FOR_UNSUPP
=
[
HEAD_SIZES
[
0
]]
HEAD_SIZES_FOR_UNSUPP
=
[
HEAD_SIZES
[
0
]]
...
@@ -99,7 +99,7 @@ class TestResources(NamedTuple):
...
@@ -99,7 +99,7 @@ class TestResources(NamedTuple):
Attributes:
Attributes:
* scale: 1/sqrt(d) scale factor for attn
* scale: 1/sqrt(d) scale factor for attn
* attn_backend: implementati
n
o of abstraction
* attn_backend: implementatio
ns
of abstraction
attention interface using
attention interface using
a particular kernel library
a particular kernel library
i.e. XFormers
i.e. XFormers
...
...
tests/kernels/attention/test_mla_decode_cpu.py
View file @
99324e25
...
@@ -7,10 +7,7 @@ from torch import Tensor
...
@@ -7,10 +7,7 @@ from torch import Tensor
import
vllm._custom_ops
as
ops
import
vllm._custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
def
ref_mla
(
def
ref_mla
(
...
...
tests/kernels/attention/test_rocm_attention_selector.py
View file @
99324e25
...
@@ -35,7 +35,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
...
@@ -35,7 +35,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"TRITON_MLA"
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"TRITON_MLA"
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
False
,
True
)
False
,
True
)
assert
backend
.
get_name
()
==
"TRITON_MLA"
assert
(
backend
.
get_name
()
==
"TRITON_MLA"
or
backend
.
get_name
()
==
"TRITON_MLA_VLLM_V1"
)
# If attention backend is None
# If attention backend is None
# If use_mla is true
# If use_mla is true
...
@@ -43,7 +44,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
...
@@ -43,7 +44,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
None
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
None
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
False
,
True
)
False
,
True
)
assert
backend
.
get_name
()
==
"TRITON_MLA"
assert
(
backend
.
get_name
()
==
"TRITON_MLA"
or
backend
.
get_name
()
==
"TRITON_MLA_VLLM_V1"
)
# change the attention backend to AITER MLA
# change the attention backend to AITER MLA
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"ROCM_AITER_MLA"
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"ROCM_AITER_MLA"
)
...
...
tests/kernels/attention/test_triton_decode_attention.py
View file @
99324e25
...
@@ -5,10 +5,7 @@ import pytest
...
@@ -5,10 +5,7 @@ import pytest
import
torch
import
torch
from
vllm.attention.ops.triton_decode_attention
import
decode_attention_fwd
from
vllm.attention.ops.triton_decode_attention
import
decode_attention_fwd
from
vllm.utils
import
cdiv
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
@
pytest
.
mark
.
parametrize
(
"B"
,
[
3
,
5
])
@
pytest
.
mark
.
parametrize
(
"B"
,
[
3
,
5
])
...
...
tests/kernels/core/test_rotary_embedding.py
View file @
99324e25
...
@@ -39,10 +39,10 @@ def rotary_embedding_opcheck(rot,
...
@@ -39,10 +39,10 @@ def rotary_embedding_opcheck(rot,
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
32
,
108
])
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
32
,
108
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
11
,
1024
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
11
,
1024
])
@
pytest
.
mark
.
parametrize
(
"use_key"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_key"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"head_stride_is_conti
n
gous"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"head_stride_is_contig
u
ous"
,
[
True
,
False
])
def
test_rotary_embedding_opcheck
(
dist_init
,
device
,
max_position
,
def
test_rotary_embedding_opcheck
(
dist_init
,
device
,
max_position
,
is_neox_style
,
rotary_dim
,
head_size
,
is_neox_style
,
rotary_dim
,
head_size
,
seq_len
,
use_key
,
head_stride_is_conti
n
gous
):
seq_len
,
use_key
,
head_stride_is_contig
u
ous
):
batch_size
=
1
batch_size
=
1
base
=
10000
base
=
10000
num_heads
=
7
num_heads
=
7
...
@@ -52,7 +52,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
...
@@ -52,7 +52,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
positions
=
torch
.
randint
(
0
,
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
max_position
,
(
batch_size
,
seq_len
),
device
=
device
)
device
=
device
)
head_stride
=
head_size
+
(
64
if
head_stride_is_conti
n
gous
else
0
)
head_stride
=
head_size
+
(
64
if
head_stride_is_contig
u
ous
else
0
)
query
=
torch
.
randn
(
batch_size
,
query
=
torch
.
randn
(
batch_size
,
seq_len
,
seq_len
,
...
@@ -72,7 +72,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
...
@@ -72,7 +72,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
# if we have a contiguous head stride, test the alternate
# if we have a contiguous head stride, test the alternate
# [..., num_heads * head_dim] shape/layout
# [..., num_heads * head_dim] shape/layout
if
head_stride_is_conti
n
gous
:
if
head_stride_is_contig
u
ous
:
rotary_embedding_opcheck
(
rotary_embedding_opcheck
(
rot
,
positions
,
query
.
flatten
(
start_dim
=-
2
),
rot
,
positions
,
query
.
flatten
(
start_dim
=-
2
),
key
.
flatten
(
start_dim
=-
2
)
if
use_key
else
None
)
key
.
flatten
(
start_dim
=-
2
)
if
use_key
else
None
)
tests/kernels/mamba/test_mamba_ssm_ssd.py
View file @
99324e25
...
@@ -107,15 +107,15 @@ def generate_random_inputs(batch_size,
...
@@ -107,15 +107,15 @@ def generate_random_inputs(batch_size,
return
A
,
dt
,
X
,
B
,
C
return
A
,
dt
,
X
,
B
,
C
def
generate_continous_batched_examples
(
example_lens_by_batch
,
def
generate_contin
u
ous_batched_examples
(
example_lens_by_batch
,
num_examples
,
num_examples
,
full_length
,
full_length
,
last_taken
,
last_taken
,
exhausted
,
exhausted
,
n_heads
,
n_heads
,
d_head
,
d_head
,
itype
,
itype
,
device
=
'cuda'
):
device
=
'cuda'
):
# this function generates a random examples of certain length
# this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed
# and then cut according to "example_lens_by_batch" and feed
...
@@ -269,11 +269,10 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
...
@@ -269,11 +269,10 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
exhausted
:
dict
=
{}
# map: eg -> boolean indicating example is exhausted
exhausted
:
dict
=
{}
# map: eg -> boolean indicating example is exhausted
states
=
None
states
=
None
for
Y_min
,
cu_seqlens
,
seq_idx
,
(
A
,
dt
,
X
,
B
,
for
Y_min
,
cu_seqlens
,
seq_idx
,
(
C
)
in
generate_continous_batched_examples
(
A
,
dt
,
X
,
B
,
C
)
in
generate_continuous_batched_examples
(
cases
,
num_examples
,
seqlen
,
cases
,
num_examples
,
seqlen
,
last_taken
,
exhausted
,
n_heads
,
last_taken
,
exhausted
,
n_heads
,
d_head
,
itype
):
d_head
,
itype
):
chunk_indices
,
chunk_offsets
=
\
chunk_indices
,
chunk_offsets
=
\
_query_start_loc_to_chunk_indices_offsets
(
_query_start_loc_to_chunk_indices_offsets
(
...
...
tests/kernels/moe/
deepep
_utils.py
→
tests/kernels/moe/
parallel
_utils.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
"""
DeepEP test utilities
DeepEP test utilities
"""
"""
import
dataclasses
import
dataclasses
import
importlib
import
importlib
import
os
import
traceback
import
traceback
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
...
@@ -13,6 +15,8 @@ from torch.multiprocessing import (
...
@@ -13,6 +15,8 @@ from torch.multiprocessing import (
spawn
)
# pyright: ignore[reportPrivateImportUsage]
spawn
)
# pyright: ignore[reportPrivateImportUsage]
from
typing_extensions
import
Concatenate
,
ParamSpec
from
typing_extensions
import
Concatenate
,
ParamSpec
from
vllm.utils
import
get_open_port
has_deep_ep
=
importlib
.
util
.
find_spec
(
"deep_ep"
)
is
not
None
has_deep_ep
=
importlib
.
util
.
find_spec
(
"deep_ep"
)
is
not
None
if
has_deep_ep
:
if
has_deep_ep
:
from
vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize
import
(
# noqa: E501
from
vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize
import
(
# noqa: E501
...
@@ -92,7 +96,7 @@ def parallel_launch(
...
@@ -92,7 +96,7 @@ def parallel_launch(
world_size
,
world_size
,
world_size
,
world_size
,
0
,
0
,
"tcp://
localhost:29500
"
,
f
"tcp://
{
os
.
getenv
(
'LOCALHOST'
,
'localhost'
)
}
:
{
get_open_port
()
}
"
,
worker
,
worker
,
)
+
args
,
)
+
args
,
nprocs
=
world_size
,
nprocs
=
world_size
,
...
@@ -134,18 +138,14 @@ def make_deepep_ht_a2a(pg: ProcessGroup,
...
@@ -134,18 +138,14 @@ def make_deepep_ht_a2a(pg: ProcessGroup,
low_latency_mode
=
low_latency_mode
,
low_latency_mode
=
low_latency_mode
,
num_qps_per_rank
=
num_qps_per_rank
)
num_qps_per_rank
=
num_qps_per_rank
)
return
DeepEPHTPrepareAndFinalize
(
buffer
=
buffer
,
return
DeepEPHTPrepareAndFinalize
(
buffer
=
buffer
,
world_size
=
pgi
.
world_size
,
num_dispatchers
=
pgi
.
world_size
,
rank
=
pgi
.
rank
,
dp_size
=
dp_size
,
dp_size
=
dp_size
,
rank_expert_offset
=
pgi
.
rank
*
rank_expert_offset
=
pgi
.
rank
*
ht_args
.
num_local_experts
,
ht_args
.
num_local_experts
)
quant_dtype
=
q_dtype
,
block_shape
=
block_shape
)
def
make_deepep_ll_a2a
(
pg
:
ProcessGroup
,
def
make_deepep_ll_a2a
(
pg
:
ProcessGroup
,
pgi
:
ProcessGroupInfo
,
pgi
:
ProcessGroupInfo
,
dp_size
:
int
,
deepep_ll_args
:
DeepEPLLArgs
,
deepep_ll_args
:
DeepEPLLArgs
,
q_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
q_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
):
block_shape
:
Optional
[
list
[
int
]]
=
None
):
...
@@ -165,11 +165,8 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
...
@@ -165,11 +165,8 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
return
DeepEPLLPrepareAndFinalize
(
return
DeepEPLLPrepareAndFinalize
(
buffer
=
buffer
,
buffer
=
buffer
,
world_size
=
pgi
.
world_size
,
num_dispatchers
=
pgi
.
world_size
,
dp_size
=
dp_size
,
max_tokens_per_rank
=
deepep_ll_args
.
max_tokens_per_rank
,
max_tokens_per_rank
=
deepep_ll_args
.
max_tokens_per_rank
,
quant_dtype
=
q_dtype
,
block_shape
=
block_shape
,
use_fp8_dispatch
=
deepep_ll_args
.
use_fp8_dispatch
,
use_fp8_dispatch
=
deepep_ll_args
.
use_fp8_dispatch
,
)
)
...
@@ -187,5 +184,4 @@ def make_deepep_a2a(pg: ProcessGroup,
...
@@ -187,5 +184,4 @@ def make_deepep_a2a(pg: ProcessGroup,
block_shape
)
block_shape
)
assert
deepep_ll_args
is
not
None
assert
deepep_ll_args
is
not
None
return
make_deepep_ll_a2a
(
pg
,
pgi
,
dp_size
,
deepep_ll_args
,
q_dtype
,
return
make_deepep_ll_a2a
(
pg
,
pgi
,
deepep_ll_args
,
q_dtype
,
block_shape
)
block_shape
)
tests/kernels/moe/test_batched_moe.py
View file @
99324e25
...
@@ -2,18 +2,57 @@
...
@@ -2,18 +2,57 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
import
pytest
import
pytest
import
torch
import
torch
import
triton.language
as
tl
import
triton.language
as
tl
from
tests.kernels.moe.utils
import
(
batched_moe
,
make_quantized_test_activations
,
make_test_weights
,
naive_batched_moe
)
from
tests.kernels.quant_utils
import
native_batched_masked_quant_matmul
from
tests.kernels.utils
import
torch_experts
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
invoke_moe_batched_triton_kernel
)
invoke_moe_batched_triton_kernel
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.platforms
import
current_platform
MNK_FACTORS
=
[
(
1
,
128
,
128
),
(
1
,
128
,
2048
),
(
1
,
512
,
512
),
(
1
,
1024
,
128
),
(
1
,
1024
,
2048
),
(
32
,
128
,
128
),
(
32
,
512
,
512
),
(
32
,
1024
,
2048
),
(
45
,
128
,
128
),
(
45
,
128
,
2048
),
(
45
,
512
,
512
),
(
45
,
1024
,
128
),
(
45
,
1024
,
2048
),
(
64
,
512
,
512
),
(
64
,
1024
,
2048
),
(
222
,
128
,
128
),
(
222
,
128
,
2048
),
(
222
,
1024
,
128
),
(
222
,
1024
,
2048
),
]
NUM_EXPERTS
=
[
8
,
64
]
TOP_KS
=
[
1
,
2
,
6
]
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
@
dataclass
@
dataclass
class
BatchedMMConfig
:
class
BatchedMMConfig
:
dtype
:
torch
.
dtype
in_dtype
:
torch
.
dtype
quant_dtype
:
Optional
[
torch
.
dtype
]
out_dtype
:
torch
.
dtype
num_experts
:
int
num_experts
:
int
max_tokens_per_expert
:
int
max_tokens_per_expert
:
int
K
:
int
K
:
int
...
@@ -32,79 +71,129 @@ class BatchedMMTensors:
...
@@ -32,79 +71,129 @@ class BatchedMMTensors:
A
=
torch
.
randn
(
A
=
torch
.
randn
(
(
config
.
num_experts
,
config
.
max_tokens_per_expert
,
config
.
K
),
(
config
.
num_experts
,
config
.
max_tokens_per_expert
,
config
.
K
),
device
=
"cuda"
,
device
=
"cuda"
,
dtype
=
config
.
dtype
)
/
10
dtype
=
config
.
in_
dtype
)
/
10
B
=
torch
.
randn
((
config
.
num_experts
,
config
.
N
,
config
.
K
),
B
=
torch
.
randn
((
config
.
num_experts
,
config
.
N
,
config
.
K
),
device
=
"cuda"
,
device
=
"cuda"
,
dtype
=
config
.
dtype
)
dtype
=
config
.
in_
dtype
)
C
=
torch
.
zeros
(
C
=
torch
.
zeros
(
(
config
.
num_experts
,
config
.
max_tokens_per_expert
,
config
.
N
),
(
config
.
num_experts
,
config
.
max_tokens_per_expert
,
config
.
N
),
device
=
"cuda"
,
device
=
"cuda"
,
dtype
=
config
.
dtype
)
dtype
=
config
.
out_dtype
)
num_expert_tokens
=
torch
.
randint
(
low
=
0
,
num_expert_tokens
=
torch
.
randint
(
low
=
0
,
high
=
config
.
max_tokens_per_expert
,
high
=
config
.
max_tokens_per_expert
,
size
=
(
config
.
num_experts
,
),
size
=
(
config
.
num_experts
,
),
device
=
"cuda"
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
return
BatchedMMTensors
(
A
,
B
,
C
,
num_expert_tokens
)
return
BatchedMMTensors
(
A
,
B
,
C
,
num_expert_tokens
)
def
ref_impl
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
8
,
16
,
32
])
num_expert_tokens
:
torch
.
Tensor
)
->
torch
.
Tensor
:
@
pytest
.
mark
.
parametrize
(
"max_tokens_per_expert"
,
[
32
,
64
,
128
,
192
,
224
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"block_shape"
,
[
None
,
[
128
,
128
]])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
def
test_batched_mm
(
num_experts
:
int
,
max_tokens_per_expert
:
int
,
K
:
int
,
N
:
int
,
dtype
:
torch
.
dtype
,
block_shape
:
Optional
[
list
[
int
]],
per_act_token_quant
:
bool
):
current_platform
.
seed_everything
(
7
)
num_expert_tokens_cpu
=
num_expert_tokens
.
clone
()
use_fp8_w8a8
=
dtype
==
torch
.
float8_e4m3fn
num_expert_tokens_cpu
=
num_expert_tokens_cpu
.
to
(
device
=
"cpu"
)
num_experts
=
num_expert_tokens
.
size
(
0
)
for
e
in
range
(
num_experts
):
if
(
per_act_token_quant
or
block_shape
is
not
None
)
and
not
use_fp8_w8a8
:
num_tokens
=
num_expert_tokens_cpu
[
e
]
pytest
.
skip
(
"Don't test blocking for non-quantized types."
)
C
[
e
,
:
num_tokens
,
:]
=
A
[
e
,
:
num_tokens
,
:]
@
B
[
e
].
transpose
(
0
,
1
)
return
C
if
per_act_token_quant
and
block_shape
is
not
None
:
pytest
.
skip
(
"Skip illegal quantization test."
)
if
dtype
.
itemsize
==
1
:
act_dtype
=
torch
.
bfloat16
quant_dtype
=
dtype
else
:
act_dtype
=
dtype
quant_dtype
=
None
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
16
,
32
])
num_expert_tokens
=
torch
.
randint
(
low
=
0
,
@
pytest
.
mark
.
parametrize
(
"max_tokens_per_expert"
,
high
=
max_tokens_per_expert
,
[
32
,
64
,
128
,
192
,
224
,
256
,
512
])
size
=
(
num_experts
,
),
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
,
256
,
1024
])
device
=
"cuda"
,
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
256
,
512
,
1024
])
dtype
=
torch
.
int32
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
def
test_batched_mm
(
num_experts
:
int
,
max_tokens_per_expert
:
int
,
K
:
int
,
N
:
int
,
dtype
:
torch
.
dtype
):
config
=
BatchedMMConfig
(
dtype
,
num_experts
,
max_tokens_per_expert
,
K
,
N
)
A
,
A_q
,
A_scale
=
make_quantized_test_activations
(
tensors
=
BatchedMMTensors
.
make_tensors
(
config
)
num_experts
,
max_tokens_per_expert
,
K
,
in_dtype
=
act_dtype
,
quant_dtype
=
quant_dtype
,
block_shape
=
block_shape
,
per_act_token_quant
=
per_act_token_quant
,
)
test_output
=
tensors
.
C
B
,
B_q
,
B_scale
,
_
,
_
,
_
=
make_test_weights
(
ref_output
=
test_output
.
clone
()
num_experts
,
N
//
2
,
K
,
in_dtype
=
act_dtype
,
quant_dtype
=
quant_dtype
,
block_shape
=
block_shape
,
per_act_token_quant
=
per_act_token_quant
,
)
out_shape
=
(
num_experts
,
max_tokens_per_expert
,
N
)
test_output
=
torch
.
zeros
(
out_shape
,
dtype
=
act_dtype
,
device
=
"cuda"
)
ref_output
=
torch
.
zeros
(
out_shape
,
dtype
=
act_dtype
,
device
=
"cuda"
)
q_ref_output
=
torch
.
zeros
(
out_shape
,
dtype
=
act_dtype
,
device
=
"cuda"
)
compute_tl_dtype
=
{
compute_tl_dtype
=
{
torch
.
float16
:
tl
.
float16
,
torch
.
float16
:
tl
.
float16
,
torch
.
bfloat16
:
tl
.
bfloat16
,
torch
.
bfloat16
:
tl
.
bfloat16
,
torch
.
float32
:
tl
.
float32
torch
.
float32
:
tl
.
float32
}[
test_output
.
dtype
]
}[
test_output
.
dtype
]
assert
A_q
.
dtype
==
B_q
.
dtype
invoke_moe_batched_triton_kernel
(
invoke_moe_batched_triton_kernel
(
tensors
.
A
,
A_q
,
tensors
.
B
,
B_q
,
test_output
,
test_output
,
tensors
.
num_expert_tokens
,
num_expert_tokens
,
compute_tl_dtype
,
compute_tl_dtype
,
# Quantization data
# Quantization data
Non
e
,
A_scal
e
,
Non
e
,
B_scal
e
,
None
,
None
,
# Quantization schemes
# Quantization schemes
False
,
use_fp8_w8a8
,
False
,
False
,
False
,
False
,
config
=
{
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
16
"BLOCK_SIZE_K"
:
16
if
dtype
.
itemsize
>
1
else
32
})
},
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
ref_output
=
ref_impl
(
tensors
.
A
,
tensors
.
B
,
ref_output
,
ref_output
=
native_batched_masked_quant_matmul
(
tensors
.
num_expert_tokens
)
A
,
B
,
ref_output
,
num_expert_tokens
,
)
q_ref_output
=
native_batched_masked_quant_matmul
(
A_q
,
B_q
,
q_ref_output
,
num_expert_tokens
,
A_scale
,
B_scale
,
block_shape
,
per_act_token_quant
)
rtol
,
atol
=
{
rtol
,
atol
=
{
torch
.
float16
:
(
6e-2
,
6e-2
),
torch
.
float16
:
(
6e-2
,
6e-2
),
...
@@ -112,4 +201,122 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
...
@@ -112,4 +201,122 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
torch
.
float32
:
(
1e-2
,
1e-2
),
torch
.
float32
:
(
1e-2
,
1e-2
),
}[
test_output
.
dtype
]
}[
test_output
.
dtype
]
torch
.
testing
.
assert_close
(
test_output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_close
(
ref_output
,
q_ref_output
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_close
(
test_output
,
q_ref_output
,
atol
=
atol
,
rtol
=
rtol
)
@
pytest
.
mark
.
parametrize
((
"m"
,
"n"
,
"k"
),
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"block_shape"
,
[
None
,
[
128
,
128
]])
@
pytest
.
mark
.
parametrize
(
"input_scales"
,
[
False
])
def
test_fused_moe_batched_experts
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]],
input_scales
:
bool
,
):
current_platform
.
seed_everything
(
7
)
use_fp8_w8a8
=
dtype
==
torch
.
float8_e4m3fn
if
topk
>
e
:
pytest
.
skip
(
"topk > e"
)
if
not
use_fp8_w8a8
and
(
per_act_token_quant
or
block_shape
is
not
None
):
pytest
.
skip
(
"Skip quantization test for non-quantized type"
)
if
per_act_token_quant
and
block_shape
is
not
None
:
pytest
.
skip
(
"Skip illegal quantization test."
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
if
dtype
.
itemsize
==
1
:
act_dtype
=
torch
.
bfloat16
quant_dtype
=
dtype
else
:
act_dtype
=
dtype
quant_dtype
=
None
w1_16
,
w1
,
w1_s
,
w2_16
,
w2
,
w2_s
=
make_test_weights
(
e
,
n
,
k
,
block_shape
=
block_shape
,
in_dtype
=
act_dtype
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
)
if
input_scales
and
quant_dtype
is
not
None
:
a1_scale
=
torch
.
tensor
(
1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
tensor
(
1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
else
:
a1_scale
=
None
a2_scale
=
None
with
set_current_vllm_config
(
vllm_config
):
topk_weight
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
baseline_output
=
torch_experts
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
batched_output
=
naive_batched_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
triton_output
=
batched_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
quant_dtype
=
quant_dtype
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
)
torch
.
testing
.
assert_close
(
batched_output
,
baseline_output
,
atol
=
3e-2
,
rtol
=
2e-2
)
torch
.
testing
.
assert_close
(
triton_output
,
batched_output
,
atol
=
2e-2
,
rtol
=
2e-2
)
tests/kernels/moe/test_block_fp8.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
tests.kernels.moe.utils
import
make_test_weights
from
tests.kernels.quant_utils
import
(
native_per_token_group_quant_fp8
,
native_w8a8_block_matmul
)
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
_valid_deep_gemm_shape
,
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
modular_triton_fused_moe
)
from
vllm.platforms
import
current_platform
dg_available
=
False
try
:
import
deep_gemm
dg_available
=
True
except
ImportError
:
pass
if
current_platform
.
get_device_capability
()
<
(
9
,
0
):
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
allow_module_level
=
True
)
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
# Test configurations
DTYPES
=
[
torch
.
bfloat16
]
# [torch.half, torch.bfloat16, torch.float32]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
MNK_FACTORS
=
[
(
1
,
128
,
128
),
(
1
,
512
,
512
),
(
1
,
128
,
7168
),
(
1
,
1024
,
7168
),
(
1
,
4608
,
128
),
(
1
,
4608
,
512
),
(
1
,
4608
,
7168
),
(
83
,
128
,
128
),
(
83
,
512
,
512
),
(
83
,
1024
,
7168
),
(
83
,
4608
,
512
),
(
83
,
4608
,
7168
),
(
128
,
128
,
128
),
(
128
,
512
,
512
),
(
128
,
1024
,
7168
),
(
128
,
4608
,
512
),
(
128
,
4608
,
7168
),
(
2048
,
128
,
128
),
(
2048
,
1024
,
7168
),
(
2048
,
4608
,
512
),
(
2048
,
4608
,
7168
),
(
8192
,
128
,
128
),
(
8192
,
512
,
512
),
(
8192
,
128
,
7168
),
(
8192
,
1024
,
7168
),
(
8192
,
4608
,
512
),
(
8192
,
4608
,
7168
),
]
MNK_FACTORS_DG
=
[
(
128
,
128
,
128
),
(
128
,
512
,
512
),
(
128
,
128
,
7168
),
(
128
,
1024
,
7168
),
(
128
,
4608
,
128
),
(
128
,
4608
,
512
),
(
128
,
4608
,
7168
),
(
192
,
128
,
128
),
(
192
,
512
,
512
),
(
192
,
1024
,
7168
),
(
192
,
4608
,
512
),
(
192
,
4608
,
7168
),
(
1335
,
128
,
128
),
(
1335
,
1024
,
7168
),
(
1335
,
4608
,
512
),
(
1335
,
4608
,
7168
),
(
2048
,
128
,
128
),
(
2048
,
512
,
512
),
(
2048
,
128
,
7168
),
(
2048
,
1024
,
7168
),
(
2048
,
4608
,
128
),
(
2048
,
4608
,
512
),
(
2048
,
4608
,
7168
),
]
BLOCK_SIZE
=
[[
128
,
128
]]
E
=
[
2
,
8
,
16
]
# [128, 256]
TOP_KS
=
[
1
,
2
,
6
]
SEEDS
=
[
0
]
def
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weight
,
topk_ids
,
block_shape
):
"""Fused moe with block-wise quantization using native torch."""
B
,
D
=
a
.
shape
topk
=
topk_ids
.
size
(
1
)
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
a_q
,
a_s
=
native_per_token_group_quant_fp8
(
a
,
block_k
)
a_q
=
a_q
.
to
(
torch
.
float32
)
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
inter_out
=
native_w8a8_block_matmul
(
a_q
[
mask
],
w1
[
i
],
a_s
[
mask
],
w1_s
[
i
],
block_shape
,
output_dtype
=
a
.
dtype
)
act_out
=
SiluAndMul
().
forward_native
(
inter_out
)
act_out_q
,
act_out_s
=
native_per_token_group_quant_fp8
(
act_out
,
block_k
)
out
[
mask
]
=
native_w8a8_block_matmul
(
act_out_q
,
w2
[
i
],
act_out_s
,
w2_s
[
i
],
block_shape
,
output_dtype
=
a
.
dtype
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
# Skip all tests if CUDA is not available
pytest
.
importorskip
(
"torch.cuda"
)
@
pytest
.
fixture
(
autouse
=
True
)
def
setup_cuda
():
torch
.
set_default_device
(
"cuda"
)
@
pytest
.
mark
.
parametrize
((
"M"
,
"N"
,
"K"
),
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"E"
,
E
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_fused_moe
(
M
,
N
,
K
,
E
,
topk
,
block_size
,
dtype
,
seed
,
monkeypatch
):
if
topk
>
E
:
pytest
.
skip
(
f
"Skipping test; topk=
{
topk
}
> E=
{
E
}
"
)
torch
.
manual_seed
(
seed
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"2048"
)
a
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
_
,
w1
,
w1_s
,
_
,
w2
,
w2_s
=
make_test_weights
(
E
,
N
,
K
,
dtype
,
torch
.
float8_e4m3fn
,
per_act_token_quant
=
False
,
block_shape
=
block_size
)
m_fused_moe
=
modular_triton_fused_moe
(
use_fp8_w8a8
=
True
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
per_act_token_quant
=
False
,
block_shape
=
block_size
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
# Set the context to avoid lots of warning spam.
with
set_current_vllm_config
(
vllm_config
):
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
,
block_size
,
)
out
=
fused_experts
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
block_shape
=
block_size
,
)
m_out
=
m_fused_moe
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
)
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
tol
=
0.035
if
M
<
40000
else
0.039
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
tol
,
rtol
=
tol
)
torch
.
testing
.
assert_close
(
m_out
,
ref_out
,
atol
=
tol
,
rtol
=
tol
)
@
pytest
.
mark
.
parametrize
((
"M"
,
"N"
,
"K"
),
MNK_FACTORS_DG
)
@
pytest
.
mark
.
parametrize
(
"E"
,
E
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
skipif
(
not
dg_available
,
reason
=
"DeepGemm kernels not available."
)
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_deep_gemm_fused_moe
(
M
,
N
,
K
,
E
,
topk
,
seed
,
monkeypatch
):
if
topk
>
E
:
pytest
.
skip
(
f
"Skipping test: topk=
{
topk
}
> E=
{
E
}
"
)
if
not
_valid_deep_gemm_shape
(
M
,
N
,
K
):
pytest
.
skip
(
f
"Skipping test: invalid size m=
{
M
}
, n=
{
N
}
, k=
{
K
}
"
)
chunk_size
=
1024
torch
.
manual_seed
(
seed
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
str
(
chunk_size
))
block_m
=
deep_gemm
.
get_m_alignment_for_contiguous_layout
()
block_size
=
[
block_m
,
block_m
]
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
_
,
w1
,
w1_s
,
_
,
w2
,
w2_s
=
make_test_weights
(
E
,
N
,
K
,
dtype
,
torch
.
float8_e4m3fn
,
per_act_token_quant
=
False
,
block_shape
=
block_size
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile
=
False
use_cudagraph
=
(
chunk_size
<
M
and
N
>=
1024
and
K
>=
1024
and
current_platform
.
is_cuda_alike
())
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
# Set the context to avoid lots of warning spam.
with
set_current_vllm_config
(
vllm_config
):
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
,
block_size
)
if
use_compile
:
deep_gemm_moe_fp8_fn
=
torch
.
compile
(
deep_gemm_moe_fp8
,
backend
=
"inductor"
,
fullgraph
=
True
)
torch
.
_dynamo
.
mark_dynamic
(
a
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
topk_weights
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
topk_ids
,
0
)
else
:
deep_gemm_moe_fp8_fn
=
deep_gemm_moe_fp8
out
=
deep_gemm_moe_fp8_fn
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
)
if
use_cudagraph
:
out
.
fill_
(
0
)
stream
=
torch
.
cuda
.
Stream
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
stream
):
out
=
deep_gemm_moe_fp8_fn
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
)
torch
.
cuda
.
synchronize
()
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
0.035
,
rtol
=
0.035
)
tests/kernels/moe/test_block_int8.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
tests.kernels.moe.utils
import
make_test_weights
from
tests.kernels.quant_utils
import
(
native_per_token_group_quant_int8
,
native_w8a8_block_matmul
)
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.platforms
import
current_platform
if
current_platform
.
get_device_capability
()
<
(
7
,
0
):
pytest
.
skip
(
"INT8 Triton requires CUDA 7.0 or higher"
,
allow_module_level
=
True
)
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
]
MNK_FACTORS
=
[
(
1
,
128
,
128
),
(
1
,
512
,
512
),
(
1
,
128
,
7168
),
(
1
,
1024
,
7168
),
(
1
,
4096
,
128
),
(
1
,
4096
,
512
),
(
1
,
4096
,
7168
),
(
33
,
128
,
128
),
(
33
,
512
,
512
),
(
33
,
128
,
7168
),
(
33
,
1024
,
7168
),
(
33
,
4096
,
128
),
(
33
,
4096
,
512
),
(
33
,
4096
,
7168
),
(
128
,
128
,
128
),
(
128
,
512
,
512
),
(
128
,
1024
,
7168
),
(
128
,
4096
,
512
),
(
128
,
4096
,
7168
),
(
222
,
128
,
128
),
(
222
,
512
,
512
),
(
222
,
1024
,
7168
),
(
222
,
4096
,
512
),
(
222
,
4096
,
7168
),
(
2048
,
128
,
128
),
(
2048
,
1024
,
7168
),
(
2048
,
4096
,
512
),
(
2048
,
4096
,
7168
),
]
E
=
[
8
,
24
]
TOP_KS
=
[
2
,
6
]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE
=
[[
128
,
128
]]
SEEDS
=
[
0
]
# For test
def
torch_w8a8_block_int8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_shape
):
"""This function performs fused moe with block-wise quantization using
native torch."""
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
_
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
a_q
,
a_s
=
native_per_token_group_quant_int8
(
a
,
block_k
)
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
inter_out
=
native_w8a8_block_matmul
(
a_q
[
mask
],
w1
[
i
],
a_s
[
mask
],
w1_s
[
i
],
block_shape
,
output_dtype
=
a
.
dtype
)
act_out
=
SiluAndMul
().
forward_native
(
inter_out
)
act_out_q
,
act_out_s
=
native_per_token_group_quant_int8
(
act_out
,
block_k
)
act_out
=
act_out
.
to
(
torch
.
float32
)
out
[
mask
]
=
native_w8a8_block_matmul
(
act_out_q
,
w2
[
i
],
act_out_s
,
w2_s
[
i
],
block_shape
,
output_dtype
=
a
.
dtype
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
@
pytest
.
fixture
(
autouse
=
True
,
scope
=
"module"
)
def
setup_cuda
():
"""Sets the default CUDA device for all tests in this module."""
torch
.
set_default_device
(
"cuda"
)
@
pytest
.
mark
.
parametrize
((
"M"
,
"N"
,
"K"
),
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"E"
,
E
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_w8a8_block_int8_fused_moe
(
M
,
N
,
K
,
E
,
topk
,
block_size
,
dtype
,
seed
):
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
native torch reference."""
torch
.
manual_seed
(
seed
)
a
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
_
,
w1
,
w1_s
,
_
,
w2
,
w2_s
=
make_test_weights
(
E
,
N
,
K
,
dtype
,
torch
.
int8
,
per_act_token_quant
=
False
,
block_shape
=
block_size
)
# Set the context to avoid lots of warning spam.
with
set_current_vllm_config
(
vllm_config
):
out
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
,
use_int8_w8a8
=
True
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
block_shape
=
block_size
,
)
ref_out
=
torch_w8a8_block_int8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_size
)
# Check results
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
0.065
,
rtol
=
0.065
)
tests/kernels/moe/test_cutlass_grouped_gemm.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# DeepGEMM Style Cutlass Grouped GEMM Test
# See https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py
import
random
import
pytest
import
torch
from
tests.kernels.utils
import
baseline_scaled_mm
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
def
per_token_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
pad_size
=
(
128
-
(
n
%
128
))
%
128
x
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
pad_size
),
value
=
0
)
if
pad_size
>
0
else
x
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
fp8_data
=
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
dtype
=
torch
.
float8_e4m3fn
)
return
fp8_data
.
view
(
m
,
n
+
pad_size
)[:,
:
n
],
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
((
cdiv
(
m
,
128
)
*
128
,
cdiv
(
n
,
128
)
*
128
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
@
pytest
.
mark
.
parametrize
(
"num_groups, expected_m_per_group, k, n"
,
[
(
4
,
8192
,
7168
,
4096
),
(
4
,
8192
,
2048
,
7168
),
(
8
,
4096
,
7168
,
4096
),
(
8
,
4096
,
2048
,
7168
),
(
32
,
1024
,
7168
,
4096
),
(
32
,
1024
,
2048
,
7168
),
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
skipif
(
(
lambda
x
:
x
is
None
or
x
.
to_int
()
!=
100
)(
current_platform
.
get_device_capability
()),
reason
=
"Block Scaled Grouped GEMM is only supported on SM100."
)
def
test_cutlass_grouped_gemm
(
num_groups
:
int
,
expected_m_per_group
:
int
,
k
:
int
,
n
:
int
,
out_dtype
:
torch
.
dtype
,
):
device
=
"cuda"
alignment
=
128
group_ms
=
[
int
(
expected_m_per_group
*
random
.
uniform
(
0.7
,
1.3
))
for
_
in
range
(
num_groups
)
]
m
=
sum
([
cdiv
(
m
,
alignment
)
*
alignment
for
m
in
group_ms
])
x
=
torch
.
randn
((
m
,
k
),
device
=
device
,
dtype
=
out_dtype
)
y
=
torch
.
randn
((
num_groups
,
n
,
k
),
device
=
device
,
dtype
=
out_dtype
)
out
=
torch
.
empty
((
m
,
n
),
device
=
device
,
dtype
=
out_dtype
)
ref_out
=
torch
.
randn
((
m
,
n
),
device
=
device
,
dtype
=
out_dtype
)
ep_offset
=
[
0
]
+
[
sum
(
group_ms
[:
i
])
for
i
in
range
(
1
,
num_groups
)]
+
[
m
]
pb_size
=
[]
for
i
in
range
(
num_groups
):
pb_size
.
append
([
ep_offset
[
i
+
1
]
-
ep_offset
[
i
],
n
,
k
])
problem_sizes
=
torch
.
tensor
(
pb_size
,
device
=
device
,
dtype
=
torch
.
int32
)
expert_offsets
=
torch
.
tensor
(
ep_offset
,
device
=
device
,
dtype
=
torch
.
int32
)
x_fp8
=
per_token_cast_to_fp8
(
x
)
y_fp8
=
(
torch
.
empty_like
(
y
,
dtype
=
torch
.
float8_e4m3fn
),
torch
.
empty
((
num_groups
,
cdiv
(
n
,
128
),
k
//
128
),
device
=
device
,
dtype
=
torch
.
float
))
for
i
in
range
(
num_groups
):
y_fp8
[
0
][
i
],
y_fp8
[
1
][
i
]
=
per_block_cast_to_fp8
(
y
[
i
])
for
i
in
range
(
num_groups
):
a
=
x_fp8
[
0
][
ep_offset
[
i
]:
ep_offset
[
i
+
1
]]
a_scale
=
x_fp8
[
1
][
ep_offset
[
i
]:
ep_offset
[
i
+
1
]]
b
=
y_fp8
[
0
][
i
].
t
()
b_scale
=
y_fp8
[
1
][
i
].
t
()
baseline
=
baseline_scaled_mm
(
a
,
b
,
a_scale
,
b_scale
,
out_dtype
)
ref_out
[
ep_offset
[
i
]:
ep_offset
[
i
+
1
]]
=
baseline
ops
.
cutlass_blockwise_scaled_grouped_mm
(
out
,
x_fp8
[
0
],
y_fp8
[
0
],
x_fp8
[
1
],
y_fp8
[
1
],
problem_sizes
,
expert_offsets
[:
-
1
],
)
torch
.
testing
.
assert_close
(
ref_out
,
out
,
atol
=
5e-1
,
rtol
=
1e-3
)
tests/kernels/moe/test_cutlass_moe.py
View file @
99324e25
...
@@ -29,6 +29,10 @@ MNK_FACTORS = [
...
@@ -29,6 +29,10 @@ MNK_FACTORS = [
(
224
,
1024
,
1536
),
(
224
,
1024
,
1536
),
(
224
,
3072
,
1024
),
(
224
,
3072
,
1024
),
(
224
,
3072
,
1536
),
(
224
,
3072
,
1536
),
(
32768
,
1024
,
1024
),
# These sizes trigger wrong answers.
#(7232, 2048, 5120),
#(40000, 2048, 5120),
]
]
vllm_config
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
vllm_config
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
...
@@ -93,11 +97,9 @@ class MOETensors8Bit(MOETensors):
...
@@ -93,11 +97,9 @@ class MOETensors8Bit(MOETensors):
n_b_scales
=
2
*
n
if
per_out_channel
else
1
n_b_scales
=
2
*
n
if
per_out_channel
else
1
k_b_scales
=
k
if
per_out_channel
else
1
k_b_scales
=
k
if
per_out_channel
else
1
# Get the right scale for tests.
# Get the right scale for tests.
_
,
a_scale
=
ops
.
scaled_fp8_quant
(
a_q
,
a_scale
=
ops
.
scaled_fp8_quant
(
moe_tensors_fp16
.
a
,
use_per_token_if_dynamic
=
per_act_token
)
moe_tensors_fp16
.
a
,
None
,
use_per_token_if_dynamic
=
per_act_token
)
a_q
,
_
=
ops
.
scaled_fp8_quant
(
moe_tensors_fp16
.
a
,
a_scale
,
use_per_token_if_dynamic
=
per_act_token
)
w1_q
=
torch
.
empty
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
q_dtype
)
w1_q
=
torch
.
empty
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
q_dtype
)
w2_q
=
torch
.
empty
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
q_dtype
)
w2_q
=
torch
.
empty
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
q_dtype
)
...
@@ -183,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
...
@@ -183,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
def
run_8_bit
(
moe_tensors
:
MOETensors8Bit
,
def
run_8_bit
(
moe_tensors
:
MOETensors8Bit
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
per_act_token
:
bool
,
num_local_experts
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
num_local_experts
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
assert
not
any
([
assert
not
any
([
t
is
None
for
t
in
[
t
is
None
for
t
in
[
...
@@ -199,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
...
@@ -199,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids'
:
topk_ids
,
'topk_ids'
:
topk_ids
,
'w1_scale'
:
moe_tensors
.
w1_scale
,
'w1_scale'
:
moe_tensors
.
w1_scale
,
'w2_scale'
:
moe_tensors
.
w2_scale
,
'w2_scale'
:
moe_tensors
.
w2_scale
,
'a1_scale'
:
moe_tensors
.
a_scale
'per_act_token'
:
per_act_token
,
'a1_scale'
:
None
#moe_tensors.a_scale
}
}
num_experts
=
moe_tensors
.
w1
.
size
(
0
)
num_experts
=
moe_tensors
.
w1
.
size
(
0
)
...
@@ -231,8 +235,10 @@ def test_cutlass_moe_8_bit_no_graph(
...
@@ -231,8 +235,10 @@ def test_cutlass_moe_8_bit_no_graph(
topk
:
int
,
topk
:
int
,
per_act_token
:
bool
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
per_out_ch
:
bool
,
monkeypatch
,
):
):
current_platform
.
seed_everything
(
7
)
current_platform
.
seed_everything
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"8192"
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
per_out_ch
)
per_out_ch
)
...
@@ -248,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph(
...
@@ -248,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph(
triton_output
=
fused_experts
(
mt
.
a_d
,
mt
.
w1_d
,
mt
.
w2_d
,
topk_weights
,
triton_output
=
fused_experts
(
mt
.
a_d
,
mt
.
w1_d
,
mt
.
w2_d
,
topk_weights
,
topk_ids
)
topk_ids
)
cutlass_output
=
run_8_bit
(
mt
,
topk_weights
,
topk_ids
)
cutlass_output
=
run_8_bit
(
mt
,
topk_weights
,
topk_ids
,
per_act_token
)
# Note 5.5 only needed for larger problem sizes, 5 works ok for
# the rest.
torch
.
testing
.
assert_close
(
triton_output
,
torch
.
testing
.
assert_close
(
triton_output
,
cutlass_output
,
cutlass_output
,
atol
=
5e-2
,
atol
=
5.
5e-2
,
rtol
=
1e-2
)
rtol
=
1e-2
)
...
@@ -273,8 +281,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
...
@@ -273,8 +281,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
topk
:
int
,
topk
:
int
,
per_act_token
:
bool
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
per_out_ch
:
bool
,
monkeypatch
,
):
):
current_platform
.
seed_everything
(
7
)
current_platform
.
seed_everything
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"8192"
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
dtype
=
torch
.
half
dtype
=
torch
.
half
...
@@ -295,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph(
...
@@ -295,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph(
stream
=
torch
.
cuda
.
Stream
()
stream
=
torch
.
cuda
.
Stream
()
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
stream
):
with
torch
.
cuda
.
graph
(
graph
,
stream
=
stream
):
cutlass_output
=
run_8_bit
(
mt
,
topk_weights
,
topk_ids
)
cutlass_output
=
run_8_bit
(
mt
,
topk_weights
,
topk_ids
,
per_act_token
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
graph
.
replay
()
graph
.
replay
()
...
@@ -328,8 +339,10 @@ def test_cutlass_moe_8_bit_EP(
...
@@ -328,8 +339,10 @@ def test_cutlass_moe_8_bit_EP(
per_act_token
:
bool
,
per_act_token
:
bool
,
per_out_channel
:
bool
,
per_out_channel
:
bool
,
ep_size
:
int
,
ep_size
:
int
,
monkeypatch
,
):
):
current_platform
.
seed_everything
(
7
)
current_platform
.
seed_everything
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"8192"
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
mt
=
MOETensors8Bit
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
per_act_token
,
per_out_channel
)
per_out_channel
)
...
@@ -349,6 +362,7 @@ def test_cutlass_moe_8_bit_EP(
...
@@ -349,6 +362,7 @@ def test_cutlass_moe_8_bit_EP(
cutlass_output
=
run_8_bit
(
mt
,
cutlass_output
=
run_8_bit
(
mt
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
per_act_token
,
num_local_experts
=
e
//
ep_size
)
num_local_experts
=
e
//
ep_size
)
torch
.
testing
.
assert_close
(
triton_output
,
torch
.
testing
.
assert_close
(
triton_output
,
...
...
tests/kernels/moe/test_deepep_deepgemm_moe.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
"""
Test DeepEP + DeepGEMM integration
Test DeepEP + DeepGEMM integration
DeepGEMM are gemm kernels specialized for the
DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case.
fp8 block-quantized case.
"""
"""
import
dataclasses
import
dataclasses
import
importlib
from
typing
import
Optional
from
typing
import
Optional
import
pytest
import
pytest
...
@@ -18,41 +18,34 @@ from vllm.config import VllmConfig, set_current_vllm_config
...
@@ -18,41 +18,34 @@ from vllm.config import VllmConfig, set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
)
FusedMoEModularKernel
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_ep
,
has_deep_gemm
from
.deepep_utils
import
ProcessGroupInfo
,
parallel_launch
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
from
.utils
import
make_test_weights
has_deep_ep
=
importlib
.
util
.
find_spec
(
"deep_ep"
)
is
not
None
if
has_deep_ep
():
try
:
import
deep_gemm
has_deep_gemm
=
True
except
ImportError
:
has_deep_gemm
=
False
if
has_deep_ep
:
from
vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize
import
(
# noqa: E501
from
vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize
import
(
# noqa: E501
DeepEPHTPrepareAndFinalize
)
DeepEPHTPrepareAndFinalize
)
from
vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize
import
(
# noqa: E501
from
vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize
import
(
# noqa: E501
DeepEPLLPrepareAndFinalize
)
DeepEPLLPrepareAndFinalize
)
from
.deepep_utils
import
DeepEPHTArgs
,
DeepEPLLArgs
,
make_deepep_a2a
from
.parallel_utils
import
DeepEPHTArgs
,
DeepEPLLArgs
,
make_deepep_a2a
if
has_deep_gemm
():
if
has_deep_gemm
:
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
BatchedDeepGemmExperts
)
BatchedDeepGemmExperts
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
DeepGemmExperts
)
DeepGemmExperts
)
requires_deep_ep
=
pytest
.
mark
.
skipif
(
requires_deep_ep
=
pytest
.
mark
.
skipif
(
not
has_deep_ep
,
not
has_deep_ep
()
,
reason
=
"Requires deep_ep kernels"
,
reason
=
"Requires deep_ep kernels"
,
)
)
requires_deep_gemm
=
pytest
.
mark
.
skipif
(
requires_deep_gemm
=
pytest
.
mark
.
skipif
(
not
has_deep_gemm
,
not
has_deep_gemm
()
,
reason
=
"Requires deep_gemm kernels"
,
reason
=
"Requires deep_gemm kernels"
,
)
)
...
@@ -66,25 +59,6 @@ def next_power_of_2(x):
...
@@ -66,25 +59,6 @@ def next_power_of_2(x):
return
2
**
math
.
ceil
(
math
.
log2
(
x
))
return
2
**
math
.
ceil
(
math
.
log2
(
x
))
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
,
block_size_n
:
int
=
128
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
deep_gemm
.
ceil_div
(
m
,
128
)
*
128
,
deep_gemm
.
ceil_div
(
n
,
block_size_n
)
*
block_size_n
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
block_size_n
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
x_scaled_sub
=
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
()
scales
=
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
return
x_scaled_sub
,
scales
def
make_block_quant_fp8_weights
(
def
make_block_quant_fp8_weights
(
e
:
int
,
e
:
int
,
n
:
int
,
n
:
int
,
...
@@ -92,43 +66,11 @@ def make_block_quant_fp8_weights(
...
@@ -92,43 +66,11 @@ def make_block_quant_fp8_weights(
block_size
:
list
[
int
],
block_size
:
list
[
int
],
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Return weights
w1, w2,
w1q, w2q, w1_scale, w2_scale
Return weights w1q, w2q, w1_scale, w2_scale
"""
"""
dtype
=
torch
.
bfloat16
w1
,
w1q
,
w1_scale
,
w2
,
w2q
,
w2_scale
=
make_test_weights
(
e
,
n
,
k
,
torch
.
bfloat16
,
torch
.
float8_e4m3fn
,
block_size
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
w1q
,
w2q
,
w1_scale
,
w2_scale
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
w1_bf16
=
torch
.
randn
((
e
,
2
*
n
,
k
),
dtype
=
dtype
)
/
10
w1_bf16
=
w1_bf16
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
dtype
=
dtype
)
w2_bf16
=
torch
.
randn
((
e
,
k
,
n
),
dtype
=
dtype
)
/
10
w2_bf16
=
w2_bf16
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
dtype
=
dtype
)
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles_w1
=
((
2
*
n
)
+
block_n
-
1
)
//
block_n
k_tiles_w1
=
(
k
+
block_k
-
1
)
//
block_k
n_tiles_w2
=
(
k
+
block_n
-
1
)
//
block_n
k_tiles_w2
=
(
n
+
block_k
-
1
)
//
block_k
w1
=
torch
.
empty_like
(
w1_bf16
,
dtype
=
torch
.
float8_e4m3fn
)
w2
=
torch
.
empty_like
(
w2_bf16
,
dtype
=
torch
.
float8_e4m3fn
)
w1_s
=
torch
.
empty
((
e
,
n_tiles_w1
,
k_tiles_w1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w2_s
=
torch
.
empty
((
e
,
n_tiles_w2
,
k_tiles_w2
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
assert
w1_s
.
shape
==
(
e
,
(
2
*
n
+
127
)
//
128
,
(
k
+
127
)
//
128
)
assert
(
w2
.
shape
[
-
2
]
+
block_n
-
1
)
//
block_n
==
w2_s
.
shape
[
-
2
]
for
i
in
range
(
e
):
w1
[
i
],
w1_s
[
i
]
=
per_block_cast_to_fp8
(
w1_bf16
[
i
])
w2
[
i
],
w2_s
[
i
]
=
per_block_cast_to_fp8
(
w2_bf16
[
i
])
return
w1
,
w2
,
w1_s
,
w2_s
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -138,6 +80,7 @@ class TestConfig:
...
@@ -138,6 +80,7 @@ class TestConfig:
k
:
int
k
:
int
n
:
int
n
:
int
num_experts
:
int
num_experts
:
int
per_act_token_quant
:
bool
block_size
:
list
[
int
]
block_size
:
list
[
int
]
# configs for testing low-latency kernels
# configs for testing low-latency kernels
low_latency
:
bool
low_latency
:
bool
...
@@ -156,8 +99,7 @@ class TestTensors:
...
@@ -156,8 +99,7 @@ class TestTensors:
def
make
(
config
:
TestConfig
,
rank
)
->
"TestTensors"
:
def
make
(
config
:
TestConfig
,
rank
)
->
"TestTensors"
:
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
topk
,
m
,
k
,
block_size
=
(
config
.
topk
,
config
.
m
,
config
.
k
,
topk
,
m
,
k
=
(
config
.
topk
,
config
.
m
,
config
.
k
)
config
.
block_size
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
...
@@ -165,9 +107,7 @@ class TestTensors:
...
@@ -165,9 +107,7 @@ class TestTensors:
rank_tokens
=
torch
.
randn
(
rank_tokens
=
torch
.
randn
(
(
m
,
k
),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
/
10.0
(
m
,
k
),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
/
10.0
rank_tokens
=
rank_tokens
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
)
rank_tokens
=
rank_tokens
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
)
rank_token_scales
=
None
block_k
=
block_size
[
1
]
_
,
rank_token_scales
=
per_token_group_quant_fp8
(
rank_tokens
,
block_k
)
topk_ids
=
torch
.
randint
(
topk_ids
=
torch
.
randint
(
low
=
0
,
low
=
0
,
...
@@ -207,10 +147,11 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
...
@@ -207,10 +147,11 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype
=
q_dtype
,
q_dtype
=
q_dtype
,
block_shape
=
test_config
.
block_size
)
block_shape
=
test_config
.
block_size
)
fused_experts
=
BatchedDeepGemmExperts
(
max_num_tokens
=
max_tokens_per_rank
,
fused_experts
=
BatchedDeepGemmExperts
(
world_size
=
pgi
.
world_size
,
max_num_tokens
=
max_tokens_per_rank
,
dp_size
=
dp_size
,
num_dispatchers
=
pgi
.
world_size
//
dp_size
,
block_shape
=
test_config
.
block_size
)
block_shape
=
test_config
.
block_size
,
per_act_token_quant
=
test_config
.
per_act_token_quant
)
mk
=
FusedMoEModularKernel
(
prepare_finalize
=
a2a
,
mk
=
FusedMoEModularKernel
(
prepare_finalize
=
a2a
,
fused_experts
=
fused_experts
)
fused_experts
=
fused_experts
)
return
mk
return
mk
...
@@ -432,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
...
@@ -432,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
"""
"""
Tests for High-Throughput DeepEP + DeepGemm integration.
Tests for High-Throughput DeepEP + DeepGemm integration.
"""
"""
import
deep_gemm
m
,
n
,
k
=
mnk
m
,
n
,
k
=
mnk
current_platform
.
seed_everything
(
7
)
current_platform
.
seed_everything
(
7
)
...
@@ -448,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
...
@@ -448,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
k
=
k
,
k
=
k
,
n
=
n
,
n
=
n
,
num_experts
=
num_experts
,
num_experts
=
num_experts
,
per_act_token_quant
=
False
,
block_size
=
block_size
,
block_size
=
block_size
,
low_latency
=
False
,
low_latency
=
False
,
use_fp8_dispatch
=
None
)
use_fp8_dispatch
=
None
)
...
@@ -480,10 +423,14 @@ USE_FP8_DISPATCH = [False]
...
@@ -480,10 +423,14 @@ USE_FP8_DISPATCH = [False]
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
requires_deep_ep
@
requires_deep_ep
@
requires_deep_gemm
@
requires_deep_gemm
def
test_ll_deepep_deepgemm_moe
(
mnk
:
tuple
[
int
,
int
,
def
test_ll_deepep_deepgemm_moe
(
int
],
num_experts
:
int
,
topk
:
int
,
mnk
:
tuple
[
int
,
int
,
int
],
use_fp8_dispatch
:
bool
,
block_size
:
list
[
int
],
num_experts
:
int
,
world_dp_size
:
tuple
[
int
,
int
]):
topk
:
int
,
use_fp8_dispatch
:
bool
,
block_size
:
list
[
int
],
world_dp_size
:
tuple
[
int
,
int
],
):
"""
"""
Tests for Low-Latency DeepEP + DeepGemm integration.
Tests for Low-Latency DeepEP + DeepGemm integration.
"""
"""
...
@@ -501,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
...
@@ -501,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
k
=
k
,
k
=
k
,
n
=
n
,
n
=
n
,
num_experts
=
num_experts
,
num_experts
=
num_experts
,
per_act_token_quant
=
False
,
block_size
=
block_size
,
block_size
=
block_size
,
low_latency
=
True
,
low_latency
=
True
,
use_fp8_dispatch
=
use_fp8_dispatch
,
use_fp8_dispatch
=
use_fp8_dispatch
,
...
...
tests/kernels/moe/test_deepep_moe.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
"""
Test deepep dispatch-combine logic
Test deepep dispatch-combine logic
"""
"""
import
dataclasses
import
dataclasses
import
importlib
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
import
pytest
import
pytest
...
@@ -22,21 +22,20 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
...
@@ -22,21 +22,20 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
per_token_group_quant_fp8
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_ep
from
.
deepep
_utils
import
ProcessGroupInfo
,
parallel_launch
from
.
parallel
_utils
import
ProcessGroupInfo
,
parallel_launch
has_deep_ep
=
importlib
.
util
.
find_spec
(
"deep_ep"
)
is
not
None
if
has_deep_ep
():
if
has_deep_ep
:
from
vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize
import
(
# noqa: E501
from
vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize
import
(
# noqa: E501
DeepEPHTPrepareAndFinalize
)
DeepEPHTPrepareAndFinalize
)
from
vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize
import
(
# noqa: E501
from
vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize
import
(
# noqa: E501
DeepEPLLPrepareAndFinalize
)
DeepEPLLPrepareAndFinalize
)
from
.
deepep
_utils
import
DeepEPHTArgs
,
DeepEPLLArgs
,
make_deepep_a2a
from
.
parallel
_utils
import
DeepEPHTArgs
,
DeepEPLLArgs
,
make_deepep_a2a
requires_deep_ep
=
pytest
.
mark
.
skipif
(
requires_deep_ep
=
pytest
.
mark
.
skipif
(
not
has_deep_ep
,
not
has_deep_ep
()
,
reason
=
"Requires deep_ep kernels"
,
reason
=
"Requires deep_ep kernels"
,
)
)
...
@@ -104,10 +103,6 @@ class TestTensors:
...
@@ -104,10 +103,6 @@ class TestTensors:
rank_tokens
=
torch
.
randn
(
rank_tokens
=
torch
.
randn
(
(
config
.
m
,
config
.
k
),
device
=
"cuda"
,
dtype
=
token_dtype
)
/
10
(
config
.
m
,
config
.
k
),
device
=
"cuda"
,
dtype
=
token_dtype
)
/
10
rank_token_scales
=
None
rank_token_scales
=
None
if
config
.
dtype
==
torch
.
float8_e4m3fn
:
# low_latency_mode kernels dont support per-token quant.
_
,
rank_token_scales
=
ops
.
scaled_fp8_quant
(
rank_tokens
,
use_per_token_if_dynamic
=
not
low_latency_mode
)
topk
=
torch
.
randint
(
low
=
0
,
topk
=
torch
.
randint
(
low
=
0
,
high
=
config
.
num_experts
,
high
=
config
.
num_experts
,
...
@@ -123,11 +118,18 @@ class TestTensors:
...
@@ -123,11 +118,18 @@ class TestTensors:
config
=
config
)
config
=
config
)
def
make_modular_kernel
(
pg
:
ProcessGroup
,
pgi
:
ProcessGroupInfo
,
def
make_modular_kernel
(
low_latency_mode
:
bool
,
hidden_size
:
int
,
dp_size
:
int
,
pg
:
ProcessGroup
,
num_experts
:
int
,
num_local_experts
:
int
,
pgi
:
ProcessGroupInfo
,
q_dtype
:
Optional
[
torch
.
dtype
],
low_latency_mode
:
bool
,
use_fp8_dispatch
:
bool
)
->
FusedMoEModularKernel
:
hidden_size
:
int
,
dp_size
:
int
,
num_experts
:
int
,
num_local_experts
:
int
,
q_dtype
:
Optional
[
torch
.
dtype
],
use_fp8_dispatch
:
bool
,
per_act_token_quant
:
bool
,
)
->
FusedMoEModularKernel
:
is_quantized
=
q_dtype
is
not
None
is_quantized
=
q_dtype
is
not
None
...
@@ -153,33 +155,47 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
...
@@ -153,33 +155,47 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
deepep_ht_args
=
ht_args
,
deepep_ht_args
=
ht_args
,
deepep_ll_args
=
ll_args
)
deepep_ll_args
=
ll_args
)
num_dispatchers
=
pgi
.
world_size
//
dp_size
if
low_latency_mode
:
if
low_latency_mode
:
assert
not
per_act_token_quant
,
"not supported in ll mode"
fused_experts
=
BatchedTritonExperts
(
fused_experts
=
BatchedTritonExperts
(
max_num_tokens
=
MAX_TOKENS_PER_RANK
,
max_num_tokens
=
MAX_TOKENS_PER_RANK
,
world_size
=
pgi
.
world_size
,
num_dispatchers
=
num_dispatchers
,
dp_size
=
dp_size
,
use_fp8_w8a8
=
is_quantized
,
use_fp8_w8a8
=
is_quantized
,
use_int8_w8a8
=
False
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
)
use_int4_w4a16
=
False
,
per_act_token_quant
=
False
,
)
else
:
else
:
fused_experts
=
TritonExperts
(
use_fp8_w8a8
=
is_quantized
,
fused_experts
=
TritonExperts
(
use_int8_w8a8
=
False
,
use_fp8_w8a8
=
is_quantized
,
use_int8_w8a16
=
False
,
use_int8_w8a8
=
False
,
use_int4_w4a16
=
False
,
use_int8_w8a16
=
False
,
per_channel_quant
=
False
)
use_int4_w4a16
=
False
,
per_act_token_quant
=
per_act_token_quant
,
)
mk
=
FusedMoEModularKernel
(
prepare_finalize
=
a2a
,
mk
=
FusedMoEModularKernel
(
prepare_finalize
=
a2a
,
fused_experts
=
fused_experts
)
fused_experts
=
fused_experts
)
return
mk
return
mk
def
deep_ep_moe_impl
(
pg
:
ProcessGroup
,
pgi
:
ProcessGroupInfo
,
def
deep_ep_moe_impl
(
low_latency_mode
:
bool
,
dp_size
:
int
,
pg
:
ProcessGroup
,
test_tensors
:
TestTensors
,
w1
:
torch
.
Tensor
,
pgi
:
ProcessGroupInfo
,
w2
:
torch
.
Tensor
,
w1_scale
:
Optional
[
torch
.
Tensor
],
low_latency_mode
:
bool
,
w2_scale
:
Optional
[
torch
.
Tensor
],
num_experts
:
int
,
dp_size
:
int
,
use_fp8_dispatch
:
bool
)
->
torch
.
Tensor
:
test_tensors
:
TestTensors
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
num_experts
:
int
,
use_fp8_dispatch
:
bool
,
per_act_token_quant
:
bool
,
)
->
torch
.
Tensor
:
num_local_experts
=
w1
.
size
(
0
)
num_local_experts
=
w1
.
size
(
0
)
...
@@ -201,11 +217,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
...
@@ -201,11 +217,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype
=
torch
.
float8_e4m3fn
q_dtype
=
torch
.
float8_e4m3fn
# Make modular kernel
# Make modular kernel
mk
:
FusedMoEModularKernel
=
make_modular_kernel
(
pg
,
pgi
,
low_latency_mode
,
mk
:
FusedMoEModularKernel
=
make_modular_kernel
(
hidden_size
,
dp_size
,
pg
,
pgi
,
low_latency_mode
,
hidden_size
,
dp_size
,
num_experts
,
num_experts
,
num_local_experts
,
q_dtype
,
use_fp8_dispatch
,
per_act_token_quant
)
num_local_experts
,
q_dtype
,
use_fp8_dispatch
)
out_hidden_states
=
torch
.
empty_like
(
test_tensors
.
rank_tokens
)
out_hidden_states
=
torch
.
empty_like
(
test_tensors
.
rank_tokens
)
total_num_tokens
=
test_tensors
.
rank_tokens
.
size
(
0
)
total_num_tokens
=
test_tensors
.
rank_tokens
.
size
(
0
)
...
@@ -259,9 +273,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
...
@@ -259,9 +273,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
return
out_hidden_states
return
out_hidden_states
def
torch_moe_impl
(
test_tensors
:
TestTensors
,
w1
:
torch
.
Tensor
,
def
torch_moe_impl
(
w2
:
torch
.
Tensor
,
w1_scale
:
Optional
[
torch
.
Tensor
],
test_tensors
:
TestTensors
,
w2_scale
:
Optional
[
torch
.
Tensor
],
using_fp8_dispatch
:
bool
):
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
using_fp8_dispatch
:
bool
,
per_act_token_quant
:
bool
,
):
a
,
topk_ids
,
topk_weights
=
(
test_tensors
.
rank_tokens
,
test_tensors
.
topk
,
a
,
topk_ids
,
topk_weights
=
(
test_tensors
.
rank_tokens
,
test_tensors
.
topk
,
test_tensors
.
topk_weights
)
test_tensors
.
topk_weights
)
...
@@ -269,6 +289,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
...
@@ -269,6 +289,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
# The DeepEP implementation is requested to dispatch using FP8.
# The DeepEP implementation is requested to dispatch using FP8.
# For numerical stability for testing, emulate the fp8 dispatch by
# For numerical stability for testing, emulate the fp8 dispatch by
# blockwise quant and de-quant.
# blockwise quant and de-quant.
assert
not
per_act_token_quant
a
=
test_tensors
.
rank_tokens
a
=
test_tensors
.
rank_tokens
aq
,
aq_scale
=
per_token_group_quant_fp8
(
a
,
128
)
aq
,
aq_scale
=
per_token_group_quant_fp8
(
a
,
128
)
a
=
(
aq
.
view
(
-
1
,
128
).
to
(
torch
.
float32
)
*
aq_scale
.
view
(
-
1
,
1
)).
view
(
a
=
(
aq
.
view
(
-
1
,
128
).
to
(
torch
.
float32
)
*
aq_scale
.
view
(
-
1
,
1
)).
view
(
...
@@ -312,6 +333,7 @@ def _deep_ep_moe(
...
@@ -312,6 +333,7 @@ def _deep_ep_moe(
w1_scale
:
Optional
[
torch
.
Tensor
],
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
use_fp8_dispatch
:
bool
,
use_fp8_dispatch
:
bool
,
per_act_token_quant
:
bool
,
):
):
if
not
low_latency_mode
:
if
not
low_latency_mode
:
...
@@ -333,7 +355,8 @@ def _deep_ep_moe(
...
@@ -333,7 +355,8 @@ def _deep_ep_moe(
with
set_current_vllm_config
(
VllmConfig
()):
with
set_current_vllm_config
(
VllmConfig
()):
# Reference
# Reference
torch_combined
=
torch_moe_impl
(
test_tensors
,
w1
,
w2
,
w1_scale
,
torch_combined
=
torch_moe_impl
(
test_tensors
,
w1
,
w2
,
w1_scale
,
w2_scale
,
use_fp8_dispatch
)
w2_scale
,
use_fp8_dispatch
,
per_act_token_quant
)
# Splice experts for this rank.
# Splice experts for this rank.
num_local_experts
=
config
.
num_experts
//
pgi
.
world_size
num_local_experts
=
config
.
num_experts
//
pgi
.
world_size
...
@@ -358,6 +381,7 @@ def _deep_ep_moe(
...
@@ -358,6 +381,7 @@ def _deep_ep_moe(
w2_scale_ep
,
w2_scale_ep
,
config
.
num_experts
,
config
.
num_experts
,
use_fp8_dispatch
,
use_fp8_dispatch
,
per_act_token_quant
,
)
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
...
@@ -386,10 +410,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
...
@@ -386,10 +410,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
@
requires_deep_ep
@
requires_deep_ep
def
test_deep_ep_moe
(
dtype
:
torch
.
dtype
,
mnk
:
tuple
[
int
,
int
,
int
],
def
test_deep_ep_moe
(
num_experts
:
int
,
topk
:
int
,
world_dp_size
:
tuple
[
int
,
dtype
:
torch
.
dtype
,
int
]):
mnk
:
tuple
[
int
,
int
,
int
],
num_experts
:
int
,
topk
:
int
,
world_dp_size
:
tuple
[
int
,
int
],
per_act_token_quant
:
bool
,
):
low_latency_mode
=
False
low_latency_mode
=
False
use_fp8_dispatch
=
False
use_fp8_dispatch
=
False
m
,
n
,
k
=
mnk
m
,
n
,
k
=
mnk
...
@@ -406,7 +436,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
...
@@ -406,7 +436,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1
,
w2
,
w1_scale
,
w2_scale
=
make_weights
(
num_experts
,
n
,
k
,
dtype
)
w1
,
w2
,
w1_scale
,
w2_scale
=
make_weights
(
num_experts
,
n
,
k
,
dtype
)
parallel_launch
(
world_size
,
_deep_ep_moe
,
low_latency_mode
,
dp_size
,
parallel_launch
(
world_size
,
_deep_ep_moe
,
low_latency_mode
,
dp_size
,
config
,
w1
,
w2
,
w1_scale
,
w2_scale
,
use_fp8_dispatch
)
config
,
w1
,
w2
,
w1_scale
,
w2_scale
,
use_fp8_dispatch
,
per_act_token_quant
)
MNKs
=
[
MNKs
=
[
...
@@ -456,4 +487,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
...
@@ -456,4 +487,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1
,
w2
,
w1_scale
,
w2_scale
=
make_weights
(
num_experts
,
n
,
k
,
dtype
)
w1
,
w2
,
w1_scale
,
w2_scale
=
make_weights
(
num_experts
,
n
,
k
,
dtype
)
parallel_launch
(
world_size
,
_deep_ep_moe
,
low_latency_mode
,
dp_size
,
parallel_launch
(
world_size
,
_deep_ep_moe
,
low_latency_mode
,
dp_size
,
config
,
w1
,
w2
,
w1_scale
,
w2_scale
,
use_fp8_dispatch
)
config
,
w1
,
w2
,
w1_scale
,
w2_scale
,
use_fp8_dispatch
,
False
)
tests/kernels/moe/test_deepgemm.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit-test DeepGEMM FP8 kernels (no DeepEP).
Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts.
"""
import
importlib
import
math
import
pytest
import
torch
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.utils
import
cdiv
has_deep_gemm
=
importlib
.
util
.
find_spec
(
"deep_gemm"
)
is
not
None
if
has_deep_gemm
:
import
deep_gemm
BLOCK_M
=
deep_gemm
.
get_m_alignment_for_contiguous_layout
()
BLOCK_SIZE
=
[
BLOCK_M
,
BLOCK_M
]
requires_deep_gemm
=
pytest
.
mark
.
skipif
(
not
has_deep_gemm
,
reason
=
"Requires deep_gemm kernels"
,
)
def
calc_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
):
x
,
y
=
x
.
double
(),
y
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
1
-
sim
def
per_block_cast_to_fp8
(
x
:
torch
.
Tensor
,
block_size_n
:
int
=
128
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
cdiv
(
m
,
128
)
*
128
,
cdiv
(
n
,
block_size_n
)
*
block_size_n
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
block_size_n
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
x_scaled_sub
=
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
()
scales
=
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
return
x_scaled_sub
,
scales
def
make_block_quant_fp8_weights
(
e
:
int
,
n
:
int
,
k
:
int
,
block_size
:
list
[
int
],
):
"""
Generate (w1, w2) expert weights and their per-block scale tensors
in FP8 block-quantized format.
w1 shape: (E, 2N, K)
w2 shape: (E, K, N)
"""
dtype
=
torch
.
bfloat16
fp8_max
,
fp8_min
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
,
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
# bf16 reference weights
w1_bf16
=
torch
.
randn
(
e
,
2
*
n
,
k
,
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2_bf16
=
torch
.
randn
(
e
,
k
,
n
,
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1_bf16
.
clamp_
(
fp8_min
,
fp8_max
)
w2_bf16
.
clamp_
(
fp8_min
,
fp8_max
)
block_n
,
block_k
=
block_size
n_tiles_w1
=
math
.
ceil
((
2
*
n
)
/
block_n
)
k_tiles_w1
=
math
.
ceil
(
k
/
block_k
)
n_tiles_w2
=
math
.
ceil
(
k
/
block_n
)
k_tiles_w2
=
math
.
ceil
(
n
/
block_k
)
w1
=
torch
.
empty_like
(
w1_bf16
,
dtype
=
torch
.
float8_e4m3fn
)
w2
=
torch
.
empty_like
(
w2_bf16
,
dtype
=
torch
.
float8_e4m3fn
)
w1_s
=
torch
.
empty
(
e
,
n_tiles_w1
,
k_tiles_w1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w2_s
=
torch
.
empty
(
e
,
n_tiles_w2
,
k_tiles_w2
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
for
i
in
range
(
e
):
w1
[
i
],
w1_s
[
i
]
=
per_block_cast_to_fp8
(
w1_bf16
[
i
])
w2
[
i
],
w2_s
[
i
]
=
per_block_cast_to_fp8
(
w2_bf16
[
i
])
return
w1
,
w2
,
w1_s
,
w2_s
def
run_single_case
(
m
,
n
,
k
,
topk
,
num_experts
,
block_size
):
"""
Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
Triton baseline within tolerance.
"""
tokens_bf16
=
torch
.
randn
(
m
,
k
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
).
clamp_min_
(
-
1
).
clamp_max_
(
1
)
_
,
a1_scale
=
per_token_group_quant_fp8
(
tokens_bf16
,
block_size
[
1
])
# expert weight tensors
w1
,
w2
,
w1_s
,
w2_s
=
make_block_quant_fp8_weights
(
num_experts
,
n
,
k
,
block_size
)
router_logits
=
torch
.
randn
(
m
,
num_experts
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
=
torch
.
topk
(
router_logits
,
k
=
topk
,
dim
=-
1
)
topk_weights
=
torch
.
nn
.
functional
.
softmax
(
topk_weights
,
dim
=-
1
)
# triton referrence
out_triton
=
fused_experts
(
hidden_states
=
tokens_bf16
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
block_shape
=
block_size
,
allow_deep_gemm
=
False
,
)
# DeepGemm
out_deepgemm
=
fused_experts
(
hidden_states
=
tokens_bf16
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
a1_scale
=
a1_scale
,
block_shape
=
block_size
,
allow_deep_gemm
=
True
,
)
base
=
out_triton
.
abs
().
mean
()
atol
=
0.1
*
base
.
clamp
(
min
=
1e-2
)
# 10% of mean, but not lower than 1e-3
rtol
=
0.05
# ----- Compare -----
torch
.
testing
.
assert_close
(
out_deepgemm
.
to
(
torch
.
float32
),
out_triton
.
to
(
torch
.
float32
),
rtol
=
rtol
,
atol
=
float
(
atol
),
)
# Note: W1 has shape (E, 2N, K), so N = 512
# can trigger the deepgemm path.
MNKs
=
[
(
1024
,
512
,
128
),
(
1024
,
512
,
512
),
(
2048
,
512
,
512
),
(
512
,
1024
,
1024
),
(
512
,
2048
,
2048
),
(
4096
,
4096
,
1024
),
]
TOPKS
=
[
2
,
6
]
NUM_EXPERTS
=
[
32
]
@
pytest
.
mark
.
parametrize
(
"mnk"
,
MNKs
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOPKS
)
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
NUM_EXPERTS
)
@
requires_deep_gemm
def
test_deepgemm_vs_triton
(
mnk
,
topk
,
num_experts
,
monkeypatch
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_DEEP_GEMM"
,
"1"
)
_fused_moe_mod
=
importlib
.
import_module
(
"vllm.model_executor.layers.fused_moe.fused_moe"
)
call_counter
=
{
"cnt"
:
0
}
orig_fn
=
_fused_moe_mod
.
deep_gemm_moe_fp8
def
_spy_deep_gemm_moe_fp8
(
*
args
,
**
kwargs
):
call_counter
[
"cnt"
]
+=
1
return
orig_fn
(
*
args
,
**
kwargs
)
monkeypatch
.
setattr
(
_fused_moe_mod
,
"deep_gemm_moe_fp8"
,
_spy_deep_gemm_moe_fp8
)
m
,
n
,
k
=
mnk
if
topk
>
num_experts
:
pytest
.
skip
(
f
"topk=
{
topk
}
> num_experts=
{
num_experts
}
"
)
run_single_case
(
m
=
m
,
n
=
n
,
k
=
k
,
topk
=
topk
,
num_experts
=
num_experts
,
block_size
=
BLOCK_SIZE
,
)
# ensure that the DeepGEMM path was indeed taken.
assert
call_counter
[
"cnt"
]
==
1
,
\
f
"DeepGEMM path was not executed during the test. "
\
f
"Call counter:
{
call_counter
[
'cnt'
]
}
"
tests/kernels/moe/test_moe.py
View file @
99324e25
...
@@ -4,6 +4,9 @@
...
@@ -4,6 +4,9 @@
Run `pytest tests/kernels/test_moe.py`.
Run `pytest tests/kernels/test_moe.py`.
"""
"""
import
functools
from
typing
import
Callable
,
Optional
,
Union
import
pytest
import
pytest
import
torch
import
torch
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
...
@@ -14,8 +17,11 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
...
@@ -14,8 +17,11 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import
vllm.model_executor.layers.fused_moe
# noqa
import
vllm.model_executor.layers.fused_moe
# noqa
from
tests.kernels.utils
import
opcheck
,
stack_and_dev
,
torch_moe
from
tests.kernels.utils
import
opcheck
,
stack_and_dev
,
torch_moe
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.distributed.parallel_state
import
init_distributed_environment
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
modular_triton_fused_moe
)
from
vllm.model_executor.layers.fused_moe.moe_torch_iterative
import
(
from
vllm.model_executor.layers.fused_moe.moe_torch_iterative
import
(
fused_moe
as
iterative_moe
)
fused_moe
as
iterative_moe
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
...
@@ -39,7 +45,76 @@ vllm_config.scheduler_config.max_num_seqs = 128
...
@@ -39,7 +45,76 @@ vllm_config.scheduler_config.max_num_seqs = 128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
vllm_config
.
scheduler_config
.
max_model_len
=
8192
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
,
1024
*
128
])
def
run_moe_test
(
baseline
:
Union
[
Callable
,
torch
.
Tensor
],
moe_fn
:
Callable
,
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
score
:
torch
.
Tensor
,
topk
:
int
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
padding
:
bool
=
False
,
use_compile
:
bool
=
False
,
use_cudagraph
:
bool
=
False
,
atol
:
float
=
2e-2
,
rtol
:
float
=
0
,
)
->
torch
.
Tensor
:
if
isinstance
(
baseline
,
torch
.
Tensor
):
baseline_output
=
baseline
else
:
baseline_output
=
baseline
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
# Pad the weight if moe padding is enabled
if
padding
:
w1
=
F
.
pad
(
w1
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
w2
=
F
.
pad
(
w2
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
if
use_compile
:
moe_fn
=
torch
.
compile
(
moe_fn
,
backend
=
"inductor"
,
fullgraph
=
True
)
torch
.
_dynamo
.
mark_dynamic
(
a
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
score
,
0
)
test_output
=
moe_fn
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
if
use_cudagraph
:
test_output
.
fill_
(
0
)
stream
=
torch
.
cuda
.
Stream
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
stream
):
test_output
=
moe_fn
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
torch
.
cuda
.
synchronize
()
graph
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
test_output
,
baseline_output
,
atol
=
atol
,
rtol
=
rtol
)
return
baseline_output
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
,
32768
,
40000
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
...
@@ -47,6 +122,7 @@ vllm_config.scheduler_config.max_model_len = 8192
...
@@ -47,6 +122,7 @@ vllm_config.scheduler_config.max_model_len = 8192
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"chunk_size"
,
[
8192
])
def
test_fused_moe
(
def
test_fused_moe
(
m
:
int
,
m
:
int
,
n
:
int
,
n
:
int
,
...
@@ -56,7 +132,21 @@ def test_fused_moe(
...
@@ -56,7 +132,21 @@ def test_fused_moe(
ep_size
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
padding
:
bool
,
padding
:
bool
,
chunk_size
:
int
,
monkeypatch
,
):
):
current_platform
.
seed_everything
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
str
(
chunk_size
))
#
# Setup test data
#
#
# Setup test data
#
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
...
@@ -76,38 +166,70 @@ def test_fused_moe(
...
@@ -76,38 +166,70 @@ def test_fused_moe(
else
:
else
:
e_map
=
None
e_map
=
None
with
set_current_vllm_config
(
vllm_config
):
#
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
# Setup test functions
iterative_output
=
iterative_moe
(
a
,
#
w1
,
w2
,
m_fused_moe_fn
=
modular_triton_fused_moe
(
use_fp8_w8a8
=
False
,
score
,
use_int8_w8a8
=
False
,
topk
,
use_int8_w8a16
=
False
,
global_num_experts
=
e
,
use_int4_w4a16
=
False
,
expert_map
=
e_map
,
per_act_token_quant
=
False
,
renormalize
=
False
)
block_shape
=
None
)
def
m_fused_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
score
:
torch
.
Tensor
,
topk
:
int
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
return
m_fused_moe_fn
(
a
,
w1
,
w2
,
topk_weights
,
topk_ids
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
fused_moe_fn
=
functools
.
partial
(
fused_moe
,
renormalize
=
False
)
#
# Run tests
#
runner
=
functools
.
partial
(
run_moe_test
,
a
=
a
,
w1
=
w1
,
w2
=
w2
,
score
=
score
,
topk
=
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
padding
=
padding
,
)
# Pad the weight if moe padding is enabled
# Note: for now use_compile will error out if the problem size is
if
padding
:
# large enough to trigger chunking. I'm leaving the flag and
w1
=
F
.
pad
(
w1
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
# setup code in case we are able to revisit this later.
torch
.
cuda
.
empty_cache
()
use_compile
=
False
w2
=
F
.
pad
(
w2
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
torch
.
cuda
.
empty_cache
()
triton_output
=
fused_moe
(
a
,
use_cudagraph
=
(
n
>=
1024
and
k
>=
1024
w1
,
and
current_platform
.
is_cuda_alike
())
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
with
set_current_vllm_config
(
vllm_config
):
torch
.
testing
.
assert_close
(
iterative_output
,
baseline_output
=
runner
(
torch_moe
,
iterative_moe
)
torch_output
,
runner
(
baseline_output
,
atol
=
2e-2
,
fused_moe_fn
,
rtol
=
0
)
use_compile
=
use_compile
,
use_cudagraph
=
use_cudagraph
)
runner
(
baseline_output
,
m_fused_moe
,
use_compile
=
use_compile
,
use_cudagraph
=
use_cudagraph
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
32
,
222
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
32
,
222
])
...
@@ -217,7 +339,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
...
@@ -217,7 +339,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
w1_zp
=
w1_qzeros
if
has_zp
else
None
,
w1_zp
=
w1_qzeros
if
has_zp
else
None
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
group_size
])
block_shape
=
[
0
,
group_size
])
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
,
e_map
)
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
,
expert_map
=
e_map
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
...
@@ -243,46 +370,59 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
...
@@ -243,46 +370,59 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
if
dtype
==
torch
.
float32
:
if
dtype
==
torch
.
float32
:
pytest
.
skip
(
"AITER ROCm test skip for float32"
)
pytest
.
skip
(
"AITER ROCm test skip for float32"
)
monkeypatch
.
setenv
(
'RANK'
,
"0"
)
monkeypatch
.
setenv
(
'LOCAL_RANK'
,
"0"
)
monkeypatch
.
setenv
(
'WORLD_SIZE'
,
"1"
)
monkeypatch
.
setenv
(
'MASTER_ADDR'
,
'localhost'
)
monkeypatch
.
setenv
(
'MASTER_PORT'
,
'12345'
)
init_distributed_environment
()
# Instantiate our and huggingface's MoE blocks
# Instantiate our and huggingface's MoE blocks
config
=
MixtralConfig
()
vllm_config
.
compilation_config
.
static_forward_context
=
dict
()
hf_moe
=
MixtralSparseMoeBlock
(
config
).
to
(
dtype
).
to
(
"cuda"
)
with
(
set_current_vllm_config
(
vllm_config
),
vllm_moe
=
MixtralMoE
(
set_forward_context
(
None
,
vllm_config
)):
num_experts
=
config
.
num_local_experts
,
config
=
MixtralConfig
()
top_k
=
config
.
num_experts_per_tok
,
hf_moe
=
MixtralSparseMoeBlock
(
config
).
to
(
dtype
).
to
(
"cuda"
)
hidden_size
=
config
.
hidden_size
,
vllm_moe
=
MixtralMoE
(
intermediate_size
=
config
.
intermediate_size
,
num_experts
=
config
.
num_local_experts
,
params_dtype
=
dtype
,
top_k
=
config
.
num_experts_per_tok
,
tp_size
=
1
,
hidden_size
=
config
.
hidden_size
,
dp_size
=
1
,
intermediate_size
=
config
.
intermediate_size
,
).
cuda
()
params_dtype
=
dtype
,
tp_size
=
1
,
# Load the weights
dp_size
=
1
,
vllm_moe
.
gate
.
weight
.
data
[:]
=
hf_moe
.
gate
.
weight
.
data
).
cuda
()
for
i
in
range
(
config
.
num_local_experts
):
weights
=
(
hf_moe
.
experts
[
i
].
w1
.
weight
.
data
,
# Load the weights
hf_moe
.
experts
[
i
].
w3
.
weight
.
data
)
vllm_moe
.
gate
.
weight
.
data
[:]
=
hf_moe
.
gate
.
weight
.
data
vllm_moe
.
experts
.
w13_weight
[
i
][:]
=
torch
.
cat
(
weights
,
dim
=
0
)
for
i
in
range
(
config
.
num_local_experts
):
vllm_moe
.
experts
.
w2_weight
[
i
][:]
=
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
weights
=
(
hf_moe
.
experts
[
i
].
w1
.
weight
.
data
,
hf_moe
.
experts
[
i
].
w3
.
weight
.
data
)
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
vllm_moe
.
experts
.
w13_weight
[
i
][:]
=
torch
.
cat
(
weights
,
dim
=
0
)
hf_inputs
=
torch
.
randn
((
1
,
64
,
config
.
hidden_size
)).
to
(
dtype
).
to
(
"cuda"
)
vllm_moe
.
experts
.
w2_weight
[
i
][:]
=
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs
=
hf_inputs
.
flatten
(
0
,
1
)
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs
=
torch
.
randn
(
(
1
,
64
,
config
.
hidden_size
)).
to
(
dtype
).
to
(
"cuda"
)
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs
=
hf_inputs
.
flatten
(
0
,
1
)
# Pad the weight if moe padding is enabled
# Pad the weight if moe padding is enabled
if
padding
:
if
padding
:
vllm_moe
.
experts
.
w13_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w13_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w13_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
vllm_moe
.
experts
.
w13_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
requires_grad
=
False
)
0
:
-
128
],
torch
.
cuda
.
empty_cache
()
requires_grad
=
False
)
vllm_moe
.
experts
.
w2_weight
=
Parameter
(
F
.
pad
(
torch
.
cuda
.
empty_cache
()
vllm_moe
.
experts
.
w2_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
vllm_moe
.
experts
.
w2_weight
=
Parameter
(
F
.
pad
(
requires_grad
=
False
)
vllm_moe
.
experts
.
w2_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
torch
.
cuda
.
empty_cache
()
0
:
-
128
],
requires_grad
=
False
)
# Run forward passes for both MoE blocks
torch
.
cuda
.
empty_cache
()
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_inputs
)
# Run forward passes for both MoE blocks
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_inputs
)
mixtral_moe_tol
=
{
mixtral_moe_tol
=
{
torch
.
float32
:
1e-3
,
torch
.
float32
:
1e-3
,
...
@@ -525,7 +665,12 @@ def test_fused_marlin_moe(
...
@@ -525,7 +665,12 @@ def test_fused_marlin_moe(
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
e_map
)
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
expert_map
=
e_map
)
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
a
,
a
,
...
...
tests/kernels/moe/test_moe_align_block_size.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size_triton
)
@
pytest
.
mark
.
parametrize
(
"block_size,num_tokens,topk,num_experts"
,
list
(
itertools
.
product
(
[
32
,
64
,
128
,
256
],
# block_size
[
1
,
3
,
7
,
16
,
256
,
2256
,
4096
,
],
# num_tokens
[
1
,
4
,
16
,
64
],
# topk
[
64
,
160
,
256
,
257
,
260
,
264
],
# num_experts
)),
)
def
test_moe_align_block_size_compare_implementations
(
block_size
,
num_tokens
,
topk
,
num_experts
):
topk_ids
=
torch
.
stack
([
torch
.
randperm
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[:
topk
]
for
_
in
range
(
num_tokens
)
])
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids_cuda
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids_cuda
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
expert_ids_cuda
=
torch
.
zeros
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad_cuda
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids_triton
=
torch
.
empty_like
(
sorted_ids_cuda
)
sorted_ids_triton
.
fill_
(
topk_ids
.
numel
())
expert_ids_triton
=
torch
.
zeros_like
(
expert_ids_cuda
)
num_tokens_post_pad_triton
=
torch
.
empty_like
(
num_tokens_post_pad_cuda
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_cuda
,
expert_ids_cuda
,
num_tokens_post_pad_cuda
,
)
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_triton
,
expert_ids_triton
,
num_tokens_post_pad_triton
,
)
assert
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_triton
),
(
f
"Expert IDs mismatch for block_size=
{
block_size
}
, "
f
"num_tokens=
{
num_tokens
}
, topk=
{
topk
}
\n
"
f
"CUDA expert_ids:
{
expert_ids_cuda
}
\n
"
f
"Triton expert_ids:
{
expert_ids_triton
}
"
)
assert
torch
.
allclose
(
num_tokens_post_pad_cuda
,
num_tokens_post_pad_triton
),
(
f
"Num tokens post pad mismatch for block_size=
{
block_size
}
, "
f
"num_tokens=
{
num_tokens
}
, topk=
{
topk
}
\n
"
f
"CUDA num_tokens_post_pad:
{
num_tokens_post_pad_cuda
}
\n
"
f
"Triton num_tokens_post_pad:
{
num_tokens_post_pad_triton
}
"
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Prev
1
…
12
13
14
15
16
17
18
19
20
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment