Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0dd25a44
Unverified
Commit
0dd25a44
authored
Apr 01, 2026
by
Yi Liu
Committed by
GitHub
Mar 31, 2026
Browse files
[Quantization][Autoround][XPU] Add `W4A16` Support (#37986)
Signed-off-by:
yiliu30
<
yi4.liu@intel.com
>
parent
3896e021
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
165 additions
and
11 deletions
+165
-11
.buildkite/scripts/hardware_ci/run-xpu-test.sh
.buildkite/scripts/hardware_ci/run-xpu-test.sh
+1
-0
vllm/model_executor/layers/quantization/inc.py
vllm/model_executor/layers/quantization/inc.py
+164
-11
No files found.
.buildkite/scripts/hardware_ci/run-xpu-test.sh
View file @
0dd25a44
...
@@ -42,6 +42,7 @@ docker run \
...
@@ -42,6 +42,7 @@ docker run \
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel
python3 examples/basic/offline_inference/generate.py --model OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc --block-size 64 --enforce-eager --max-model-len 8192
cd tests
cd tests
pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py --ignore=v1/core/test_scheduler_e2e.py
pytest -v -s v1/core --ignore=v1/core/test_reset_prefix_cache_e2e.py --ignore=v1/core/test_scheduler_e2e.py
pytest -v -s v1/engine
pytest -v -s v1/engine
...
...
vllm/model_executor/layers/quantization/inc.py
View file @
0dd25a44
...
@@ -6,14 +6,24 @@ from typing import TYPE_CHECKING, Any
...
@@ -6,14 +6,24 @@ from typing import TYPE_CHECKING, Any
import
regex
as
re
import
regex
as
re
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
UnquantizedLinearMethod
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
,
)
from
vllm.model_executor.layers.quantization
import
(
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
QuantizationConfig
,
QuantizationMethods
,
QuantizationMethods
,
)
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
,
RowvLLMParameter
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
...
@@ -402,16 +412,30 @@ class INCConfig(QuantizationConfig):
...
@@ -402,16 +412,30 @@ class INCConfig(QuantizationConfig):
return
None
return
None
def
apply_
ipex
_quant_layer
(
self
,
layer
,
prefix
:
str
):
def
apply_
xpu_w4a16
_quant_layer
(
self
,
layer
,
prefix
:
str
):
weight_bits
,
group_size
,
sym
=
self
.
get_layer_config
(
layer
,
prefix
)
weight_bits
,
group_size
,
sym
=
self
.
get_layer_config
(
layer
,
prefix
)
if
not
self
.
check_quantized
(
weight_bits
):
if
not
self
.
check_quantized
(
weight_bits
):
if
isinstance
(
layer
,
(
LinearBase
,
ParallelLMHead
)):
if
isinstance
(
layer
,
(
LinearBase
,
ParallelLMHead
)):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
else
:
else
:
return
None
return
None
if
weight_bits
!=
4
:
raise
NotImplementedError
(
f
"INC on XPU only supports 4-bit quantization, "
f
"got weight_bits=
{
weight_bits
}
."
)
if
not
sym
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"INC
quantization is not supported during xpu kernel migration
."
"INC
W4A16 on XPU only supports symmetric quantization for now
."
)
)
if
isinstance
(
layer
,
(
LinearBase
,
ParallelLMHead
)):
return
INCXPULinearMethod
(
weight_bits
=
weight_bits
,
group_size
=
group_size
,
sym
=
sym
,
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
):
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
):
if
prefix
and
self
.
extra_config
:
if
prefix
and
self
.
extra_config
:
...
@@ -420,12 +444,8 @@ class INCConfig(QuantizationConfig):
...
@@ -420,12 +444,8 @@ class INCConfig(QuantizationConfig):
layer_name
==
prefix
or
layer_name
==
f
"model.
{
prefix
}
"
layer_name
==
prefix
or
layer_name
==
f
"model.
{
prefix
}
"
)
and
self
.
extra_config
[
layer_name
].
get
(
"bits"
,
16
)
>=
16
:
)
and
self
.
extra_config
[
layer_name
].
get
(
"bits"
,
16
)
>=
16
:
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
if
(
if
current_platform
.
is_xpu
():
current_platform
.
is_cpu
()
return
self
.
apply_xpu_w4a16_quant_layer
(
layer
,
prefix
)
or
current_platform
.
is_xpu
()
or
self
.
backend
==
"ipex"
):
return
self
.
apply_ipex_quant_layer
(
layer
,
prefix
)
if
"gptq"
in
self
.
packing_format
or
"gptq"
in
self
.
backend
:
if
"gptq"
in
self
.
packing_format
or
"gptq"
in
self
.
backend
:
return
self
.
apply_gptq_quant_layer
(
layer
,
prefix
)
return
self
.
apply_gptq_quant_layer
(
layer
,
prefix
)
if
"awq"
in
self
.
packing_format
or
"awq"
in
self
.
backend
:
if
"awq"
in
self
.
packing_format
or
"awq"
in
self
.
backend
:
...
@@ -440,3 +460,136 @@ class INCConfig(QuantizationConfig):
...
@@ -440,3 +460,136 @@ class INCConfig(QuantizationConfig):
if
is_auto_round_format
:
if
is_auto_round_format
:
return
cls
.
get_name
()
return
cls
.
get_name
()
return
None
return
None
class
INCXPULinearMethod
(
LinearMethodBase
):
"""XPU linear method for INC w4a16 GPTQ quantization (symmetric only).
Repacks GPTQ weights from [in_packed, out] to oneDNN [out, in_packed]
layout and calls torch.ops._xpu_C.int4_gemm_w4a16.
GPTQ format: qweight [in_packed, out] with sequential nibble order.
Note: Asymmetric quantization (sym=false) is not for now.
FIXME(yiliu30): Refine the implementation to reuse XPUwNa16LinearKernel.
"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
sym
:
bool
):
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
sym
=
sym
self
.
pack_factor
=
32
//
weight_bits
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
output_size
# Unused.
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
scales_and_zp_size
=
input_size_per_partition
//
self
.
group_size
# GPTQ: qweight [in // pack_factor, out] packed along input dim
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
,
packed_factor
=
self
.
pack_factor
,
weight_loader
=
weight_loader
,
)
# scales: [num_groups, out] params_dtype
scales
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
0
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
# qzeros: [num_groups, out // pack_factor] int32
qzeros
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
pack_factor
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
pack_factor
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
# GPTQ checkpoints may include g_idx for activation reordering.
# Register it so the weight loader doesn't error on unexpected keys.
g_idx
=
RowvLLMParameter
(
data
=
torch
.
tensor
(
[
i
//
self
.
group_size
for
i
in
range
(
input_size_per_partition
)],
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Repack GPTQ weights into kernel-ready NT layout."""
device
=
layer
.
qweight
.
data
.
device
# oneDNN int4 kernel requires strides[0]==1 ("NT format"), but GPTQ
# checkpoint is [K_packed, N] contiguous with strides (N, 1).
# Two transposes are needed — neither alone can achieve this:
# 1. .t().contiguous() → [N, K_packed] contiguous in memory
# 2. .t() → [K_packed, N] view with strides (1, K_packed)
# The result has the same logical shape but strides[0]==1 as required.
qweight_ct
=
layer
.
qweight
.
data
.
t
().
contiguous
()
layer
.
qweight
=
Parameter
(
qweight_ct
.
t
(),
requires_grad
=
False
)
# Scales: [num_groups, out] — no change needed
layer
.
scales
=
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
# Symmetric: GPTQ v1 stores qzeros=7, effective zp = 7+1 = 8
# Kernel expects int8 scalar = 8
layer
.
qzeros
=
Parameter
(
torch
.
tensor
([
8
],
dtype
=
torch
.
int8
,
device
=
device
),
requires_grad
=
False
,
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
# qweight is already in NT layout [K_packed, N] (strides (1, K_packed))
# from process_weights_after_loading — pass directly to kernel.
out_shape
=
x
.
shape
[:
-
1
]
+
(
layer
.
qweight
.
shape
[
1
],)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
torch
.
ops
.
_xpu_C
.
int4_gemm_w4a16
(
reshaped_x
,
layer
.
qweight
,
bias
,
layer
.
scales
,
layer
.
qzeros
,
self
.
group_size
,
None
,
# g_idx not needed: desc_act is always False for INC models
)
return
out
.
reshape
(
out_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment