Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
47b7af0d
Unverified
Commit
47b7af0d
authored
Mar 19, 2026
by
Tianmu Li
Committed by
GitHub
Mar 20, 2026
Browse files
[Feat] Enable CompressedTensorW4A8Int for XPU (#37207)
Signed-off-by:
Li, Tianmu
<
tianmu.li@intel.com
>
parent
269bf46d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
172 additions
and
0 deletions
+172
-0
vllm/_xpu_ops.py
vllm/_xpu_ops.py
+54
-0
vllm/model_executor/kernels/linear/__init__.py
vllm/model_executor/kernels/linear/__init__.py
+3
-0
vllm/model_executor/kernels/linear/mixed_precision/__init__.py
...model_executor/kernels/linear/mixed_precision/__init__.py
+2
-0
vllm/model_executor/kernels/linear/mixed_precision/xpu.py
vllm/model_executor/kernels/linear/mixed_precision/xpu.py
+113
-0
No files found.
vllm/_xpu_ops.py
View file @
47b7af0d
...
@@ -37,6 +37,26 @@ if hasattr(torch.ops._xpu_C, "fp8_gemm_w8a16"):
...
@@ -37,6 +37,26 @@ if hasattr(torch.ops._xpu_C, "fp8_gemm_w8a16"):
return
torch
.
empty
((
M
,
N
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
return
torch
.
empty
((
M
,
N
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
if
hasattr
(
torch
.
ops
.
_xpu_C
,
"int4_gemm_w4a8"
):
@
register_fake
(
"_xpu_C::int4_gemm_w4a8"
)
def
_int4_gemm_w4a8_fake
(
input
:
torch
.
Tensor
,
input_scales
:
torch
.
Tensor
,
input_zero_points
:
torch
.
Tensor
,
q_weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_zp
:
torch
.
Tensor
,
group_size
:
int
,
g_idx
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
M
=
input_2d
.
size
(
0
)
N
=
q_weight
.
size
(
1
)
return
torch
.
empty
((
M
,
N
),
dtype
=
torch
.
float16
,
device
=
input
.
device
)
if
hasattr
(
torch
.
ops
.
_xpu_C
,
"int4_gemm_w4a16"
):
if
hasattr
(
torch
.
ops
.
_xpu_C
,
"int4_gemm_w4a16"
):
@
register_fake
(
"_xpu_C::int4_gemm_w4a16"
)
@
register_fake
(
"_xpu_C::int4_gemm_w4a16"
)
...
@@ -87,6 +107,40 @@ _OPS_REGISTERED = False
...
@@ -87,6 +107,40 @@ _OPS_REGISTERED = False
class
xpu_ops
:
class
xpu_ops
:
@
staticmethod
@
torch
.
compile
def
dynamic_per_token_int8_quant_ref
(
input
:
torch
.
Tensor
,
use_sym_quant
:
bool
,
bits
:
int
):
original_sizes
=
input
.
size
()
# view is not safe in torch.compile if input is not contiguous
input
=
input
.
reshape
(
-
1
,
original_sizes
[
-
1
]
)
# Flatten except for the last dimension
qmin
=
-
(
2
**
(
bits
-
1
))
if
use_sym_quant
else
0
qmax
=
2
**
(
bits
-
1
)
-
1
if
use_sym_quant
else
2
**
bits
-
1
min_val
=
torch
.
min
(
input
,
dim
=-
1
)[
0
].
to
(
dtype
=
torch
.
float32
).
unsqueeze
(
-
1
)
max_val
=
torch
.
max
(
input
,
dim
=-
1
)[
0
].
to
(
dtype
=
torch
.
float32
).
unsqueeze
(
-
1
)
if
use_sym_quant
:
scale
=
(
torch
.
maximum
(
torch
.
abs
(
min_val
),
torch
.
abs
(
max_val
))
/
qmax
).
clamp
(
min
=
1e-5
)
zero_point
=
torch
.
zeros_like
(
scale
).
to
(
dtype
=
torch
.
int32
)
else
:
scale
=
((
max_val
-
min_val
)
/
qmax
).
clamp
(
min
=
1e-5
)
zero_point
=
-
1
*
torch
.
round
(
min_val
/
scale
).
to
(
dtype
=
torch
.
int32
)
scale
=
scale
.
to
(
dtype
=
input
.
dtype
)
quantized
=
torch
.
clamp
(
torch
.
round
(
input
/
scale
.
to
(
dtype
=
torch
.
float32
)
+
zero_point
),
qmin
,
qmax
,
).
to
(
dtype
=
torch
.
int8
if
use_sym_quant
else
torch
.
uint8
)
return
(
quantized
.
view
(
original_sizes
),
scale
.
view
(
original_sizes
[:
-
1
]
+
(
1
,)),
zero_point
.
view
(
original_sizes
[:
-
1
]
+
(
1
,)),
)
@
staticmethod
@
staticmethod
def
flash_attn_varlen_func
(
def
flash_attn_varlen_func
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
...
...
vllm/model_executor/kernels/linear/__init__.py
View file @
47b7af0d
...
@@ -48,6 +48,7 @@ from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
...
@@ -48,6 +48,7 @@ from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
MarlinLinearKernel
,
MarlinLinearKernel
,
)
)
from
vllm.model_executor.kernels.linear.mixed_precision.xpu
import
(
from
vllm.model_executor.kernels.linear.mixed_precision.xpu
import
(
XPUW4A8IntLinearKernel
,
XPUwNa16LinearKernel
,
XPUwNa16LinearKernel
,
)
)
from
vllm.model_executor.kernels.linear.scaled_mm
import
(
from
vllm.model_executor.kernels.linear.scaled_mm
import
(
...
@@ -138,6 +139,7 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
...
@@ -138,6 +139,7 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
ExllamaLinearKernel
,
ExllamaLinearKernel
,
],
],
PlatformEnum
.
XPU
:
[
PlatformEnum
.
XPU
:
[
XPUW4A8IntLinearKernel
,
XPUwNa16LinearKernel
,
XPUwNa16LinearKernel
,
],
],
PlatformEnum
.
CPU
:
[
PlatformEnum
.
CPU
:
[
...
@@ -391,5 +393,6 @@ __all__ = [
...
@@ -391,5 +393,6 @@ __all__ = [
"ExllamaLinearKernel"
,
"ExllamaLinearKernel"
,
"MacheteLinearKernel"
,
"MacheteLinearKernel"
,
"MarlinLinearKernel"
,
"MarlinLinearKernel"
,
"XPUW4A8IntLinearKernel"
,
"XPUwNa16LinearKernel"
,
"XPUwNa16LinearKernel"
,
]
]
vllm/model_executor/kernels/linear/mixed_precision/__init__.py
View file @
47b7af0d
...
@@ -30,6 +30,7 @@ from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import (
...
@@ -30,6 +30,7 @@ from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import (
MPLinearLayerConfig
,
MPLinearLayerConfig
,
)
)
from
vllm.model_executor.kernels.linear.mixed_precision.xpu
import
(
from
vllm.model_executor.kernels.linear.mixed_precision.xpu
import
(
XPUW4A8IntLinearKernel
,
XPUwNa16LinearKernel
,
XPUwNa16LinearKernel
,
)
)
...
@@ -44,5 +45,6 @@ __all__ = [
...
@@ -44,5 +45,6 @@ __all__ = [
"ExllamaLinearKernel"
,
"ExllamaLinearKernel"
,
"MacheteLinearKernel"
,
"MacheteLinearKernel"
,
"MarlinLinearKernel"
,
"MarlinLinearKernel"
,
"XPUW4A8IntLinearKernel"
,
"XPUwNa16LinearKernel"
,
"XPUwNa16LinearKernel"
,
]
]
vllm/model_executor/kernels/linear/mixed_precision/xpu.py
View file @
47b7af0d
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
...
@@ -12,6 +14,8 @@ from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
...
@@ -12,6 +14,8 @@ from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
_XPUWNA16_SUPPORTED_QUANT_TYPES
=
(
scalar_types
.
uint4
,
scalar_types
.
uint4b8
)
_XPUWNA16_SUPPORTED_QUANT_TYPES
=
(
scalar_types
.
uint4
,
scalar_types
.
uint4b8
)
logger
=
init_logger
(
__name__
)
class
XPUwNa16LinearKernel
(
MPLinearKernel
):
class
XPUwNa16LinearKernel
(
MPLinearKernel
):
@
classmethod
@
classmethod
...
@@ -86,3 +90,112 @@ class XPUwNa16LinearKernel(MPLinearKernel):
...
@@ -86,3 +90,112 @@ class XPUwNa16LinearKernel(MPLinearKernel):
layer
.
g_idx
,
layer
.
g_idx
,
)
)
return
out
return
out
class
XPUW4A8IntLinearKernel
(
MPLinearKernel
):
"""XPU kernel for W4A8 integer quantization using oneDNN int4_gemm_w4a8.
Weights are symmetric group-quantized int4 packed as uint4.
Activations are dynamically quantized per-token to symmetric int8.
"""
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
-
1
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_xpu
():
return
False
,
"XPUW4A8Int only supported on XPU"
if
c
.
act_type
not
in
(
torch
.
bfloat16
,
torch
.
float16
):
return
False
,
"XPUW4A8Int requires BF16/FP16 activations"
if
c
.
weight_type
!=
scalar_types
.
int4
:
return
(
False
,
f
"XPUW4A8Int requires int4 weights, got
{
c
.
weight_type
}
"
,
)
if
c
.
zero_points
:
return
False
,
"XPUW4A8Int only supports symmetric weight quantization"
if
c
.
group_size
!=
-
1
and
c
.
group_size
%
32
!=
0
:
return
(
False
,
f
"Group size (
{
c
.
group_size
}
) not supported by XPUW4A8Int, "
"must be a multiple of 32"
,
)
in_size
,
out_size
=
c
.
partition_weight_shape
if
in_size
%
8
!=
0
or
out_size
%
8
!=
0
:
return
(
False
,
f
"in/out sizes (
{
in_size
}
,
{
out_size
}
) must be multiples of 8"
,
)
if
c
.
act_type
!=
torch
.
float16
:
logger
.
warning_once
(
"XPUW4A8IntLinearKernel is running with model dtype %s, "
"but int4_gemm_w4a8 produces float16 output. Recommend "
"setting --dtype float16 for best performance."
,
c
.
act_type
,
)
return
True
,
None
def
_pack_int4_weight
(
self
,
w
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# w is [N, K] int8 with values in [-8, 7]
w_u4
=
w
.
to
(
torch
.
int32
)
+
8
# shift to [0, 15]
w_u4
=
w_u4
.
reshape
(
w
.
shape
[
0
],
w
.
shape
[
1
]
//
8
,
8
)
# [N, K/8, 8]
shifts
=
torch
.
arange
(
0
,
32
,
4
,
dtype
=
torch
.
int32
,
device
=
w
.
device
)
packed
=
((
w_u4
&
0xF
)
<<
shifts
[
None
,
None
,
:]).
sum
(
dim
=
2
).
to
(
torch
.
int32
)
return
packed
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
weight_scale
.
data
=
layer
.
weight_scale
.
data
.
t
().
contiguous
()
device
=
layer
.
weight_packed
.
device
# TODO: support asymmetric quantization
weight_zero_point
=
torch
.
tensor
([
8
],
dtype
=
torch
.
int8
,
device
=
device
)
layer
.
weight_zero_point
=
Parameter
(
weight_zero_point
,
requires_grad
=
False
)
# weight_packed is [out, in] int8, signed int4 values in [-8, 7]
w
=
layer
.
weight_packed
.
data
# [out, in]
# TODO: implement asym case
packed
=
self
.
_pack_int4_weight
(
w
)
# [out, in/8] packed uint4
replace_parameter
(
layer
,
self
.
w_q_name
,
torch
.
nn
.
Parameter
(
packed
,
requires_grad
=
False
),
)
# Free the original unpacked int8 weight (still registered as "weight")
# to avoid double-storing both int8 [N, K] and int32 [N, K/8] in memory.
layer
.
register_parameter
(
"weight"
,
None
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
# [M, K]
from
vllm._xpu_ops
import
xpu_ops
as
ops
# TODO: static and asymmetric quantization case
# Common code for CompressedTensorsW4A8Int does not read act symmetry data
quant_x
,
x_scale
,
x_zero
=
ops
.
dynamic_per_token_int8_quant_ref
(
reshaped_x
,
True
,
8
)
out
=
torch
.
ops
.
_xpu_C
.
int4_gemm_w4a8
(
quant_x
,
x_scale
,
x_zero
,
layer
.
weight_packed
.
t
(),
layer
.
weight_scale
,
layer
.
weight_zero_point
,
self
.
config
.
group_size
,
None
,
# g_idx not currently supported
bias
,
)
return
out
.
to
(
x
.
dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment