Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
20852c8f
Unverified
Commit
20852c8f
authored
Nov 19, 2025
by
Li, Jiang
Committed by
GitHub
Nov 19, 2025
Browse files
[CPU] Refactor CPU WNA16 (#28826)
Signed-off-by:
jiang1.li
<
jiang1.li@intel.com
>
parent
40b6b38f
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
626 additions
and
1 deletion
+626
-1
vllm/model_executor/layers/quantization/cpu_wna16.py
vllm/model_executor/layers/quantization/cpu_wna16.py
+625
-0
vllm/model_executor/layers/quantization/ipex_quant.py
vllm/model_executor/layers/quantization/ipex_quant.py
+1
-1
No files found.
vllm/model_executor/layers/quantization/cpu_wna16.py
0 → 100644
View file @
20852c8f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Optional
import
torch
from
safetensors.torch
import
_TYPES
as
_SAFETENSORS_TO_TORCH_DTYPE
from
vllm._custom_ops
import
(
cpu_gemm_wna16
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
,
)
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
vllm.model_executor.layers.quantization.utils.gptq_utils
import
(
get_linear_quant_method
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
marlin_repeat_scales_on_all_ranks
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
,
pack_cols
,
unpack_cols
,
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.models.utils
import
WeightsMapper
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
RowvLLMParameter
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_safetensors_params_metadata
from
vllm.utils.collection_utils
import
is_list_of
logger
=
init_logger
(
__name__
)
class
CPUGPTQConfig
(
QuantizationConfig
):
"""Config class for CPU GPTQ quant"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
,
lm_head_quantized
:
bool
,
dynamic
:
dict
[
str
,
dict
[
str
,
int
|
bool
]],
full_config
:
dict
[
str
,
Any
],
modules_in_block_to_quantize
:
list
[
str
]
|
None
=
None
,
)
->
None
:
super
().
__init__
()
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act
=
False
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
# Format is dict[str, dict] where key is a regex string that can
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
# matching of a module.
# Default to positive match, override base quant config mode, if no
# prefix is used. Value is in dict format of field key and override
# value.
# Negative matching will skip quantization init for this module
# entirely:
# non-quantized inference. More details and quantization examples can be
# found at: https://github.com/ModelCloud/GPTQModel
# Example:
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
# dynamic = {
# #`.*\.` matches the layers_node prefix
# # positive match layer 10-15
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
# # positive match layer 16-21
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
# }
assert
weight_bits
==
4
self
.
dynamic
=
dynamic
self
.
weight_bits
=
weight_bits
self
.
is_sym
=
is_sym
self
.
pack_factor
=
32
//
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
lm_head_quantized
=
lm_head_quantized
self
.
full_config
=
full_config
self
.
modules_in_block_to_quantize
=
modules_in_block_to_quantize
or
[]
def
__repr__
(
self
)
->
str
:
return
(
f
"CPUWNA16Config("
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
, "
f
"dynamic=
{
self
.
dynamic
}
, "
f
"modules_in_block_to_quantize=
{
self
.
modules_in_block_to_quantize
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
QuantizationMethods
:
return
"cpu_gptq"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
-
1
@
classmethod
def
get_config_filenames
(
cls
)
->
list
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"CPUGPTQConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
desc_act
=
cls
.
get_from_keys_or
(
config
,
[
"desc_act"
],
default
=
False
)
dynamic
=
cls
.
get_from_keys_or
(
config
,
[
"dynamic"
],
default
=
{})
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
modules_in_block_to_quantize
=
cls
.
get_from_keys_or
(
config
,
[
"modules_in_block_to_quantize"
],
default
=
None
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
,
lm_head_quantized
,
dynamic
,
config
,
modules_in_block_to_quantize
,
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
QuantizationMethods
|
None
:
quant_method
=
hf_quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
if
current_platform
.
is_cpu
()
and
(
quant_method
==
"gptq"
):
return
cls
.
get_name
()
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
CPUGPTQLinearMethod
)
# type: ignore
def
apply_vllm_mapper
(
self
,
hf_to_vllm_mapper
):
if
self
.
modules_in_block_to_quantize
is
not
None
:
self
.
modules_in_block_to_quantize
=
hf_to_vllm_mapper
.
apply_list
(
self
.
modules_in_block_to_quantize
)
def
maybe_update_config
(
self
,
model_name
:
str
,
revision
:
str
|
None
=
None
):
if
self
.
modules_in_block_to_quantize
:
if
is_list_of
(
self
.
modules_in_block_to_quantize
,
list
):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self
.
modules_in_block_to_quantize
=
[
item
for
sublist
in
self
.
modules_in_block_to_quantize
for
item
in
sublist
]
return
unquant_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
metadata
=
get_safetensors_params_metadata
(
model_name
,
revision
=
revision
)
quant_layers
:
set
[
str
]
=
{
param_name
.
rsplit
(
"."
,
1
)[
0
]
for
param_name
,
info
in
metadata
.
items
()
if
(
dtype
:
=
info
.
get
(
"dtype"
,
None
))
and
_SAFETENSORS_TO_TORCH_DTYPE
[
dtype
]
not
in
unquant_dtypes
}
self
.
modules_in_block_to_quantize
=
list
(
quant_layers
)
class
CPUGPTQLinearMethod
(
LinearMethodBase
):
"""Linear method for GPTQ on CPU.
Args:
quant_config: The CPUWNA16 quantization config.
"""
def
__init__
(
self
,
quant_config
:
CPUGPTQConfig
)
->
None
:
self
.
quant_config
=
quant_config
assert
self
.
quant_config
.
is_sym
,
"GPTQ asym quant is not supported on CPU"
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
output_size_per_partition
=
sum
(
output_partition_sizes
)
assert
output_size_per_partition
*
self
.
quant_config
.
weight_bits
%
32
==
0
assert
output_size_per_partition
%
32
==
0
assert
input_size_per_partition
%
32
==
0
is_row_parallel
=
input_size
!=
input_size_per_partition
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
# Determine sharding
if
marlin_repeat_scales_on_all_ranks
(
self
.
quant_config
.
desc_act
,
self
.
quant_config
.
group_size
,
is_row_parallel
):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each rank in TP>1 case.
scales_and_zp_input_dim
=
None
scales_and_zp_size
=
input_size
//
group_size
else
:
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim
=
0
scales_and_zp_size
=
input_size_per_partition
//
group_size
# Quantized weights
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
,
)
# Activation order
g_idx
=
RowvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
weight_loader
=
weight_loader
,
)
set_weight_attrs
(
g_idx
,
{
"ignore_warning"
:
True
},
)
qzeros_args
=
{
"data"
:
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
"weight_loader"
:
weight_loader
,
}
weight_scale_args
=
{
"data"
:
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
"weight_loader"
:
weight_loader
,
}
if
scales_and_zp_input_dim
is
None
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
**
weight_scale_args
)
qzeros
=
PackedColumnParameter
(
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
,
)
else
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
weight_scale_args
)
qzeros
=
PackedvLLMParameter
(
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
,
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
torch
.
set_printoptions
(
profile
=
"full"
,
linewidth
=
5000
,
sci_mode
=
False
)
packed_weight
=
layer
.
qweight
.
data
bits
=
self
.
quant_config
.
weight_bits
pack_factor
=
int
(
self
.
quant_config
.
pack_factor
)
p_w_k
,
p_w_n
=
packed_weight
.
size
()
input_size
=
p_w_k
*
pack_factor
output_size
=
p_w_n
isa_hint
=
_get_isa_hint
(
layer
.
scales
.
dtype
)
layer
.
isa_hint
=
isa_hint
layer
.
qzeros
=
None
if
not
self
.
quant_config
.
desc_act
:
layer
.
g_idx
=
None
# convert input dim packed to output dim packed
weight
=
unpack_cols
(
packed_weight
,
bits
,
p_w_k
,
p_w_n
*
pack_factor
).
view
(
p_w_k
,
p_w_n
,
pack_factor
)
weight
=
weight
.
permute
(
0
,
2
,
1
).
reshape
(
input_size
,
output_size
).
contiguous
()
weight
=
pack_cols
(
weight
,
bits
,
input_size
,
output_size
)
# make 16 output channel as a block and transpose to the make
# the block contigous
weight
=
(
weight
.
view
(
input_size
,
-
1
,
16
//
pack_factor
)
.
permute
(
1
,
0
,
2
)
.
reshape
(
-
1
,
input_size
*
16
//
pack_factor
)
.
contiguous
()
)
layer
.
qweight
.
data
=
weight
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
x
=
cpu_gemm_wna16
(
input
=
x
,
q_weight
=
layer
.
qweight
,
scales
=
layer
.
scales
,
zeros
=
layer
.
qzeros
,
g_idx
=
layer
.
g_idx
,
bias
=
bias
,
pack_factor
=
8
,
isa_hint
=
layer
.
isa_hint
,
)
return
x
class
CPUAWQConfig
(
QuantizationConfig
):
"""Config class for CPU AWQ"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
zero_point
:
bool
,
lm_head_quantized
:
bool
,
modules_to_not_convert
:
list
[
str
]
|
None
,
full_config
:
dict
[
str
,
Any
],
)
->
None
:
super
().
__init__
()
assert
weight_bits
==
4
self
.
pack_factor
=
32
//
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
zero_point
=
zero_point
self
.
lm_head_quantized
=
lm_head_quantized
self
.
weight_bits
=
weight_bits
self
.
modules_to_not_convert
=
modules_to_not_convert
or
[]
self
.
full_config
=
full_config
def
__repr__
(
self
)
->
str
:
return
(
f
"AWQMarlinConfig("
f
"group_size=
{
self
.
group_size
}
, "
f
"zero_point=
{
self
.
zero_point
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
, "
f
"modules_to_not_convert=
{
self
.
modules_to_not_convert
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
"QuantizationMethods"
:
return
"cpu_awq"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
-
1
@
classmethod
def
get_config_filenames
(
cls
)
->
list
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"CPUAWQConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
zero_point
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
modules_to_not_convert
=
cls
.
get_from_keys_or
(
config
,
[
"modules_to_not_convert"
],
None
)
return
cls
(
weight_bits
,
group_size
,
zero_point
,
lm_head_quantized
,
modules_to_not_convert
,
config
,
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
"QuantizationMethods"
]:
quant_method
=
hf_quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
if
current_platform
.
is_cpu
()
and
(
quant_method
==
"awq"
):
return
cls
.
get_name
()
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
):
if
is_layer_skipped
(
prefix
,
self
.
modules_to_not_convert
,
self
.
packed_modules_mapping
,
skip_with_substr
=
True
,
):
return
UnquantizedLinearMethod
()
return
CPUAWQLinearMethod
(
self
)
return
None
def
apply_vllm_mapper
(
self
,
hf_to_vllm_mapper
:
"WeightsMapper"
):
if
self
.
modules_to_not_convert
:
self
.
modules_to_not_convert
=
hf_to_vllm_mapper
.
apply_list
(
self
.
modules_to_not_convert
)
def
maybe_update_config
(
self
,
model_name
:
str
,
revision
:
str
|
None
=
None
):
if
self
.
modules_to_not_convert
:
return
unquant_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
metadata
=
get_safetensors_params_metadata
(
model_name
,
revision
=
revision
)
layers
=
{
param_name
.
rsplit
(
"."
,
1
)[
0
]
for
param_name
in
metadata
}
quant_layers
:
set
[
str
]
=
{
param_name
.
rsplit
(
"."
,
1
)[
0
]
for
param_name
,
info
in
metadata
.
items
()
if
(
dtype
:
=
info
.
get
(
"dtype"
,
None
))
and
_SAFETENSORS_TO_TORCH_DTYPE
[
dtype
]
not
in
unquant_dtypes
}
self
.
modules_to_not_convert
=
list
(
layers
-
quant_layers
)
class
CPUAWQLinearMethod
(
LinearMethodBase
):
"""Linear method for CPU AWQ.
Args:
quant_config: The CPU AWQ quantization config.
"""
def
__init__
(
self
,
quant_config
:
CPUAWQConfig
)
->
None
:
self
.
quant_config
=
quant_config
assert
self
.
quant_config
.
zero_point
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
,
)
num_groups
=
input_size_per_partition
//
group_size
qzeros
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
num_groups
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
,
)
scales
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
num_groups
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
0
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
register_parameter
(
"scales"
,
scales
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
torch
.
set_printoptions
(
profile
=
"full"
,
linewidth
=
5000
,
sci_mode
=
False
)
packed_weight
=
layer
.
qweight
.
data
packed_zeros
=
layer
.
qzeros
.
data
group_num
=
packed_zeros
.
size
(
0
)
bits
=
self
.
quant_config
.
weight_bits
pack_factor
=
int
(
self
.
quant_config
.
pack_factor
)
input_size
,
packed_output_size
=
packed_weight
.
size
()
output_size
=
packed_output_size
*
pack_factor
isa_hint
=
_get_isa_hint
(
layer
.
scales
.
dtype
)
layer
.
isa_hint
=
isa_hint
interleave_map
=
(
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
)
weight
=
unpack_cols
(
packed_weight
,
bits
,
input_size
,
output_size
,
)
zeros
=
unpack_cols
(
packed_zeros
,
bits
,
group_num
,
output_size
,
)
weight
=
(
weight
.
view
(
input_size
,
-
1
,
pack_factor
)[:,
:,
interleave_map
]
.
reshape
(
input_size
,
output_size
)
.
contiguous
()
)
zeros
=
(
zeros
.
view
(
group_num
,
-
1
,
pack_factor
)[:,
:,
interleave_map
]
.
reshape
(
group_num
,
output_size
)
.
contiguous
()
)
zeros
=
pack_cols
(
zeros
,
bits
,
group_num
,
output_size
).
contiguous
()
# make 16 output channel as a block and transpose to
# the make the block contigous
weight
=
pack_cols
(
weight
,
bits
,
input_size
,
output_size
)
weight
=
(
weight
.
view
(
input_size
,
-
1
,
16
//
pack_factor
)
.
permute
(
1
,
0
,
2
)
.
reshape
(
-
1
,
input_size
*
16
//
pack_factor
)
.
contiguous
()
)
layer
.
qweight
.
data
=
weight
layer
.
qzeros
.
data
=
zeros
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
x
=
cpu_gemm_wna16
(
input
=
x
,
q_weight
=
layer
.
qweight
,
scales
=
layer
.
scales
,
zeros
=
layer
.
qzeros
,
g_idx
=
None
,
bias
=
bias
,
pack_factor
=
8
,
isa_hint
=
layer
.
isa_hint
,
)
return
x
def
_get_isa_hint
(
dtype
:
torch
.
dtype
)
->
str
:
supports_amx
=
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
if
supports_amx
and
dtype
in
(
torch
.
bfloat16
,):
return
"amx"
else
:
return
"vec"
vllm/model_executor/layers/quantization/ipex_quant.py
View file @
20852c8f
...
...
@@ -134,7 +134,7 @@ class IPEXConfig(QuantizationConfig):
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
QuantizationMethods
|
None
:
if
not
current_platform
.
is_cpu
()
and
not
current_platform
.
is_xpu
():
if
not
current_platform
.
is_xpu
():
return
None
quant_method
=
hf_quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment