Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d435d1cd
Commit
d435d1cd
authored
Feb 06, 2026
by
jujl1
Browse files
feat:w4a8Linear调用apply_int8_linear,以支持blaslt
parent
4ff0a865
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
76 deletions
+11
-76
vllm/model_executor/layers/quantization/slimquant_w4a8.py
vllm/model_executor/layers/quantization/slimquant_w4a8.py
+11
-76
No files found.
vllm/model_executor/layers/quantization/slimquant_w4a8.py
View file @
d435d1cd
...
...
@@ -5,6 +5,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
apply_int8_linear
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
...
...
@@ -111,6 +112,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
elif
self
.
w8a8_strategy
==
3
:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
T
else
:
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
...
...
@@ -158,81 +161,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
**
_
):
if
envs
.
USE_FUSED_RMS_QUANT
and
input_quant_args
is
not
None
:
assert
len
(
input_quant_args
)
==
2
x_q
,
x_scale
=
input_quant_args
elif
envs
.
USE_FUSED_SILU_MUL_QUANT
and
silu_quant_args
is
not
None
:
assert
len
(
silu_quant_args
)
==
2
x_q
,
x_scale
=
silu_quant_args
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
input_quant_args
is
not
None
:
assert
len
(
input_quant_args
)
==
2
x_q
,
x_scale
=
input_quant_args
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
silu_quant_args
is
not
None
:
assert
len
(
silu_quant_args
)
==
2
x_q
,
x_scale
=
silu_quant_args
else
:
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
if
self
.
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
k
=
x_q
.
shape
[
1
]
n
=
layer
.
weight
.
shape
[
1
]
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
best_config
=
None
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
m_
=
m
elif
m
<=
64
:
m_
=
((
m
+
3
)
//
4
)
*
4
#取值到最近的4的倍数
elif
m
<=
160
:
m_
=
(
m
//
8
)
*
8
elif
m
<
200
:
#256
m_
=
160
elif
m
<
480
:
#512
m_
=
256
elif
m
<
960
:
#1024
m_
=
512
elif
m
<
2048
:
m_
=
1024
elif
m
<
4096
:
m_
=
2048
elif
m
<
6000
:
m_
=
4096
else
:
m_
=
8192
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
#if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return
ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
best_config
=
best_config
)
elif
self
.
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
return
apply_int8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
bias
=
bias
,
w8a8_strategy
=
self
.
w8a8_strategy
,
input_quant_args
=
input_quant_args
,
silu_quant_args
=
silu_quant_args
)
class
SlimQuantW4A8Int8MoEMethod
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment