Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1e9ff2e7
Commit
1e9ff2e7
authored
Apr 13, 2026
by
chenhw5
Browse files
fix deepgemm accuracy bug.
parent
a8d6ba1e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
74 additions
and
47 deletions
+74
-47
tests/kernels/moe/untest_silu_mul_fp8_quant_deep_gemm.py
tests/kernels/moe/untest_silu_mul_fp8_quant_deep_gemm.py
+1
-1
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+22
-20
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+24
-23
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+24
-0
vllm/utils/import_utils.py
vllm/utils/import_utils.py
+2
-2
No files found.
tests/kernels/moe/untest_silu_mul_fp8_quant_deep_gemm.py
View file @
1e9ff2e7
...
@@ -244,7 +244,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt
...
@@ -244,7 +244,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dt
and
current_platform
.
has_device_capability
(
100
)
and
current_platform
.
has_device_capability
(
100
)
and
scale_fmt
==
DeepGemmQuantScaleFMT
.
UE8M0
and
scale_fmt
==
DeepGemmQuantScaleFMT
.
UE8M0
):
):
from
deep
_
gemm
import
transform_sf_into_required_layout
from
deepgemm
import
transform_sf_into_required_layout
_q
,
_s
=
ref_with_scale_fmt
(
_q
,
_s
=
ref_with_scale_fmt
(
E
,
E
,
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
1e9ff2e7
...
@@ -41,7 +41,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
...
@@ -41,7 +41,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from
lightop
import
fuse_silu_mul_quant_ep
from
lightop
import
fuse_silu_mul_quant_ep
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
if
has_deep_gemm
():
if
has_deep_gemm
():
from
deep
_
gemm
import
m_grouped_w8a8_gemm_nt_masked
from
deepgemm
import
m_grouped_w8a8_gemm_nt_masked
else
:
else
:
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
...
@@ -642,26 +642,28 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -642,26 +642,28 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m
,
expected_m
,
)
)
elif
self
.
quant_config
.
use_int8_w8a8
:
elif
self
.
quant_config
.
use_int8_w8a8
:
# m_grouped_w8a8_gemm_nt_masked((a1q, a1q_scale),
# (w1, self.w1_scale),
# workspace1,
# expert_num_tokens,
# expected_m,
# )
# assert expert_num_tokens is not None
# a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
# m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
# (w2, self.w2_scale),
# output,
# expert_num_tokens,
# expected_m)
moe_grouped_gemm
(
a1q
,
w1
,
a1q_scale
,
self
.
w1_scale
,
expert_num_tokens
,
workspace1
)
m_grouped_w8a8_gemm_nt_masked
((
a1q
,
a1q_scale
),
act_out
=
self
.
act_fn
(
workspace1
)
(
w1
,
self
.
w1_scale
),
a2q
,
a2q_scale
=
per_token_quant_int8
(
act_out
)
workspace1
,
moe_grouped_gemm
(
a2q
,
w2
,
a2q_scale
,
self
.
w2_scale
,
expert_num_tokens
,
output
)
expert_num_tokens
,
expected_m
,
)
assert
expert_num_tokens
is
not
None
a2q
,
a2q_scale
=
fuse_silu_mul_quant_ep
(
workspace1
,
expert_num_tokens
)
m_grouped_w8a8_gemm_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
self
.
w2_scale
),
output
,
expert_num_tokens
,
expected_m
)
# moe_grouped_gemm(a1q, w1, a1q_scale, self.w1_scale, expert_num_tokens, workspace1)
# act_out = self.act_fn(workspace1)
# a2q, a2q_scale = per_token_quant_int8(act_out)
# moe_grouped_gemm(a2q, w2, a2q_scale, self.w2_scale, expert_num_tokens, output)
else
:
else
:
raise
ValueError
(
f
"Unsupported dtype
{
self
.
quant_config
.
quant_dtype
}
"
)
raise
ValueError
(
f
"Unsupported dtype
{
self
.
quant_config
.
quant_dtype
}
"
)
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
1e9ff2e7
...
@@ -39,7 +39,7 @@ from vllm.utils.import_utils import has_deep_gemm
...
@@ -39,7 +39,7 @@ from vllm.utils.import_utils import has_deep_gemm
from
lightop
import
fuse_silu_mul_quant
from
lightop
import
fuse_silu_mul_quant
if
has_deep_gemm
():
if
has_deep_gemm
():
from
deep
_
gemm
import
m_grouped_i8_gemm_nt_contiguous
from
deepgemm
import
m_grouped_i8_gemm_nt_contiguous
else
:
else
:
from
lightop
import
m_grouped_w8a8_gemm_nt_contig_asm
as
m_grouped_i8_gemm_nt_contiguous
from
lightop
import
m_grouped_w8a8_gemm_nt_contig_asm
as
m_grouped_i8_gemm_nt_contiguous
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
1e9ff2e7
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe import (
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported
,
FusedMoEConfig
)
FusedMoeWeightScaleSupported
,
FusedMoEConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
get_w8a8_int8_marlin_weights
,
w8a8_nt_kpack2_marlin_weight
)
get_w8a8_int8_marlin_weights
,
w8a8_nt_kpack2_marlin_weight
,
weight8bit_nt_kpack2_marlin1
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
int8_w8a8_moe_quant_config
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
int8_w8a8_moe_quant_config
)
from
vllm.model_executor.layers.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoE
,
...
@@ -370,28 +370,29 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -370,28 +370,29 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer
.
w2_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
not
self
.
use_deepep
:
w1_marlin_list
=
[]
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
if
not
self
.
use_deepep
:
if
not
self
.
use_deepep
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
else
:
#w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
w1_marlin_in
=
weight8bit_nt_kpack2_marlin1
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
del
w1_marlin_list
del
w1_marlin_list
w2_marlin_list
=
[]
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
if
not
self
.
use_deepep
:
if
not
self
.
use_deepep
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
else
:
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
#w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin_in
=
weight8bit_nt_kpack2_marlin1
(
layer
.
w2_weight
[
ii
])
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
def
apply
(
def
apply
(
self
,
self
,
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
1e9ff2e7
...
@@ -43,6 +43,30 @@ def w8a8_nt_kpack2_marlin_weight(w8a8_w, # [size_n, size_k// 2 ]
...
@@ -43,6 +43,30 @@ def w8a8_nt_kpack2_marlin_weight(w8a8_w, # [size_n, size_k// 2 ]
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
return
w8a8_w
return
w8a8_w
def
weight8bit_nt_kpack2_marlin1
(
weight
,
# [size_n, size_k// 2 ]
k_tile
=
16
,
k_tile1
=
4
,
n_tile
=
16
,
n_tile1
=
16
):
assert
weight
.
element_size
()
==
1
,
"weight 必须是 8 bit 类型"
if
weight
.
dim
()
==
2
:
size_n
,
size_k
=
weight
.
shape
assert
size_n
%
k_tile
==
0
and
size_k
%
n_tile
==
0
,
"k_tile / n_tile 必须能整除对应维度"
q
=
weight
.
reshape
((
size_n
//
(
n_tile
*
n_tile1
),
n_tile1
,
n_tile
,
size_k
//
(
k_tile
*
k_tile1
),
k_tile1
,
k_tile
))
# q = q.permute((0, 2, 1, 3)).contiguous()
q
=
q
.
permute
((
0
,
3
,
1
,
4
,
2
,
5
)).
contiguous
()
q
=
q
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
elif
weight
.
dim
()
==
3
:
E
,
size_n
,
size_k
=
weight
.
shape
assert
size_n
%
n_tile
==
0
and
size_k
%
k_tile
==
0
,
"k_tile / n_tile 必须能整除对应维度"
q
=
weight
.
reshape
((
E
,
size_n
//
(
n_tile
*
n_tile1
),
n_tile1
,
n_tile
,
size_k
//
(
k_tile
*
k_tile1
),
k_tile1
,
k_tile
))
q
=
q
.
permute
((
0
,
1
,
4
,
2
,
5
,
3
,
6
)).
contiguous
()
q
=
q
.
reshape
((
E
,
size_n
//
k_tile
,
size_k
*
k_tile
))
return
q
def
sparse_cutlass_supported
()
->
bool
:
def
sparse_cutlass_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
if
not
current_platform
.
is_cuda
():
...
...
vllm/utils/import_utils.py
View file @
1e9ff2e7
...
@@ -413,8 +413,8 @@ def has_deep_ep() -> bool:
...
@@ -413,8 +413,8 @@ def has_deep_ep() -> bool:
def
has_deep_gemm
()
->
bool
:
def
has_deep_gemm
()
->
bool
:
"""Whether the optional `deep
_
gemm` package is available."""
"""Whether the optional `deepgemm` package is available."""
return
_has_module
(
"deep
_
gemm"
)
return
_has_module
(
"deepgemm"
)
def
has_triton_kernels
()
->
bool
:
def
has_triton_kernels
()
->
bool
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment