Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0d8a7d8a
Unverified
Commit
0d8a7d8a
authored
Dec 05, 2025
by
Yi Liu
Committed by
GitHub
Dec 05, 2025
Browse files
[Compressed Tensors] Add XPU `wNa16` support (#29484)
Signed-off-by:
yiliu30
<
yi4.liu@intel.com
>
parent
9843e332
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
102 additions
and
0 deletions
+102
-0
.buildkite/scripts/hardware_ci/run-xpu-test.sh
.buildkite/scripts/hardware_ci/run-xpu-test.sh
+1
-0
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
...r/layers/quantization/kernels/mixed_precision/__init__.py
+4
-0
vllm/model_executor/layers/quantization/kernels/mixed_precision/xpu.py
...ecutor/layers/quantization/kernels/mixed_precision/xpu.py
+97
-0
No files found.
.buildkite/scripts/hardware_ci/run-xpu-test.sh
View file @
0d8a7d8a
...
@@ -38,6 +38,7 @@ docker run \
...
@@ -38,6 +38,7 @@ docker run \
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -cc.cudagraph_mode=NONE
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -cc.cudagraph_mode=NONE
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
python3 examples/offline_inference/basic/generate.py --model Intel/Qwen2.5-0.5B-W4A16-G128-AutoRound-LLMC-TEST-ONLY --enforce-eager
VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
VLLM_ATTENTION_BACKEND=TRITON_ATTN python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
cd tests
cd tests
pytest -v -s v1/core
pytest -v -s v1/core
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
View file @
0d8a7d8a
...
@@ -30,6 +30,9 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer
...
@@ -30,6 +30,9 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer
MPLinearKernel
,
MPLinearKernel
,
MPLinearLayerConfig
,
MPLinearLayerConfig
,
)
)
from
vllm.model_executor.layers.quantization.kernels.mixed_precision.xpu
import
(
# noqa: E501
XPUwNa16LinearKernel
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
# in priority/performance order (when available)
# in priority/performance order (when available)
...
@@ -42,6 +45,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
...
@@ -42,6 +45,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
BitBLASLinearKernel
,
BitBLASLinearKernel
,
ConchLinearKernel
,
ConchLinearKernel
,
ExllamaLinearKernel
,
ExllamaLinearKernel
,
XPUwNa16LinearKernel
,
]
]
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/xpu.py
0 → 100644
View file @
0d8a7d8a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.platforms
import
current_platform
from
.MPLinearKernel
import
MPLinearKernel
,
MPLinearLayerConfig
class
XPUwNa16LinearKernel
(
MPLinearKernel
):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
0
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_xpu
():
return
False
,
"IPEX wNa16 only supported on XPU/CPU devices"
# TODO: (yiliu30) relax these restrictions in later PRs
if
c
.
zero_points
:
return
False
,
"Zero points not supported for Now"
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
from
packaging
import
version
MIN_IPEX_VERSION
=
"2.6.0"
bias
=
layer
.
bias
if
not
layer
.
skip_bias_add
else
None
try
:
import
intel_extension_for_pytorch
as
ipex
if
version
.
parse
(
ipex
.
__version__
)
<
version
.
parse
(
MIN_IPEX_VERSION
):
raise
ImportError
(
"intel_extension_for_pytorch version is "
"wrong. Please install "
f
"intel_extension_for_pytorch>=
{
MIN_IPEX_VERSION
}
."
)
except
ImportError
as
err
:
raise
ImportError
(
"Please install "
f
"intel_extension_for_pytorch>=
{
MIN_IPEX_VERSION
}
via "
f
"`pip install intel_extension_for_pytorch>=
{
MIN_IPEX_VERSION
}
`"
" to use IPEX-AWQ linear method."
)
from
err
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
# with better performance.
lowp_mode
=
ipex
.
quantization
.
WoqLowpMode
.
INT8
# The weight will be de-packed from INT4 to INT8.
weight_dtype
=
ipex
.
quantization
.
WoqWeightDtype
.
INT4
# The float activation will be quantized (dynamic, per-token) to INT8.
act_quant_mode
=
ipex
.
quantization
.
WoqActQuantMode
.
PER_BATCH
qconfig
=
ipex
.
quantization
.
get_weight_only_quant_qconfig_mapping
(
weight_dtype
=
weight_dtype
,
lowp_mode
=
lowp_mode
,
act_quant_mode
=
act_quant_mode
,
group_size
=
self
.
config
.
group_size
,
weight_qscheme
=
ipex
.
quantization
.
WoqWeightQScheme
.
SYMMETRIC
,
)
qweight
=
layer
.
weight_packed
g_idx
=
layer
.
weight_g_idx
if
self
.
config
.
has_g_idx
else
None
scales
=
layer
.
weight_scale
qzeros
=
None
if
self
.
config
.
zero_points
:
qzeros
=
layer
.
weight_zero_point
.
contiguous
()
qweight
=
qweight
.
t
().
contiguous
()
scales
=
scales
.
t
().
contiguous
()
layer
.
ipex_output_size
=
self
.
config
.
partition_weight_shape
[
1
]
layer
.
ipex_qlinear
=
(
ipex
.
llm
.
quantization
.
woq_linear
.
IPEXWeightOnlyQuantizedLinear
.
from_weight
(
qweight
,
scales
,
qzeros
,
in_features
=
self
.
config
.
partition_weight_shape
[
0
],
out_features
=
self
.
config
.
partition_weight_shape
[
1
],
qconfig
=
qconfig
,
g_idx
=
g_idx
,
bias
=
bias
,
group_size
=
self
.
config
.
group_size
,
quant_method
=
0
,
# `0` stands for the IPEX GPTQ
)
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
layer
.
ipex_qlinear
(
reshaped_x
)
return
out
.
reshape
(
x
.
shape
[:
-
1
]
+
(
layer
.
ipex_output_size
,))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment