Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
16d49763
Commit
16d49763
authored
Dec 25, 2025
by
gaoqiong
Browse files
修复w8a8 triton config 择优位运算可能引发torch compile 编译错误,修复smquant w8a8 权重后处理位置
parent
7d5faa43
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
29 deletions
+33
-29
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+3
-26
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+27
-0
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+3
-3
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
16d49763
...
@@ -39,7 +39,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
...
@@ -39,7 +39,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils
import
(
# noqa: E501
cutlass_fp4_supported
)
cutlass_fp4_supported
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
W8a8GetCacheJSON
import
os
import
os
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -616,33 +616,10 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -616,33 +616,10 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quantization_config
:
CompressedTensorsConfig
):
def
__init__
(
self
,
quantization_config
:
CompressedTensorsConfig
):
self
.
quantization_config
=
quantization_config
self
.
quantization_config
=
quantization_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
n
=
layer
.
weight
.
shape
[
0
]
k
=
layer
.
weight
.
shape
[
1
]
if
self
.
w8a8_strategy
==
1
:
if
[
n
,
k
]
not
in
self
.
tritonsingleton
.
weight_shapes
:
self
.
tritonsingleton
.
weight_shapes
.
append
([
n
,
k
])
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
self
.
tritonsingleton
.
triton_json_dict
.
update
(
configs_dict
)
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
elif
self
.
w8a8_strategy
==
3
:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
T
else
:
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
layer
.
weight
.
data
=
_weight
self
.
tritonsingleton
.
gen_model_json
()
layer
.
scheme
.
process_weights_after_loading
(
layer
)
layer
.
scheme
.
process_weights_after_loading
(
layer
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
16d49763
...
@@ -18,6 +18,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
...
@@ -18,6 +18,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter
,
ChannelQuantScaleParameter
,
ModelWeightParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
PerTensorScaleParameter
)
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm
import
_custom_ops
as
ops
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -29,6 +31,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -29,6 +31,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_symmetric
:
bool
):
input_symmetric
:
bool
):
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
self
.
input_symmetric
=
input_symmetric
self
.
input_symmetric
=
input_symmetric
...
@@ -108,6 +111,30 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -108,6 +111,30 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
# Checkpoints are serialized in compressed-tensors format, which is
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
# different from the format the kernel may want. Handle repacking here.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
n
=
layer
.
weight
.
shape
[
0
]
k
=
layer
.
weight
.
shape
[
1
]
if
self
.
w8a8_strategy
==
1
:
if
[
n
,
k
]
not
in
self
.
tritonsingleton
.
weight_shapes
:
self
.
tritonsingleton
.
weight_shapes
.
append
([
n
,
k
])
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
self
.
tritonsingleton
.
triton_json_dict
.
update
(
configs_dict
)
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
elif
self
.
w8a8_strategy
==
3
:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
T
else
:
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
layer
.
weight
.
data
=
_weight
self
.
tritonsingleton
.
gen_model_json
()
self
.
kernel
.
process_weights_after_loading
(
layer
)
self
.
kernel
.
process_weights_after_loading
(
layer
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
16d49763
...
@@ -455,11 +455,11 @@ def apply_int8_linear(
...
@@ -455,11 +455,11 @@ def apply_int8_linear(
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
if
m
<=
16
:
m_
=
m
m_
=
m
#best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
elif
m
<=
64
:
elif
m
<=
64
:
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
m_
=
(
m
//
4
)
*
4
#取值到最近的4的倍数
elif
m
<=
160
:
elif
m
<=
160
:
m_
=
(
m
+
7
)
&
-
8
m_
=
(
m
//
8
)
*
8
elif
m
<
200
:
#256
elif
m
<
200
:
#256
m_
=
160
m_
=
160
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment