Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6efaf21a
Commit
6efaf21a
authored
Apr 14, 2026
by
chenhw5
Committed by
zhangzbb
Apr 16, 2026
Browse files
[BUGFIX]修复deepgemm算子导致的GLM5 W8A8精度问题。
parent
fbfe20c6
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
38 additions
and
8 deletions
+38
-8
tests/kernels/moe/untest_silu_mul_fp8_quant_deep_gemm.py
tests/kernels/moe/untest_silu_mul_fp8_quant_deep_gemm.py
+1
-1
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+7
-1
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+3
-3
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+24
-0
vllm/utils/import_utils.py
vllm/utils/import_utils.py
+2
-2
No files found.
tests/kernels/moe/untest_silu_mul_fp8_quant_deep_gemm.py
View file @
6efaf21a
...
@@ -244,7 +244,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt
...
@@ -244,7 +244,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt
and
current_platform
.
has_device_capability
(
100
)
and
current_platform
.
has_device_capability
(
100
)
and
scale_fmt
==
DeepGemmQuantScaleFMT
.
UE8M0
and
scale_fmt
==
DeepGemmQuantScaleFMT
.
UE8M0
):
):
from
deep
_
gemm
import
transform_sf_into_required_layout
from
deepgemm
import
transform_sf_into_required_layout
_q
,
_s
=
ref_with_scale_fmt
(
_q
,
_s
=
ref_with_scale_fmt
(
E
,
E
,
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
6efaf21a
...
@@ -36,7 +36,7 @@ from vllm.utils.import_utils import has_deep_gemm
...
@@ -36,7 +36,7 @@ from vllm.utils.import_utils import has_deep_gemm
from
lightop
import
fuse_silu_mul_quant_ep
from
lightop
import
fuse_silu_mul_quant_ep
if
has_deep_gemm
():
if
has_deep_gemm
():
from
deep
_
gemm
import
m_grouped_w8a8_gemm_nt_masked
from
deepgemm
import
m_grouped_w8a8_gemm_nt_masked
else
:
else
:
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
...
@@ -481,5 +481,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -481,5 +481,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
output
,
output
,
expert_num_tokens
,
expert_num_tokens
,
expected_m
)
expected_m
)
# moe_grouped_gemm(a1q, w1, a1q_scale, self.w1_scale, expert_num_tokens, workspace1)
# act_out = self.act_fn(workspace1)
# a2q, a2q_scale = per_token_quant_int8(act_out)
# moe_grouped_gemm(a2q, w2, a2q_scale, self.w2_scale, expert_num_tokens, output)
else
:
else
:
raise
ValueError
(
f
"Unsupported dtype
{
self
.
quant_config
.
quant_dtype
}
"
)
raise
ValueError
(
f
"Unsupported dtype
{
self
.
quant_config
.
quant_dtype
}
"
)
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
6efaf21a
...
@@ -39,7 +39,7 @@ from vllm.utils.import_utils import has_deep_gemm
...
@@ -39,7 +39,7 @@ from vllm.utils.import_utils import has_deep_gemm
from
lightop
import
fuse_silu_mul_quant
from
lightop
import
fuse_silu_mul_quant
if
has_deep_gemm
():
if
has_deep_gemm
():
from
deep
_
gemm
import
m_grouped_i8_gemm_nt_contiguous
from
deepgemm
import
m_grouped_i8_gemm_nt_contiguous
else
:
else
:
from
lightop
import
m_grouped_w8a8_gemm_nt_contig_asm
as
m_grouped_i8_gemm_nt_contiguous
from
lightop
import
m_grouped_w8a8_gemm_nt_contig_asm
as
m_grouped_i8_gemm_nt_contiguous
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
6efaf21a
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe import (
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported
,
FusedMoEConfig
)
FusedMoeWeightScaleSupported
,
FusedMoEConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
get_w8a8_int8_marlin_weights
,
w8a8_nt_kpack2_marlin_weight
)
get_w8a8_int8_marlin_weights
,
w8a8_nt_kpack2_marlin_weight
,
weight8bit_nt_kpack2_marlin1
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
int8_w8a8_moe_quant_config
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
int8_w8a8_moe_quant_config
)
from
vllm.model_executor.layers.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoE
,
...
@@ -375,7 +375,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -375,7 +375,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
if
not
self
.
use_deepep
:
if
not
self
.
use_deepep
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
else
:
w1_marlin_in
=
w
8a8
_nt_kpack2_marlin
_weight
(
layer
.
w13_weight
[
ii
])
w1_marlin_in
=
w
eight8bit
_nt_kpack2_marlin
1
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
...
@@ -385,7 +385,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -385,7 +385,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
if
not
self
.
use_deepep
:
if
not
self
.
use_deepep
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
else
:
w2_marlin_in
=
w
8a8
_nt_kpack2_marlin
_weight
(
layer
.
w2_weight
[
ii
])
w2_marlin_in
=
w
eight8bit
_nt_kpack2_marlin
1
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
6efaf21a
...
@@ -43,6 +43,30 @@ def w8a8_nt_kpack2_marlin_weight(w8a8_w, # [size_n, size_k// 2 ]
...
@@ -43,6 +43,30 @@ def w8a8_nt_kpack2_marlin_weight(w8a8_w, # [size_n, size_k// 2 ]
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
return
w8a8_w
return
w8a8_w
def
weight8bit_nt_kpack2_marlin1
(
weight
,
# [size_n, size_k// 2 ]
k_tile
=
16
,
k_tile1
=
4
,
n_tile
=
16
,
n_tile1
=
16
):
assert
weight
.
element_size
()
==
1
,
"weight 必须是 8 bit 类型"
if
weight
.
dim
()
==
2
:
size_n
,
size_k
=
weight
.
shape
assert
size_n
%
k_tile
==
0
and
size_k
%
n_tile
==
0
,
"k_tile / n_tile 必须能整除对应维度"
q
=
weight
.
reshape
((
size_n
//
(
n_tile
*
n_tile1
),
n_tile1
,
n_tile
,
size_k
//
(
k_tile
*
k_tile1
),
k_tile1
,
k_tile
))
# q = q.permute((0, 2, 1, 3)).contiguous()
q
=
q
.
permute
((
0
,
3
,
1
,
4
,
2
,
5
)).
contiguous
()
q
=
q
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
elif
weight
.
dim
()
==
3
:
E
,
size_n
,
size_k
=
weight
.
shape
assert
size_n
%
n_tile
==
0
and
size_k
%
k_tile
==
0
,
"k_tile / n_tile 必须能整除对应维度"
q
=
weight
.
reshape
((
E
,
size_n
//
(
n_tile
*
n_tile1
),
n_tile1
,
n_tile
,
size_k
//
(
k_tile
*
k_tile1
),
k_tile1
,
k_tile
))
q
=
q
.
permute
((
0
,
1
,
4
,
2
,
5
,
3
,
6
)).
contiguous
()
q
=
q
.
reshape
((
E
,
size_n
//
k_tile
,
size_k
*
k_tile
))
return
q
def
sparse_cutlass_supported
()
->
bool
:
def
sparse_cutlass_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
if
not
current_platform
.
is_cuda
():
...
...
vllm/utils/import_utils.py
View file @
6efaf21a
...
@@ -413,8 +413,8 @@ def has_deep_ep() -> bool:
...
@@ -413,8 +413,8 @@ def has_deep_ep() -> bool:
def
has_deep_gemm
()
->
bool
:
def
has_deep_gemm
()
->
bool
:
"""Whether the optional `deep
_
gemm` package is available."""
"""Whether the optional `deepgemm` package is available."""
return
_has_module
(
"deep
_
gemm"
)
return
_has_module
(
"deepgemm"
)
def
has_triton_kernels
()
->
bool
:
def
has_triton_kernels
()
->
bool
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment