Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2a75c6bc
"platforms/vscode:/vscode.git/clone" did not exist on "6f24ee8a965137f83409d0f57b4a0f96eb1f2fb4"
Commit
2a75c6bc
authored
Jan 15, 2026
by
zhuwenwen
Browse files
fix tests of kernels
parent
3dd7fd64
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
325 additions
and
325 deletions
+325
-325
tests/kernels/core/untest_mrope.py
tests/kernels/core/untest_mrope.py
+23
-23
tests/kernels/moe/test_batched_moe.py
tests/kernels/moe/test_batched_moe.py
+1
-1
tests/kernels/moe/untest_block_int8.py
tests/kernels/moe/untest_block_int8.py
+0
-0
tests/kernels/moe/untest_moe.py
tests/kernels/moe/untest_moe.py
+301
-301
No files found.
tests/kernels/core/test_mrope.py
→
tests/kernels/core/
un
test_mrope.py
View file @
2a75c6bc
...
...
@@ -49,27 +49,27 @@ class MRoPETestInfo(NamedTuple):
TRANSFORMERS_BASE_VERSION
=
Version
(
TRANSFORMERS_VERSION
).
base_version
MODELS_TO_TEST
=
[
MRoPETestInfo
(
model_name
=
os
.
path
.
join
(
models_path_prefix
,
"zai-org/GLM-4.1V-9B-Thinking"
)
)
,
MRoPETestInfo
(
model_name
=
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2-VL-7B-Instruct"
)
)
,
MRoPETestInfo
(
model_name
=
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2-VL-72B-Instruct"
)
)
,
# MRoPETestInfo(model_name=
os.path.join(
"Qwen/Qwen2.5-VL-72B-Instruct")
)
,
MRoPETestInfo
(
model_name
=
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-VL-4B-Instruct"
)
,
marks
=
[
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_BASE_VERSION
)
<
Version
(
"4.57.0"
),
reason
=
"Qwen3-VL only available after Transformers v4.57"
,
)
]),
MRoPETestInfo
(
model_name
=
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-VL-30B-A3B-Instruct"
)
,
marks
=
[
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_BASE_VERSION
)
<
Version
(
"4.57.0"
),
reason
=
"Qwen3-VL only available after Transformers v4.57"
,
)
]),
MODELS_TO_TEST
=
[
MRoPETestInfo
(
model_name
=
"zai-org/GLM-4.1V-9B-Thinking"
),
#
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
#
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
# MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
#
MRoPETestInfo(
#
model_name="Qwen/Qwen3-VL-4B-Instruct",
#
marks=[
#
pytest.mark.skipif(
#
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
#
reason="Qwen3-VL only available after Transformers v4.57",
#
)
#
]),
#
MRoPETestInfo(
#
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
#
marks=[
#
pytest.mark.skipif(
#
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
#
reason="Qwen3-VL only available after Transformers v4.57",
#
)
#
]),
]
num_tokens_list
=
[
11
,
8192
]
...
...
@@ -78,7 +78,7 @@ num_tokens_list = [11, 8192]
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Skipping CUDA/ROCm only tests."
)
@
pytest
.
mark
.
parametrize
(
"model_info, model_name"
,
[
pytest
.
param
(
test_config
,
test_config
.
model_name
,
marks
=
test_config
.
marks
)
pytest
.
param
(
test_config
,
os
.
path
.
join
(
models_path_prefix
,
test_config
.
model_name
)
,
marks
=
test_config
.
marks
)
for
test_config
in
MODELS_TO_TEST
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
...
...
@@ -90,7 +90,7 @@ def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
atol
=
model_info
.
atol
rtol
=
model_info
.
rtol
config
=
AutoConfig
.
from_pretrained
(
model_name
)
config
=
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
config
=
config
.
get_text_config
()
# get the model config
...
...
tests/kernels/moe/test_batched_moe.py
View file @
2a75c6bc
...
...
@@ -90,7 +90,7 @@ class BatchedMMTensors:
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
8
,
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens_per_expert"
,
[
32
,
224
,
512
])
@
pytest
.
mark
.
parametrize
(
"max_tokens_per_expert"
,
[
32
,
512
])
# 224
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
bfloat16
]
if
not
current_platform
.
is_rocm
()
else
[
torch
.
bfloat16
])
...
...
tests/kernels/moe/test_block_int8.py
→
tests/kernels/moe/
un
test_block_int8.py
View file @
2a75c6bc
File moved
tests/kernels/moe/test_moe.py
→
tests/kernels/moe/
un
test_moe.py
View file @
2a75c6bc
...
...
@@ -528,305 +528,305 @@ def marlin_moe_generate_valid_test_cases():
return
cases
@
pytest
.
mark
.
flaky
(
reruns
=
2
)
@
pytest
.
mark
.
parametrize
((
"m, n, k, e, topk, ep_size, dtype, group_size,"
"act_order, quant_type, is_k_full"
),
marlin_moe_generate_valid_test_cases
())
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Skip for rocm"
)
def
test_fused_marlin_moe
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
act_order
:
bool
,
quant_type
:
ScalarType
,
is_k_full
:
bool
,
):
torch
.
cuda
.
manual_seed
(
0
)
has_zp
=
quant_type
in
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
#
@pytest.mark.flaky(reruns=2)
#
@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size,"
#
"act_order, quant_type, is_k_full"),
#
marlin_moe_generate_valid_test_cases())
#
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
#
def test_fused_marlin_moe(
#
m: int,
#
n: int,
#
k: int,
#
e: int,
#
topk: int,
#
ep_size: int,
#
dtype: torch.dtype,
#
group_size: int,
#
act_order: bool,
#
quant_type: ScalarType,
#
is_k_full: bool,
#
):
#
torch.cuda.manual_seed(0)
#
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
20
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
20
if
ep_size
>
1
:
local_e
=
e
//
ep_size
e_ids
=
torch
.
randperm
(
e
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[:
local_e
]
e_map
=
torch
.
full
((
e
,
),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
[
e_ids
]
=
torch
.
arange
(
local_e
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
w1
=
w1
[
e_ids
]
w2
=
w2
[
e_ids
]
else
:
e_map
=
None
# a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
# w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
# w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
w_ref1_l
=
[]
qweight1_l
=
[]
scales1_l
=
[]
global_scale1_l
=
[]
zeros1_l
=
[]
g_idx1_l
=
[]
sort_indices1_l
=
[]
# if ep_size > 1:
# local_e = e // ep_size
# e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
# e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
# e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
# w1 = w1[e_ids]
# w2 = w2[e_ids]
# else:
# e_map = None
for
i
in
range
(
w1
.
shape
[
0
]):
if
quant_type
==
scalar_types
.
float4_e2m1f
:
if
group_size
==
16
:
w_ref1
,
qweight1
,
scales1
,
global_scale1
=
\
rand_marlin_weight_nvfp4_like
(
w1
[
i
],
group_size
)
else
:
w_ref1
,
qweight1
,
scales1
=
\
rand_marlin_weight_mxfp4_like
(
w1
[
i
],
group_size
)
global_scale1
=
None
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
if
global_scale1
is
not
None
:
global_scale1_l
.
append
(
global_scale1
)
elif
quant_type
==
scalar_types
.
float8_e4m3fn
:
w_ref1
,
qweight1
,
scales1
=
marlin_quant_fp8_torch
(
w1
[
i
],
group_size
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
elif
has_zp
:
w_ref1
,
qweight1
,
scales1
,
zeros1
=
awq_marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
zeros1_l
.
append
(
zeros1
)
else
:
test_perm
=
torch
.
randperm
(
k
)
w_ref1
,
qweight1
,
scales1
,
g_idx1
,
sort_indices1
,
_
=
\
marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
g_idx1_l
.
append
(
g_idx1
)
sort_indices1_l
.
append
(
sort_indices1
)
w_ref1
=
stack_and_dev
(
w_ref1_l
)
qweight1
=
stack_and_dev
(
qweight1_l
).
contiguous
()
scales1
=
stack_and_dev
(
scales1_l
)
global_scale1
=
stack_and_dev
(
global_scale1_l
)
if
global_scale1_l
else
None
g_idx1
=
stack_and_dev
(
g_idx1_l
)
if
g_idx1_l
else
None
zeros1
=
stack_and_dev
(
zeros1_l
)
if
zeros1_l
else
None
sort_indices1
=
stack_and_dev
(
sort_indices1_l
)
if
sort_indices1_l
else
None
w_ref2_l
=
[]
qweight2_l
=
[]
scales2_l
=
[]
global_scale2_l
=
[]
zeros2_l
=
[]
g_idx2_l
=
[]
sort_indices2_l
=
[]
for
i
in
range
(
w2
.
shape
[
0
]):
if
quant_type
==
scalar_types
.
float4_e2m1f
:
if
group_size
==
16
:
w_ref2
,
qweight2
,
scales2
,
global_scale2
=
\
rand_marlin_weight_nvfp4_like
(
w2
[
i
],
group_size
)
else
:
w_ref2
,
qweight2
,
scales2
=
\
rand_marlin_weight_mxfp4_like
(
w2
[
i
],
group_size
)
global_scale2
=
None
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
if
global_scale2
is
not
None
:
global_scale2_l
.
append
(
global_scale2
)
elif
quant_type
==
scalar_types
.
float8_e4m3fn
:
w_ref2
,
qweight2
,
scales2
=
marlin_quant_fp8_torch
(
w2
[
i
],
group_size
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
elif
has_zp
:
w_ref2
,
qweight2
,
scales2
,
zeros2
=
awq_marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
zeros2_l
.
append
(
zeros2
)
else
:
test_perm
=
torch
.
randperm
(
n
)
w_ref2
,
qweight2
,
scales2
,
g_idx2
,
sort_indices2
,
_
=
\
marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
g_idx2_l
.
append
(
g_idx2
)
sort_indices2_l
.
append
(
sort_indices2
)
w_ref2
=
stack_and_dev
(
w_ref2_l
)
qweight2
=
stack_and_dev
(
qweight2_l
).
contiguous
()
scales2
=
stack_and_dev
(
scales2_l
)
global_scale2
=
stack_and_dev
(
global_scale2_l
)
if
global_scale2_l
else
None
g_idx2
=
stack_and_dev
(
g_idx2_l
)
if
g_idx2_l
else
None
zeros2
=
stack_and_dev
(
zeros2_l
)
if
zeros2_l
else
None
sort_indices2
=
stack_and_dev
(
sort_indices2_l
)
if
sort_indices2_l
else
None
# w_ref1_l = []
# qweight1_l = []
# scales1_l = []
# global_scale1_l = []
# zeros1_l = []
# g_idx1_l = []
# sort_indices1_l = []
# for i in range(w1.shape[0]):
# if quant_type == scalar_types.float4_e2m1f:
# if group_size == 16:
# w_ref1, qweight1, scales1, global_scale1 = \
# rand_marlin_weight_nvfp4_like(w1[i], group_size)
# else:
# w_ref1, qweight1, scales1 = \
# rand_marlin_weight_mxfp4_like(w1[i], group_size)
# global_scale1 = None
# w_ref1_l.append(w_ref1.T)
# qweight1_l.append(qweight1)
# scales1_l.append(scales1)
# if global_scale1 is not None:
# global_scale1_l.append(global_scale1)
# elif quant_type == scalar_types.float8_e4m3fn:
# w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
# w1[i], group_size)
# w_ref1_l.append(w_ref1.T)
# qweight1_l.append(qweight1)
# scales1_l.append(scales1)
# elif has_zp:
# w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
# w1[i].transpose(1, 0), quant_type, group_size)
# w_ref1_l.append(w_ref1.T)
# qweight1_l.append(qweight1)
# scales1_l.append(scales1)
# zeros1_l.append(zeros1)
# else:
# test_perm = torch.randperm(k)
# w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
# marlin_quantize(w1[i].transpose(1, 0), quant_type,
# group_size, act_order, test_perm)
# w_ref1_l.append(w_ref1.T)
# qweight1_l.append(qweight1)
# scales1_l.append(scales1)
# g_idx1_l.append(g_idx1)
# sort_indices1_l.append(sort_indices1)
# w_ref1 = stack_and_dev(w_ref1_l)
# qweight1 = stack_and_dev(qweight1_l).contiguous()
# scales1 = stack_and_dev(scales1_l)
# global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
# g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
# zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
# sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
# w_ref2_l = []
# qweight2_l = []
# scales2_l = []
# global_scale2_l = []
# zeros2_l = []
# g_idx2_l = []
# sort_indices2_l = []
# for i in range(w2.shape[0]):
# if quant_type == scalar_types.float4_e2m1f:
# if group_size == 16:
# w_ref2, qweight2, scales2, global_scale2 = \
# rand_marlin_weight_nvfp4_like(w2[i], group_size)
# else:
# w_ref2, qweight2, scales2 = \
# rand_marlin_weight_mxfp4_like(w2[i], group_size)
# global_scale2 = None
# w_ref2_l.append(w_ref2.T)
# qweight2_l.append(qweight2)
# scales2_l.append(scales2)
# if global_scale2 is not None:
# global_scale2_l.append(global_scale2)
# elif quant_type == scalar_types.float8_e4m3fn:
# w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
# w2[i], group_size)
# w_ref2_l.append(w_ref2.T)
# qweight2_l.append(qweight2)
# scales2_l.append(scales2)
# elif has_zp:
# w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
# w2[i].transpose(1, 0), quant_type, group_size)
# w_ref2_l.append(w_ref2.T)
# qweight2_l.append(qweight2)
# scales2_l.append(scales2)
# zeros2_l.append(zeros2)
# else:
# test_perm = torch.randperm(n)
# w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
# marlin_quantize(w2[i].transpose(1, 0), quant_type,
# group_size, act_order, test_perm)
# w_ref2_l.append(w_ref2.T)
# qweight2_l.append(qweight2)
# scales2_l.append(scales2)
# g_idx2_l.append(g_idx2)
# sort_indices2_l.append(sort_indices2)
# w_ref2 = stack_and_dev(w_ref2_l)
# qweight2 = stack_and_dev(qweight2_l).contiguous()
# scales2 = stack_and_dev(scales2_l)
# global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
# g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
# zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
# sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
#
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
#
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with
set_current_vllm_config
(
vllm_config
):
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
expert_map
=
e_map
)
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
a
,
qweight1
,
qweight2
,
None
,
None
,
scales1
,
scales2
,
score
,
topk_weights
,
topk_ids
,
global_num_experts
=
e
,
expert_map
=
e_map
,
global_scale1
=
global_scale1
,
global_scale2
=
global_scale2
,
g_idx1
=
g_idx1
,
g_idx2
=
g_idx2
,
sort_indices1
=
sort_indices1
,
sort_indices2
=
sort_indices2
,
w1_zeros
=
zeros1
,
w2_zeros
=
zeros2
,
quant_type_id
=
quant_type
.
id
,
is_k_full
=
is_k_full
)
torch
.
testing
.
assert_close
(
marlin_output
,
torch_output
,
atol
=
5e-2
,
rtol
=
0
)
@
pytest
.
mark
.
flaky
(
reruns
=
2
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Skip for rocm"
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
256
])
def
test_fused_marlin_moe_with_bias
(
m
):
torch
.
cuda
.
manual_seed
(
0
)
e
,
topk
=
32
,
4
n
,
k
=
2048
,
2048
group_size
=
128
act_order
=
False
is_k_full
=
True
quant_type
=
scalar_types
.
uint4b8
dtype
=
torch
.
half
# with set_current_vllm_config(vllm_config):
# torch_output = torch_moe(a,
# w_ref1,
# w_ref2,
# score,
# topk,
# expert_map=e_map)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
b_bias1
=
torch
.
randn
((
e
,
2
*
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
b_bias2
=
torch
.
randn
((
e
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
b_bias1_l
=
[]
w_ref1_l
=
[]
qweight1_l
=
[]
scales1_l
=
[]
g_idx1_l
=
[]
sort_indices1_l
=
[]
for
i
in
range
(
w1
.
shape
[
0
]):
test_perm
=
torch
.
randperm
(
k
)
w_ref1
,
qweight1
,
scales1
,
g_idx1
,
sort_indices1
,
_
=
\
marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
g_idx1_l
.
append
(
g_idx1
)
sort_indices1_l
.
append
(
sort_indices1
)
b_bias1_l
.
append
(
marlin_permute_bias
(
b_bias1
[
i
]))
w_ref1
=
stack_and_dev
(
w_ref1_l
)
qweight1
=
stack_and_dev
(
qweight1_l
).
contiguous
()
scales1
=
stack_and_dev
(
scales1_l
)
global_scale1
=
None
g_idx1
=
stack_and_dev
(
g_idx1_l
)
if
g_idx1_l
else
None
zeros1
=
None
sort_indices1
=
stack_and_dev
(
sort_indices1_l
)
if
sort_indices1_l
else
None
marlin_bias1
=
stack_and_dev
(
b_bias1_l
)
if
b_bias1_l
else
None
b_bias2_l
=
[]
w_ref2_l
=
[]
qweight2_l
=
[]
scales2_l
=
[]
g_idx2_l
=
[]
sort_indices2_l
=
[]
for
i
in
range
(
w2
.
shape
[
0
]):
test_perm
=
torch
.
randperm
(
n
)
w_ref2
,
qweight2
,
scales2
,
g_idx2
,
sort_indices2
,
_
=
\
marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
g_idx2_l
.
append
(
g_idx2
)
sort_indices2_l
.
append
(
sort_indices2
)
b_bias2_l
.
append
(
marlin_permute_bias
(
b_bias2
[
i
]))
w_ref2
=
stack_and_dev
(
w_ref2_l
)
qweight2
=
stack_and_dev
(
qweight2_l
).
contiguous
()
scales2
=
stack_and_dev
(
scales2_l
)
global_scale2
=
None
g_idx2
=
stack_and_dev
(
g_idx2_l
)
if
g_idx2_l
else
None
zeros2
=
None
sort_indices2
=
stack_and_dev
(
sort_indices2_l
)
if
sort_indices2_l
else
None
marlin_bias2
=
stack_and_dev
(
b_bias2_l
)
if
b_bias2_l
else
None
# marlin_output = torch.ops.vllm.fused_marlin_moe(
# a,
# qweight1,
# qweight2,
# None,
# None,
# scales1,
# scales2,
# score,
# topk_weights,
# topk_ids,
# global_num_experts=e,
# expert_map=e_map,
# global_scale1=global_scale1,
# global_scale2=global_scale2,
# g_idx1=g_idx1,
# g_idx2=g_idx2,
# sort_indices1=sort_indices1,
# sort_indices2=sort_indices2,
# w1_zeros=zeros1,
# w2_zeros=zeros2,
# quant_type_id=quant_type.id,
# is_k_full=is_k_full)
# torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
# @pytest.mark.flaky(reruns=2)
# @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
# @pytest.mark.parametrize("m", [1, 256])
# def test_fused_marlin_moe_with_bias(m):
# torch.cuda.manual_seed(0)
# e, topk = 32, 4
# n, k = 2048, 2048
# group_size = 128
# act_order = False
# is_k_full = True
# quant_type = scalar_types.uint4b8
# dtype = torch.half
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
# a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
# w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
# w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
# b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
# b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
# b_bias1_l = []
# w_ref1_l = []
# qweight1_l = []
# scales1_l = []
# g_idx1_l = []
# sort_indices1_l = []
# for i in range(w1.shape[0]):
# test_perm = torch.randperm(k)
# w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
# marlin_quantize(w1[i].transpose(1, 0), quant_type,
# group_size, act_order, test_perm)
# w_ref1_l.append(w_ref1.T)
# qweight1_l.append(qweight1)
# scales1_l.append(scales1)
# g_idx1_l.append(g_idx1)
# sort_indices1_l.append(sort_indices1)
# b_bias1_l.append(marlin_permute_bias(b_bias1[i]))
# w_ref1 = stack_and_dev(w_ref1_l)
# qweight1 = stack_and_dev(qweight1_l).contiguous()
# scales1 = stack_and_dev(scales1_l)
# global_scale1 = None
# g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
# zeros1 = None
# sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
# marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None
# b_bias2_l = []
# w_ref2_l = []
# qweight2_l = []
# scales2_l = []
# g_idx2_l = []
# sort_indices2_l = []
# for i in range(w2.shape[0]):
# test_perm = torch.randperm(n)
# w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
# marlin_quantize(w2[i].transpose(1, 0), quant_type,
# group_size, act_order, test_perm)
# w_ref2_l.append(w_ref2.T)
# qweight2_l.append(qweight2)
# scales2_l.append(scales2)
# g_idx2_l.append(g_idx2)
# sort_indices2_l.append(sort_indices2)
# b_bias2_l.append(marlin_permute_bias(b_bias2[i]))
# w_ref2 = stack_and_dev(w_ref2_l)
# qweight2 = stack_and_dev(qweight2_l).contiguous()
# scales2 = stack_and_dev(scales2_l)
# global_scale2 = None
# g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
# zeros2 = None
# sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
# marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None
topk_weights
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
Fals
e
)
#
score = torch.randn((m, e), device="cuda", dtype=dtyp
e)
with
set_current_vllm_config
(
vllm_config
):
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
b_bias1
,
b_bias2
)
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
a
,
qweight1
,
qweight2
,
marlin_bias1
,
marlin_bias2
,
scales1
,
scales2
,
score
,
topk_weights
,
topk_ids
,
global_num_experts
=
e
,
expert_map
=
None
,
global_scale1
=
global_scale1
,
global_scale2
=
global_scale2
,
g_idx1
=
g_idx1
,
g_idx2
=
g_idx2
,
sort_indices1
=
sort_indices1
,
sort_indices2
=
sort_indices2
,
w1_zeros
=
zeros1
,
w2_zeros
=
zeros2
,
quant_type_id
=
quant_type
.
id
,
is_k_full
=
is_k_full
)
# topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
torch
.
testing
.
assert_close
(
marlin_output
,
torch_output
,
atol
=
5e-2
,
rtol
=
0
)
# with set_current_vllm_config(vllm_config):
# torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1,
# b_bias2)
# marlin_output = torch.ops.vllm.fused_marlin_moe(
# a,
# qweight1,
# qweight2,
# marlin_bias1,
# marlin_bias2,
# scales1,
# scales2,
# score,
# topk_weights,
# topk_ids,
# global_num_experts=e,
# expert_map=None,
# global_scale1=global_scale1,
# global_scale2=global_scale2,
# g_idx1=g_idx1,
# g_idx2=g_idx2,
# sort_indices1=sort_indices1,
# sort_indices2=sort_indices2,
# w1_zeros=zeros1,
# w2_zeros=zeros2,
# quant_type_id=quant_type.id,
# is_k_full=is_k_full)
# torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
def
test_moe_align_block_size_opcheck
():
...
...
@@ -855,19 +855,19 @@ def test_moe_align_block_size_opcheck():
num_tokens_post_pad
))
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Skip for rocm"
)
def
test_moe_sum
(
m
:
int
,
topk
:
int
,
k
:
int
,
dtype
:
torch
.
dtype
):
input
=
torch
.
randn
((
m
,
topk
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
actual
=
torch
.
empty
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
#
@pytest.mark.parametrize("m", [1, 33, 64, 222])
#
@pytest.mark.parametrize("topk", TOP_KS)
#
@pytest.mark.parametrize("k", [128, 511, 1024])
#
@pytest.mark.parametrize("dtype",
#
[torch.float32, torch.float16, torch.bfloat16])
#
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
#
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
#
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
#
actual = torch.empty((m, k), device="cuda", dtype=dtype)
expected
=
input
.
sum
(
dim
=
1
)
torch
.
ops
.
_moe_C
.
moe_sum
(
input
,
actual
)
#
expected = input.sum(dim=1)
#
torch.ops._moe_C.moe_sum(input, actual)
torch
.
testing
.
assert_close
(
actual
,
expected
,
atol
=
2e-2
,
rtol
=
0
)
#
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
opcheck
(
torch
.
ops
.
_moe_C
.
moe_sum
,
(
input
,
actual
))
#
opcheck(torch.ops._moe_C.moe_sum, (input, actual))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment