Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f09daea2
Unverified
Commit
f09daea2
authored
Mar 31, 2026
by
Yintong Lu
Committed by
GitHub
Mar 31, 2026
Browse files
[CPU] Support int8 compute mode in CPU AWQ (#35697)
Signed-off-by:
Yintong Lu
<
yintong.lu@intel.com
>
parent
42318c84
Changes
10
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1197 additions
and
11 deletions
+1197
-11
.buildkite/hardware_tests/cpu.yaml
.buildkite/hardware_tests/cpu.yaml
+3
-1
cmake/cpu_extension.cmake
cmake/cpu_extension.cmake
+1
-0
csrc/cpu/sgl-kernels/common.h
csrc/cpu/sgl-kernels/common.h
+8
-0
csrc/cpu/sgl-kernels/gemm.h
csrc/cpu/sgl-kernels/gemm.h
+36
-3
csrc/cpu/sgl-kernels/gemm_int4.cpp
csrc/cpu/sgl-kernels/gemm_int4.cpp
+755
-0
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+20
-0
tests/kernels/test_awq_int4_to_int8.py
tests/kernels/test_awq_int4_to_int8.py
+281
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+32
-0
vllm/envs.py
vllm/envs.py
+3
-0
vllm/model_executor/layers/quantization/cpu_wna16.py
vllm/model_executor/layers/quantization/cpu_wna16.py
+58
-7
No files found.
.buildkite/hardware_tests/cpu.yaml
View file @
f09daea2
...
@@ -13,12 +13,14 @@ steps:
...
@@ -13,12 +13,14 @@ steps:
-
tests/kernels/attention/test_cpu_attn.py
-
tests/kernels/attention/test_cpu_attn.py
-
tests/kernels/moe/test_cpu_fused_moe.py
-
tests/kernels/moe/test_cpu_fused_moe.py
-
tests/kernels/test_onednn.py
-
tests/kernels/test_onednn.py
-
tests/kernels/test_awq_int4_to_int8.py
commands
:
commands
:
-
|
-
|
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
pytest -x -v -s tests/kernels/test_onednn.py"
pytest -x -v -s tests/kernels/test_onednn.py
pytest -x -v -s tests/kernels/test_awq_int4_to_int8.py"
-
label
:
CPU-Compatibility Tests
-
label
:
CPU-Compatibility Tests
depends_on
:
[]
depends_on
:
[]
...
...
cmake/cpu_extension.cmake
View file @
f09daea2
...
@@ -373,6 +373,7 @@ if (ENABLE_X86_ISA)
...
@@ -373,6 +373,7 @@ if (ENABLE_X86_ISA)
"csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
"csrc/cpu/sgl-kernels/gemm_int4.cpp"
"csrc/cpu/sgl-kernels/moe.cpp"
"csrc/cpu/sgl-kernels/moe.cpp"
"csrc/cpu/sgl-kernels/moe_int8.cpp"
"csrc/cpu/sgl-kernels/moe_int8.cpp"
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
)
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
)
...
...
csrc/cpu/sgl-kernels/common.h
View file @
f09daea2
...
@@ -117,6 +117,14 @@ inline void parallel_for(int n, const func_t& f) {
...
@@ -117,6 +117,14 @@ inline void parallel_for(int n, const func_t& f) {
#endif
#endif
}
}
inline
int
get_thread_num
()
{
#if defined(_OPENMP)
return
omp_get_thread_num
();
#else
return
0
;
#endif
}
// for 1d parallel, use `actual_nth`
// for 1d parallel, use `actual_nth`
// for 2d parallel, use even nths, e.g. 43->42
// for 2d parallel, use even nths, e.g. 43->42
int
inline
adjust_num_threads
(
int
m
)
{
int
inline
adjust_num_threads
(
int
m
)
{
...
...
csrc/cpu/sgl-kernels/gemm.h
View file @
f09daea2
...
@@ -17,8 +17,8 @@ constexpr int block_size_n() { return 2 * TILE_N; }
...
@@ -17,8 +17,8 @@ constexpr int block_size_n() { return 2 * TILE_N; }
template
<
typename
T
>
inline
bool
can_use_brgemm
(
int
M
);
template
<
typename
T
>
inline
bool
can_use_brgemm
(
int
M
);
template
<
>
inline
bool
can_use_brgemm
<
at
::
BFloat16
>
(
int
M
)
{
return
M
>
4
;
}
template
<
>
inline
bool
can_use_brgemm
<
at
::
BFloat16
>
(
int
M
)
{
return
M
>
4
;
}
template
<
>
inline
bool
can_use_brgemm
<
at
::
Half
>
(
int
M
)
{
return
true
;
}
template
<
>
inline
bool
can_use_brgemm
<
at
::
Half
>
(
int
M
)
{
return
true
;
}
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
template
<
>
inline
bool
can_use_brgemm
<
int8_t
>
(
int
M
)
{
return
M
>
4
;
}
template
<
>
inline
bool
can_use_brgemm
<
int8_t
>
(
int
M
)
{
return
false
;
}
template
<
>
inline
bool
can_use_brgemm
<
u
int8_t
>
(
int
M
)
{
return
M
>
4
;
}
template
<
>
inline
bool
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
int
M
)
{
return
M
>
4
;
}
template
<
>
inline
bool
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
int
M
)
{
return
M
>
4
;
}
template
<
>
inline
bool
can_use_brgemm
<
at
::
quint4x2
>
(
int
M
)
{
return
M
>
4
;
}
template
<
>
inline
bool
can_use_brgemm
<
at
::
quint4x2
>
(
int
M
)
{
return
M
>
4
;
}
...
@@ -40,9 +40,17 @@ inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
...
@@ -40,9 +40,17 @@ inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
return
use_int8_w8a8
?
K
+
sizeof
(
int32_t
)
:
K
;
return
use_int8_w8a8
?
K
+
sizeof
(
int32_t
)
:
K
;
}
}
// pack weight to vnni format
inline
int64_t
get_4bit_block_k_size
(
int64_t
group_size
)
{
return
group_size
>
128
?
128
:
group_size
;
}
// pack weight into vnni format
at
::
Tensor
convert_weight_packed
(
at
::
Tensor
&
weight
);
at
::
Tensor
convert_weight_packed
(
at
::
Tensor
&
weight
);
// pack weight to vnni format for int4 (adapted from sglang)
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
convert_weight_packed_scale_zp
(
at
::
Tensor
qweight
,
at
::
Tensor
qzeros
,
at
::
Tensor
scales
);
// moe implementations for int8 w8a8
// moe implementations for int8 w8a8
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
fused_experts_int8_kernel_impl
(
void
fused_experts_int8_kernel_impl
(
...
@@ -233,6 +241,31 @@ void tinygemm_kernel(
...
@@ -233,6 +241,31 @@ void tinygemm_kernel(
int64_t
strideBs
,
int64_t
strideBs
,
bool
brg
);
bool
brg
);
// int4 scaled GEMM (adapted from sglang)
at
::
Tensor
int4_scaled_mm_cpu
(
at
::
Tensor
&
x
,
at
::
Tensor
&
w
,
at
::
Tensor
&
w_zeros
,
at
::
Tensor
&
w_scales
,
std
::
optional
<
at
::
Tensor
>
bias
);
// int4 tinygemm kernel interface(adapted from sglang)
template
<
typename
scalar_t
>
void
tinygemm_kernel
(
scalar_t
*
C
,
float
*
C_temp
,
const
uint8_t
*
A
,
const
float
*
scales_a
,
const
int32_t
*
qzeros_a
,
const
uint8_t
*
B
,
const
float
*
scales_b
,
const
int8_t
*
qzeros_b
,
const
int32_t
*
compensation
,
int8_t
*
dqB_tmp
,
int64_t
M
,
int64_t
K
,
int64_t
lda
,
int64_t
ldc_f
,
int64_t
ldc_s
,
bool
store_out
,
bool
use_brgemm
);
// TODO: debug print, remove me later
// TODO: debug print, remove me later
inline
void
print_16x32i
(
const
__m512i
x
)
{
inline
void
print_16x32i
(
const
__m512i
x
)
{
int32_t
a
[
16
];
int32_t
a
[
16
];
...
...
csrc/cpu/sgl-kernels/gemm_int4.cpp
0 → 100644
View file @
f09daea2
This diff is collapsed.
Click to expand it.
csrc/cpu/torch_bindings.cpp
View file @
f09daea2
...
@@ -79,6 +79,14 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
...
@@ -79,6 +79,14 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
const
std
::
optional
<
at
::
Tensor
>&
bias
,
const
std
::
optional
<
at
::
Tensor
>&
bias
,
at
::
ScalarType
out_dtype
,
bool
is_vnni
);
at
::
ScalarType
out_dtype
,
bool
is_vnni
);
// Adapted from sglang: INT4 W4A8 kernels
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
convert_weight_packed_scale_zp
(
at
::
Tensor
qweight
,
at
::
Tensor
qzeros
,
at
::
Tensor
scales
);
at
::
Tensor
int4_scaled_mm_cpu
(
at
::
Tensor
&
x
,
at
::
Tensor
&
w
,
at
::
Tensor
&
w_zeros
,
at
::
Tensor
&
w_scales
,
std
::
optional
<
at
::
Tensor
>
bias
);
torch
::
Tensor
get_scheduler_metadata
(
torch
::
Tensor
get_scheduler_metadata
(
const
int64_t
num_req
,
const
int64_t
num_heads_q
,
const
int64_t
num_req
,
const
int64_t
num_heads_q
,
const
int64_t
num_heads_kv
,
const
int64_t
head_dim
,
const
int64_t
num_heads_kv
,
const
int64_t
head_dim
,
...
@@ -285,6 +293,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -285,6 +293,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor"
);
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor"
);
ops
.
impl
(
"int8_scaled_mm_with_quant"
,
torch
::
kCPU
,
ops
.
impl
(
"int8_scaled_mm_with_quant"
,
torch
::
kCPU
,
&
int8_scaled_mm_with_quant
);
&
int8_scaled_mm_with_quant
);
// Adapted from sglang: INT4 W4A8 kernels
ops
.
def
(
"convert_weight_packed_scale_zp(Tensor qweight, Tensor qzeros, "
"Tensor scales) -> (Tensor, Tensor, Tensor)"
);
ops
.
impl
(
"convert_weight_packed_scale_zp"
,
torch
::
kCPU
,
&
convert_weight_packed_scale_zp
);
ops
.
def
(
"int4_scaled_mm_cpu(Tensor(a0!) x, Tensor(a1!) w, Tensor(a2!) w_zeros, "
"Tensor(a3!) w_scales, Tensor? bias) -> Tensor"
);
ops
.
impl
(
"int4_scaled_mm_cpu"
,
torch
::
kCPU
,
&
int4_scaled_mm_cpu
);
#endif
#endif
// CPU attention kernels
// CPU attention kernels
...
...
tests/kernels/test_awq_int4_to_int8.py
0 → 100644
View file @
f09daea2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for AWQ INT4 W4A8 GEMM pipeline (SGLang kernel migration).
Part 1: Weight packing tests
- convert_weight_packed_scale_zp correctness
Part 2: INT4 W4A8 GEMM tests
- int4_scaled_mm_cpu correctness w.r.t. float reference
- Bias, 3D input, various shapes
Part 3: create_weights shapes
cmd:
VLLM_CPU_INT4_W4A8=1 python -m pytest tests/kernels/test_awq_int4_to_int8.py -v -s
"""
import
numpy
as
np
import
pytest
import
torch
from
vllm._custom_ops
import
_supports_cpu_w4a8_int8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_cols
,
)
from
vllm.platforms
import
current_platform
if
not
current_platform
.
is_cpu
():
pytest
.
skip
(
"skipping CPU-only tests"
,
allow_module_level
=
True
)
requires_cpu_w4a8_int8
=
pytest
.
mark
.
skipif
(
not
_supports_cpu_w4a8_int8
,
reason
=
"Requires vLLM CPU build with SGLang INT4 W4A8 kernels"
,
)
def
make_awq_checkpoint_data
(
K
,
N
,
group_size
,
seed
=
42
):
"""Create synthetic AWQ checkpoint data in packed int32 format.
Returns:
packed_qweight: [K, N//8] int32 (AWQ interleaved + packed)
packed_qzeros: [num_groups, N//8] int32 (AWQ interleaved + packed)
scales: [num_groups, N] float32
float_ref: [K, N] float32, reference dequantized weights
weight_int4_orig: [K, N] int32, original int4 values (0-15)
zeros_int4_orig: [num_groups, N] int32, original zero points (0-15)
"""
rng
=
np
.
random
.
RandomState
(
seed
)
num_groups
=
K
//
group_size
weight_int4_orig
=
torch
.
from_numpy
(
rng
.
randint
(
0
,
16
,
size
=
(
K
,
N
)).
astype
(
np
.
int32
)
)
zeros_int4_orig
=
torch
.
from_numpy
(
rng
.
randint
(
0
,
16
,
size
=
(
num_groups
,
N
)).
astype
(
np
.
int32
)
)
scales
=
torch
.
from_numpy
((
rng
.
randn
(
num_groups
,
N
)
*
0.05
).
astype
(
np
.
float32
))
scales_exp
=
scales
.
repeat_interleave
(
group_size
,
dim
=
0
)
zeros_exp
=
zeros_int4_orig
.
repeat_interleave
(
group_size
,
dim
=
0
)
float_ref
=
(
weight_int4_orig
.
float
()
-
zeros_exp
.
float
())
*
scales_exp
awq_interleave
=
[
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]
weight_interleaved
=
(
weight_int4_orig
.
reshape
(
-
1
,
8
)[:,
awq_interleave
].
reshape
(
K
,
N
).
contiguous
()
)
packed_qweight
=
pack_cols
(
weight_interleaved
,
4
,
K
,
N
)
zeros_interleaved
=
(
zeros_int4_orig
.
reshape
(
-
1
,
8
)[:,
awq_interleave
]
.
reshape
(
num_groups
,
N
)
.
contiguous
()
)
packed_qzeros
=
pack_cols
(
zeros_interleaved
,
4
,
num_groups
,
N
)
return
(
packed_qweight
,
packed_qzeros
,
scales
,
float_ref
,
weight_int4_orig
,
zeros_int4_orig
,
)
class
TestConvertWeightPackedScaleZp
:
"""Tests for convert_weight_packed_scale_zp weightpacking."""
@
requires_cpu_w4a8_int8
@
pytest
.
mark
.
parametrize
(
"K,N,group_size"
,
[
(
128
,
128
,
128
),
(
256
,
256
,
128
),
(
512
,
256
,
64
),
],
)
def
test_packing_output_shapes
(
self
,
K
,
N
,
group_size
):
"""Packed outputs should have expected shapes."""
(
packed_qweight
,
packed_qzeros
,
scales
,
_
,
_
,
_
)
=
make_awq_checkpoint_data
(
K
,
N
,
group_size
)
blocked_w
,
blocked_zp
,
blocked_s
=
torch
.
ops
.
_C
.
convert_weight_packed_scale_zp
(
packed_qweight
,
packed_qzeros
,
scales
)
block_n
=
32
Nc
=
N
//
block_n
assert
blocked_w
.
dim
()
>=
2
,
(
f
"blocked_w should have >= 2 dims, got
{
blocked_w
.
dim
()
}
"
)
assert
blocked_s
.
size
(
0
)
==
Nc
,
(
f
"Expected Nc=
{
Nc
}
scale blocks, got
{
blocked_s
.
size
(
0
)
}
"
)
assert
blocked_zp
.
size
(
0
)
==
Nc
,
(
f
"Expected Nc=
{
Nc
}
qzeros blocks, got
{
blocked_zp
.
size
(
0
)
}
"
)
print
(
f
" [PASS] packing shapes K=
{
K
}
, N=
{
N
}
, gs=
{
group_size
}
: "
f
"blocked_w=
{
list
(
blocked_w
.
shape
)
}
, "
f
"blocked_s=
{
list
(
blocked_s
.
shape
)
}
, blocked_zp=
{
list
(
blocked_zp
.
shape
)
}
"
)
class
TestInt4ScaledMmCpu
:
"""Tests for int4_scaled_mm_cpu GEMM kernel."""
@
requires_cpu_w4a8_int8
@
pytest
.
mark
.
parametrize
(
"M,K,N,group_size"
,
[
(
1
,
128
,
128
,
128
),
(
4
,
256
,
256
,
128
),
(
16
,
512
,
256
,
64
),
(
32
,
256
,
512
,
128
),
(
64
,
512
,
512
,
128
),
],
)
def
test_gemm_vs_float_reference
(
self
,
M
,
K
,
N
,
group_size
):
"""INT4 W4A8 GEMM should approximate float matmul."""
(
packed_qweight
,
packed_qzeros
,
scales
,
float_ref
,
_
,
_
)
=
(
make_awq_checkpoint_data
(
K
,
N
,
group_size
)
)
blocked_w
,
blocked_zp
,
blocked_s
=
torch
.
ops
.
_C
.
convert_weight_packed_scale_zp
(
packed_qweight
,
packed_qzeros
,
scales
)
x
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
bfloat16
)
out
=
torch
.
ops
.
_C
.
int4_scaled_mm_cpu
(
x
,
blocked_w
,
blocked_zp
,
blocked_s
,
None
)
ref_out
=
torch
.
mm
(
x
.
float
(),
float_ref
)
abs_diff
=
(
out
.
float
()
-
ref_out
).
abs
()
mean_abs
=
abs_diff
.
mean
().
item
()
pct95
=
torch
.
quantile
(
abs_diff
,
0.95
).
item
()
ref_mag
=
ref_out
.
abs
().
mean
().
item
()
+
1e-6
mean_rel
=
mean_abs
/
ref_mag
assert
mean_rel
<
0.05
,
(
f
"Mean relative error
{
mean_rel
:.
4
f
}
exceeds 5% threshold"
)
assert
pct95
<
ref_mag
*
0.15
,
(
f
"95th-pctile abs_diff
{
pct95
:.
4
f
}
exceeds 15% of ref magnitude"
)
print
(
f
" [PASS] INT4 GEMM correct: M=
{
M
}
, K=
{
K
}
, N=
{
N
}
"
)
@
requires_cpu_w4a8_int8
@
pytest
.
mark
.
parametrize
(
"M"
,
[
1
,
8
,
32
])
def
test_gemm_with_bias
(
self
,
M
):
"""INT4 W4A8 GEMM with bias should match reference."""
K
,
N
,
group_size
=
256
,
128
,
128
(
packed_qweight
,
packed_qzeros
,
scales
,
float_ref
,
_
,
_
)
=
(
make_awq_checkpoint_data
(
K
,
N
,
group_size
)
)
blocked_w
,
blocked_zp
,
blocked_s
=
torch
.
ops
.
_C
.
convert_weight_packed_scale_zp
(
packed_qweight
,
packed_qzeros
,
scales
)
bias
=
torch
.
randn
(
N
,
dtype
=
torch
.
float32
)
x
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
bfloat16
)
out
=
torch
.
ops
.
_C
.
int4_scaled_mm_cpu
(
x
,
blocked_w
,
blocked_zp
,
blocked_s
,
bias
)
ref_out
=
torch
.
mm
(
x
.
float
(),
float_ref
)
+
bias
abs_diff
=
(
out
.
float
()
-
ref_out
).
abs
()
mean_abs
=
abs_diff
.
mean
().
item
()
ref_mag
=
ref_out
.
abs
().
mean
().
item
()
+
1e-6
mean_rel
=
mean_abs
/
ref_mag
assert
mean_rel
<
0.05
,
(
f
"Mean relative error
{
mean_rel
:.
4
f
}
with bias exceeds 5%"
)
print
(
f
" [PASS] INT4 GEMM with bias: M=
{
M
}
"
)
@
requires_cpu_w4a8_int8
def
test_gemm_3d_input
(
self
):
"""apply() reshapes 3D input [B, S, K] -> [B*S, K] -> back to 3D."""
K
,
N
,
group_size
=
256
,
128
,
128
(
packed_qweight
,
packed_qzeros
,
scales
,
float_ref
,
_
,
_
)
=
(
make_awq_checkpoint_data
(
K
,
N
,
group_size
)
)
blocked_w
,
blocked_zp
,
blocked_s
=
torch
.
ops
.
_C
.
convert_weight_packed_scale_zp
(
packed_qweight
,
packed_qzeros
,
scales
)
B
,
S
=
2
,
8
x_3d
=
torch
.
randn
(
B
,
S
,
K
,
dtype
=
torch
.
bfloat16
)
x_2d
=
x_3d
.
reshape
(
-
1
,
K
)
out_2d
=
torch
.
ops
.
_C
.
int4_scaled_mm_cpu
(
x_2d
,
blocked_w
,
blocked_zp
,
blocked_s
,
None
)
out_3d
=
out_2d
.
reshape
(
B
,
S
,
N
)
ref_out
=
torch
.
mm
(
x_2d
.
float
(),
float_ref
).
reshape
(
B
,
S
,
N
)
assert
out_3d
.
shape
==
(
B
,
S
,
N
)
abs_diff
=
(
out_3d
.
float
()
-
ref_out
).
abs
()
mean_abs
=
abs_diff
.
mean
().
item
()
ref_mag
=
ref_out
.
abs
().
mean
().
item
()
+
1e-6
mean_rel
=
mean_abs
/
ref_mag
assert
mean_rel
<
0.05
,
f
"Mean relative error
{
mean_rel
:.
4
f
}
for 3D exceeds 5%"
print
(
f
" [PASS] 3D input [
{
B
}
,
{
S
}
,
{
K
}
] -> output [
{
B
}
,
{
S
}
,
{
N
}
]"
)
@
requires_cpu_w4a8_int8
def
test_gemm_fp16_input
(
self
):
"""INT4 GEMM should also work with fp16 input."""
K
,
N
,
group_size
,
M
=
256
,
256
,
128
,
8
(
packed_qweight
,
packed_qzeros
,
scales
,
float_ref
,
_
,
_
)
=
(
make_awq_checkpoint_data
(
K
,
N
,
group_size
)
)
blocked_w
,
blocked_zp
,
blocked_s
=
torch
.
ops
.
_C
.
convert_weight_packed_scale_zp
(
packed_qweight
,
packed_qzeros
,
scales
)
x
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
)
out
=
torch
.
ops
.
_C
.
int4_scaled_mm_cpu
(
x
,
blocked_w
,
blocked_zp
,
blocked_s
,
None
)
ref_out
=
torch
.
mm
(
x
.
float
(),
float_ref
)
abs_diff
=
(
out
.
float
()
-
ref_out
).
abs
()
ref_mag
=
ref_out
.
abs
().
mean
().
item
()
+
1e-6
mean_rel
=
abs_diff
.
mean
().
item
()
/
ref_mag
assert
mean_rel
<
0.05
,
(
f
"Mean relative error
{
mean_rel
:.
4
f
}
for fp16 exceeds 5%"
)
print
(
f
" [PASS] fp16 input M=
{
M
}
, K=
{
K
}
, N=
{
N
}
"
)
class
TestCreateWeightsUnchanged
:
"""Create_weights should still produce correct int4 placeholder shapes."""
@
pytest
.
mark
.
parametrize
(
"K,N,group_size"
,
[
(
128
,
128
,
128
),
(
256
,
256
,
128
),
(
512
,
256
,
64
),
],
)
def
test_int4_placeholder_shapes
(
self
,
K
,
N
,
group_size
):
"""Verify qweight, qzeros, scales shapes."""
pack_factor
=
8
num_groups
=
K
//
group_size
qweight
=
torch
.
empty
(
K
,
N
//
pack_factor
,
dtype
=
torch
.
int32
)
qzeros
=
torch
.
empty
(
num_groups
,
N
//
pack_factor
,
dtype
=
torch
.
int32
)
scales
=
torch
.
empty
(
num_groups
,
N
,
dtype
=
torch
.
bfloat16
)
assert
qweight
.
shape
==
(
K
,
N
//
pack_factor
)
assert
qzeros
.
shape
==
(
num_groups
,
N
//
pack_factor
)
assert
scales
.
shape
==
(
num_groups
,
N
)
print
(
f
" [PASS] create_weights shapes: K=
{
K
}
, N=
{
N
}
, gs=
{
group_size
}
"
)
vllm/_custom_ops.py
View file @
f09daea2
...
@@ -2967,6 +2967,38 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
...
@@ -2967,6 +2967,38 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
return
torch
.
empty
((
M
,
N
),
dtype
=
out_dtype
)
return
torch
.
empty
((
M
,
N
),
dtype
=
out_dtype
)
if
hasattr
(
torch
.
ops
.
_C
,
"convert_weight_packed_scale_zp"
):
@
register_fake
(
"_C::convert_weight_packed_scale_zp"
)
def
convert_weight_packed_scale_zp_fake
(
qweight
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
return
(
torch
.
empty_like
(
qweight
),
torch
.
empty_like
(
qzeros
),
torch
.
empty_like
(
scales
),
)
if
hasattr
(
torch
.
ops
.
_C
,
"int4_scaled_mm_cpu"
):
@
register_fake
(
"_C::int4_scaled_mm_cpu"
)
def
int4_scaled_mm_cpu_fake
(
x
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
w_zeros
:
torch
.
Tensor
,
w_scales
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
N
=
w_scales
.
size
(
0
)
*
w_scales
.
size
(
-
1
)
return
torch
.
empty
((
x
.
size
(
0
),
N
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
_supports_cpu_w4a8_int8
=
bool
(
hasattr
(
torch
.
ops
.
_C
,
"convert_weight_packed_scale_zp"
))
class
CPUDNNLGEMMHandler
:
class
CPUDNNLGEMMHandler
:
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
handler_tensor
:
torch
.
Tensor
|
None
=
None
self
.
handler_tensor
:
torch
.
Tensor
|
None
=
None
...
...
vllm/envs.py
View file @
f09daea2
...
@@ -52,6 +52,7 @@ if TYPE_CHECKING:
...
@@ -52,6 +52,7 @@ if TYPE_CHECKING:
VLLM_CPU_NUM_OF_RESERVED_CPU
:
int
|
None
=
None
VLLM_CPU_NUM_OF_RESERVED_CPU
:
int
|
None
=
None
VLLM_CPU_SGL_KERNEL
:
bool
=
False
VLLM_CPU_SGL_KERNEL
:
bool
=
False
VLLM_ZENTORCH_WEIGHT_PREPACK
:
bool
=
True
VLLM_ZENTORCH_WEIGHT_PREPACK
:
bool
=
True
VLLM_CPU_INT4_W4A8
:
bool
=
True
VLLM_XLA_CACHE_PATH
:
str
=
os
.
path
.
join
(
VLLM_CACHE_ROOT
,
"xla_cache"
)
VLLM_XLA_CACHE_PATH
:
str
=
os
.
path
.
join
(
VLLM_CACHE_ROOT
,
"xla_cache"
)
VLLM_XLA_CHECK_RECOMPILATION
:
bool
=
False
VLLM_XLA_CHECK_RECOMPILATION
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
:
Literal
[
"auto"
,
"nccl"
,
"shm"
]
=
"auto"
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
:
Literal
[
"auto"
,
"nccl"
,
"shm"
]
=
"auto"
...
@@ -728,6 +729,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -728,6 +729,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ZENTORCH_WEIGHT_PREPACK"
:
lambda
:
bool
(
"VLLM_ZENTORCH_WEIGHT_PREPACK"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ZENTORCH_WEIGHT_PREPACK"
,
"1"
))
int
(
os
.
getenv
(
"VLLM_ZENTORCH_WEIGHT_PREPACK"
,
"1"
))
),
),
# (CPU backend only) whether to use SGLang INT4 W4A8 kernels for AWQ.
"VLLM_CPU_INT4_W4A8"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_CPU_INT4_W4A8"
,
"1"
))),
# If the env var is set, Ray Compiled Graph uses the specified
# If the env var is set, Ray Compiled Graph uses the specified
# channel type to communicate between workers belonging to
# channel type to communicate between workers belonging to
# different pipeline-parallel stages.
# different pipeline-parallel stages.
...
...
vllm/model_executor/layers/quantization/cpu_wna16.py
View file @
f09daea2
...
@@ -7,9 +7,8 @@ import torch
...
@@ -7,9 +7,8 @@ import torch
from
safetensors.torch
import
_TYPES
as
_SAFETENSORS_TO_TORCH_DTYPE
from
safetensors.torch
import
_TYPES
as
_SAFETENSORS_TO_TORCH_DTYPE
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm._custom_ops
import
(
import
vllm.envs
as
envs
cpu_gemm_wna16
,
from
vllm
import
_custom_ops
as
ops
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearBase
,
...
@@ -230,7 +229,14 @@ class CPUAWQLinearMethod(LinearMethodBase):
...
@@ -230,7 +229,14 @@ class CPUAWQLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"scales"
,
scales
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
torch
.
set_printoptions
(
profile
=
"full"
,
linewidth
=
5000
,
sci_mode
=
False
)
layer
.
use_w4a8
=
envs
.
VLLM_CPU_INT4_W4A8
and
torch
.
cpu
.
_is_amx_tile_supported
()
if
layer
.
use_w4a8
:
self
.
_process_weights_sglang_int4
(
layer
)
else
:
self
.
_process_weights_woq
(
layer
)
def
_process_weights_woq
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Original WOQ int4 repack path."""
packed_weight
=
layer
.
qweight
.
data
packed_weight
=
layer
.
qweight
.
data
packed_zeros
=
layer
.
qzeros
.
data
packed_zeros
=
layer
.
qzeros
.
data
group_num
=
packed_zeros
.
size
(
0
)
group_num
=
packed_zeros
.
size
(
0
)
...
@@ -266,8 +272,6 @@ class CPUAWQLinearMethod(LinearMethodBase):
...
@@ -266,8 +272,6 @@ class CPUAWQLinearMethod(LinearMethodBase):
)
)
zeros
=
pack_cols
(
zeros
,
bits
,
group_num
,
output_size
).
contiguous
()
zeros
=
pack_cols
(
zeros
,
bits
,
group_num
,
output_size
).
contiguous
()
# make 16 output channel as a block and transpose to
# the make the block contiguous
weight
=
pack_cols
(
weight
,
bits
,
input_size
,
output_size
)
weight
=
pack_cols
(
weight
,
bits
,
input_size
,
output_size
)
weight
=
(
weight
=
(
weight
.
view
(
input_size
,
-
1
,
16
//
pack_factor
)
weight
.
view
(
input_size
,
-
1
,
16
//
pack_factor
)
...
@@ -278,13 +282,40 @@ class CPUAWQLinearMethod(LinearMethodBase):
...
@@ -278,13 +282,40 @@ class CPUAWQLinearMethod(LinearMethodBase):
layer
.
qweight
.
data
=
weight
layer
.
qweight
.
data
=
weight
layer
.
qzeros
.
data
=
zeros
layer
.
qzeros
.
data
=
zeros
def
_process_weights_sglang_int4
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""SGLang INT4 W4A8 path: pack int4 weights with VNNI reordering."""
packed_weight
=
layer
.
qweight
.
data
packed_zeros
=
layer
.
qzeros
.
data
scales
=
layer
.
scales
.
data
blocked_w
,
blocked_zp
,
blocked_s
=
torch
.
ops
.
_C
.
convert_weight_packed_scale_zp
(
packed_weight
,
packed_zeros
,
scales
)
layer
.
packed_weight
=
blocked_w
layer
.
packed_qzeros
=
blocked_zp
layer
.
packed_scales
=
blocked_s
layer
.
qweight
=
None
layer
.
qzeros
=
None
layer
.
scales
=
None
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
x
=
cpu_gemm_wna16
(
if
layer
.
use_w4a8
:
return
self
.
_apply_sglang_int4
(
layer
,
x
,
bias
)
return
self
.
_apply_woq
(
layer
,
x
,
bias
)
def
_apply_woq
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""Original WOQ int4 GEMM path."""
x
=
ops
.
cpu_gemm_wna16
(
input
=
x
,
input
=
x
,
q_weight
=
layer
.
qweight
,
q_weight
=
layer
.
qweight
,
scales
=
layer
.
scales
,
scales
=
layer
.
scales
,
...
@@ -296,6 +327,26 @@ class CPUAWQLinearMethod(LinearMethodBase):
...
@@ -296,6 +327,26 @@ class CPUAWQLinearMethod(LinearMethodBase):
)
)
return
x
return
x
def
_apply_sglang_int4
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""SGLang INT4 W4A8 GEMM path."""
x_shape
=
x
.
shape
x_2d
=
x
.
reshape
(
-
1
,
x_shape
[
-
1
])
if
len
(
x_shape
)
>
2
else
x
out
=
torch
.
ops
.
_C
.
int4_scaled_mm_cpu
(
x_2d
,
layer
.
packed_weight
,
layer
.
packed_qzeros
,
layer
.
packed_scales
,
bias
,
)
out
=
out
.
reshape
(
x_shape
[:
-
1
]
+
(
out
.
size
(
-
1
),))
if
len
(
x_shape
)
>
2
else
out
return
out
def
_get_isa_hint
(
dtype
:
torch
.
dtype
)
->
str
:
def
_get_isa_hint
(
dtype
:
torch
.
dtype
)
->
str
:
supports_amx
=
torch
.
cpu
.
_is_amx_tile_supported
()
supports_amx
=
torch
.
cpu
.
_is_amx_tile_supported
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment