Unverified Commit f09daea2 authored by Yintong Lu's avatar Yintong Lu Committed by GitHub
Browse files

[CPU] Support int8 compute mode in CPU AWQ (#35697)


Signed-off-by: default avatarYintong Lu <yintong.lu@intel.com>
parent 42318c84
...@@ -13,12 +13,14 @@ steps: ...@@ -13,12 +13,14 @@ steps:
- tests/kernels/attention/test_cpu_attn.py - tests/kernels/attention/test_cpu_attn.py
- tests/kernels/moe/test_cpu_fused_moe.py - tests/kernels/moe/test_cpu_fused_moe.py
- tests/kernels/test_onednn.py - tests/kernels/test_onednn.py
- tests/kernels/test_awq_int4_to_int8.py
commands: commands:
- | - |
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m " bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
pytest -x -v -s tests/kernels/test_onednn.py" pytest -x -v -s tests/kernels/test_onednn.py
pytest -x -v -s tests/kernels/test_awq_int4_to_int8.py"
- label: CPU-Compatibility Tests - label: CPU-Compatibility Tests
depends_on: [] depends_on: []
......
...@@ -373,6 +373,7 @@ if (ENABLE_X86_ISA) ...@@ -373,6 +373,7 @@ if (ENABLE_X86_ISA)
"csrc/cpu/sgl-kernels/gemm.cpp" "csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp" "csrc/cpu/sgl-kernels/gemm_int8.cpp"
"csrc/cpu/sgl-kernels/gemm_fp8.cpp" "csrc/cpu/sgl-kernels/gemm_fp8.cpp"
"csrc/cpu/sgl-kernels/gemm_int4.cpp"
"csrc/cpu/sgl-kernels/moe.cpp" "csrc/cpu/sgl-kernels/moe.cpp"
"csrc/cpu/sgl-kernels/moe_int8.cpp" "csrc/cpu/sgl-kernels/moe_int8.cpp"
"csrc/cpu/sgl-kernels/moe_fp8.cpp") "csrc/cpu/sgl-kernels/moe_fp8.cpp")
......
...@@ -117,6 +117,14 @@ inline void parallel_for(int n, const func_t& f) { ...@@ -117,6 +117,14 @@ inline void parallel_for(int n, const func_t& f) {
#endif #endif
} }
inline int get_thread_num() {
#if defined(_OPENMP)
return omp_get_thread_num();
#else
return 0;
#endif
}
// for 1d parallel, use `actual_nth` // for 1d parallel, use `actual_nth`
// for 2d parallel, use even nths, e.g. 43->42 // for 2d parallel, use even nths, e.g. 43->42
int inline adjust_num_threads(int m) { int inline adjust_num_threads(int m) {
......
...@@ -17,8 +17,8 @@ constexpr int block_size_n() { return 2 * TILE_N; } ...@@ -17,8 +17,8 @@ constexpr int block_size_n() { return 2 * TILE_N; }
template <typename T> inline bool can_use_brgemm(int M); template <typename T> inline bool can_use_brgemm(int M);
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; } template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<at::Half>(int M) { return true; } template <> inline bool can_use_brgemm<at::Half>(int M) { return true; }
// TODO: add u8s8 brgemm, this requires PyTorch 2.7 template <> inline bool can_use_brgemm<int8_t>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<int8_t>(int M) { return false; } template <> inline bool can_use_brgemm<uint8_t>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; } template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; }
template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; } template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; }
...@@ -40,9 +40,17 @@ inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) { ...@@ -40,9 +40,17 @@ inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
return use_int8_w8a8 ? K + sizeof(int32_t) : K; return use_int8_w8a8 ? K + sizeof(int32_t) : K;
} }
// pack weight to vnni format inline int64_t get_4bit_block_k_size(int64_t group_size) {
return group_size > 128 ? 128 : group_size;
}
// pack weight into vnni format
at::Tensor convert_weight_packed(at::Tensor& weight); at::Tensor convert_weight_packed(at::Tensor& weight);
// pack weight to vnni format for int4 (adapted from sglang)
std::tuple<at::Tensor, at::Tensor, at::Tensor>
convert_weight_packed_scale_zp(at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
// moe implementations for int8 w8a8 // moe implementations for int8 w8a8
template <typename scalar_t> template <typename scalar_t>
void fused_experts_int8_kernel_impl( void fused_experts_int8_kernel_impl(
...@@ -233,6 +241,31 @@ void tinygemm_kernel( ...@@ -233,6 +241,31 @@ void tinygemm_kernel(
int64_t strideBs, int64_t strideBs,
bool brg); bool brg);
// int4 scaled GEMM (adapted from sglang)
at::Tensor int4_scaled_mm_cpu(
at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros, at::Tensor& w_scales, std::optional<at::Tensor> bias);
// int4 tinygemm kernel interface(adapted from sglang)
template <typename scalar_t>
void tinygemm_kernel(
scalar_t* C,
float* C_temp,
const uint8_t* A,
const float* scales_a,
const int32_t* qzeros_a,
const uint8_t* B,
const float* scales_b,
const int8_t* qzeros_b,
const int32_t* compensation,
int8_t* dqB_tmp,
int64_t M,
int64_t K,
int64_t lda,
int64_t ldc_f,
int64_t ldc_s,
bool store_out,
bool use_brgemm);
// TODO: debug print, remove me later // TODO: debug print, remove me later
inline void print_16x32i(const __m512i x) { inline void print_16x32i(const __m512i x) {
int32_t a[16]; int32_t a[16];
......
This diff is collapsed.
...@@ -79,6 +79,14 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, ...@@ -79,6 +79,14 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
const std::optional<at::Tensor>& bias, const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype, bool is_vnni); at::ScalarType out_dtype, bool is_vnni);
// Adapted from sglang: INT4 W4A8 kernels
std::tuple<at::Tensor, at::Tensor, at::Tensor> convert_weight_packed_scale_zp(
at::Tensor qweight, at::Tensor qzeros, at::Tensor scales);
at::Tensor int4_scaled_mm_cpu(at::Tensor& x, at::Tensor& w, at::Tensor& w_zeros,
at::Tensor& w_scales,
std::optional<at::Tensor> bias);
torch::Tensor get_scheduler_metadata( torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q, const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim, const int64_t num_heads_kv, const int64_t head_dim,
...@@ -285,6 +293,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -285,6 +293,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor"); "Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
ops.impl("int8_scaled_mm_with_quant", torch::kCPU, ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
&int8_scaled_mm_with_quant); &int8_scaled_mm_with_quant);
// Adapted from sglang: INT4 W4A8 kernels
ops.def(
"convert_weight_packed_scale_zp(Tensor qweight, Tensor qzeros, "
"Tensor scales) -> (Tensor, Tensor, Tensor)");
ops.impl("convert_weight_packed_scale_zp", torch::kCPU,
&convert_weight_packed_scale_zp);
ops.def(
"int4_scaled_mm_cpu(Tensor(a0!) x, Tensor(a1!) w, Tensor(a2!) w_zeros, "
"Tensor(a3!) w_scales, Tensor? bias) -> Tensor");
ops.impl("int4_scaled_mm_cpu", torch::kCPU, &int4_scaled_mm_cpu);
#endif #endif
// CPU attention kernels // CPU attention kernels
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for AWQ INT4 W4A8 GEMM pipeline (SGLang kernel migration).
Part 1: Weight packing tests
- convert_weight_packed_scale_zp correctness
Part 2: INT4 W4A8 GEMM tests
- int4_scaled_mm_cpu correctness w.r.t. float reference
- Bias, 3D input, various shapes
Part 3: create_weights shapes
cmd:
VLLM_CPU_INT4_W4A8=1 python -m pytest tests/kernels/test_awq_int4_to_int8.py -v -s
"""
import numpy as np
import pytest
import torch
from vllm._custom_ops import _supports_cpu_w4a8_int8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_cols,
)
from vllm.platforms import current_platform
if not current_platform.is_cpu():
pytest.skip("skipping CPU-only tests", allow_module_level=True)
requires_cpu_w4a8_int8 = pytest.mark.skipif(
not _supports_cpu_w4a8_int8,
reason="Requires vLLM CPU build with SGLang INT4 W4A8 kernels",
)
def make_awq_checkpoint_data(K, N, group_size, seed=42):
"""Create synthetic AWQ checkpoint data in packed int32 format.
Returns:
packed_qweight: [K, N//8] int32 (AWQ interleaved + packed)
packed_qzeros: [num_groups, N//8] int32 (AWQ interleaved + packed)
scales: [num_groups, N] float32
float_ref: [K, N] float32, reference dequantized weights
weight_int4_orig: [K, N] int32, original int4 values (0-15)
zeros_int4_orig: [num_groups, N] int32, original zero points (0-15)
"""
rng = np.random.RandomState(seed)
num_groups = K // group_size
weight_int4_orig = torch.from_numpy(
rng.randint(0, 16, size=(K, N)).astype(np.int32)
)
zeros_int4_orig = torch.from_numpy(
rng.randint(0, 16, size=(num_groups, N)).astype(np.int32)
)
scales = torch.from_numpy((rng.randn(num_groups, N) * 0.05).astype(np.float32))
scales_exp = scales.repeat_interleave(group_size, dim=0)
zeros_exp = zeros_int4_orig.repeat_interleave(group_size, dim=0)
float_ref = (weight_int4_orig.float() - zeros_exp.float()) * scales_exp
awq_interleave = [0, 2, 4, 6, 1, 3, 5, 7]
weight_interleaved = (
weight_int4_orig.reshape(-1, 8)[:, awq_interleave].reshape(K, N).contiguous()
)
packed_qweight = pack_cols(weight_interleaved, 4, K, N)
zeros_interleaved = (
zeros_int4_orig.reshape(-1, 8)[:, awq_interleave]
.reshape(num_groups, N)
.contiguous()
)
packed_qzeros = pack_cols(zeros_interleaved, 4, num_groups, N)
return (
packed_qweight,
packed_qzeros,
scales,
float_ref,
weight_int4_orig,
zeros_int4_orig,
)
class TestConvertWeightPackedScaleZp:
"""Tests for convert_weight_packed_scale_zp weightpacking."""
@requires_cpu_w4a8_int8
@pytest.mark.parametrize(
"K,N,group_size",
[
(128, 128, 128),
(256, 256, 128),
(512, 256, 64),
],
)
def test_packing_output_shapes(self, K, N, group_size):
"""Packed outputs should have expected shapes."""
(packed_qweight, packed_qzeros, scales, _, _, _) = make_awq_checkpoint_data(
K, N, group_size
)
blocked_w, blocked_zp, blocked_s = torch.ops._C.convert_weight_packed_scale_zp(
packed_qweight, packed_qzeros, scales
)
block_n = 32
Nc = N // block_n
assert blocked_w.dim() >= 2, (
f"blocked_w should have >= 2 dims, got {blocked_w.dim()}"
)
assert blocked_s.size(0) == Nc, (
f"Expected Nc={Nc} scale blocks, got {blocked_s.size(0)}"
)
assert blocked_zp.size(0) == Nc, (
f"Expected Nc={Nc} qzeros blocks, got {blocked_zp.size(0)}"
)
print(
f" [PASS] packing shapes K={K}, N={N}, gs={group_size}: "
f"blocked_w={list(blocked_w.shape)}, "
f"blocked_s={list(blocked_s.shape)}, blocked_zp={list(blocked_zp.shape)}"
)
class TestInt4ScaledMmCpu:
"""Tests for int4_scaled_mm_cpu GEMM kernel."""
@requires_cpu_w4a8_int8
@pytest.mark.parametrize(
"M,K,N,group_size",
[
(1, 128, 128, 128),
(4, 256, 256, 128),
(16, 512, 256, 64),
(32, 256, 512, 128),
(64, 512, 512, 128),
],
)
def test_gemm_vs_float_reference(self, M, K, N, group_size):
"""INT4 W4A8 GEMM should approximate float matmul."""
(packed_qweight, packed_qzeros, scales, float_ref, _, _) = (
make_awq_checkpoint_data(K, N, group_size)
)
blocked_w, blocked_zp, blocked_s = torch.ops._C.convert_weight_packed_scale_zp(
packed_qweight, packed_qzeros, scales
)
x = torch.randn(M, K, dtype=torch.bfloat16)
out = torch.ops._C.int4_scaled_mm_cpu(x, blocked_w, blocked_zp, blocked_s, None)
ref_out = torch.mm(x.float(), float_ref)
abs_diff = (out.float() - ref_out).abs()
mean_abs = abs_diff.mean().item()
pct95 = torch.quantile(abs_diff, 0.95).item()
ref_mag = ref_out.abs().mean().item() + 1e-6
mean_rel = mean_abs / ref_mag
assert mean_rel < 0.05, (
f"Mean relative error {mean_rel:.4f} exceeds 5% threshold"
)
assert pct95 < ref_mag * 0.15, (
f"95th-pctile abs_diff {pct95:.4f} exceeds 15% of ref magnitude"
)
print(f" [PASS] INT4 GEMM correct: M={M}, K={K}, N={N}")
@requires_cpu_w4a8_int8
@pytest.mark.parametrize("M", [1, 8, 32])
def test_gemm_with_bias(self, M):
"""INT4 W4A8 GEMM with bias should match reference."""
K, N, group_size = 256, 128, 128
(packed_qweight, packed_qzeros, scales, float_ref, _, _) = (
make_awq_checkpoint_data(K, N, group_size)
)
blocked_w, blocked_zp, blocked_s = torch.ops._C.convert_weight_packed_scale_zp(
packed_qweight, packed_qzeros, scales
)
bias = torch.randn(N, dtype=torch.float32)
x = torch.randn(M, K, dtype=torch.bfloat16)
out = torch.ops._C.int4_scaled_mm_cpu(x, blocked_w, blocked_zp, blocked_s, bias)
ref_out = torch.mm(x.float(), float_ref) + bias
abs_diff = (out.float() - ref_out).abs()
mean_abs = abs_diff.mean().item()
ref_mag = ref_out.abs().mean().item() + 1e-6
mean_rel = mean_abs / ref_mag
assert mean_rel < 0.05, (
f"Mean relative error {mean_rel:.4f} with bias exceeds 5%"
)
print(f" [PASS] INT4 GEMM with bias: M={M}")
@requires_cpu_w4a8_int8
def test_gemm_3d_input(self):
"""apply() reshapes 3D input [B, S, K] -> [B*S, K] -> back to 3D."""
K, N, group_size = 256, 128, 128
(packed_qweight, packed_qzeros, scales, float_ref, _, _) = (
make_awq_checkpoint_data(K, N, group_size)
)
blocked_w, blocked_zp, blocked_s = torch.ops._C.convert_weight_packed_scale_zp(
packed_qweight, packed_qzeros, scales
)
B, S = 2, 8
x_3d = torch.randn(B, S, K, dtype=torch.bfloat16)
x_2d = x_3d.reshape(-1, K)
out_2d = torch.ops._C.int4_scaled_mm_cpu(
x_2d, blocked_w, blocked_zp, blocked_s, None
)
out_3d = out_2d.reshape(B, S, N)
ref_out = torch.mm(x_2d.float(), float_ref).reshape(B, S, N)
assert out_3d.shape == (B, S, N)
abs_diff = (out_3d.float() - ref_out).abs()
mean_abs = abs_diff.mean().item()
ref_mag = ref_out.abs().mean().item() + 1e-6
mean_rel = mean_abs / ref_mag
assert mean_rel < 0.05, f"Mean relative error {mean_rel:.4f} for 3D exceeds 5%"
print(f" [PASS] 3D input [{B},{S},{K}] -> output [{B},{S},{N}]")
@requires_cpu_w4a8_int8
def test_gemm_fp16_input(self):
"""INT4 GEMM should also work with fp16 input."""
K, N, group_size, M = 256, 256, 128, 8
(packed_qweight, packed_qzeros, scales, float_ref, _, _) = (
make_awq_checkpoint_data(K, N, group_size)
)
blocked_w, blocked_zp, blocked_s = torch.ops._C.convert_weight_packed_scale_zp(
packed_qweight, packed_qzeros, scales
)
x = torch.randn(M, K, dtype=torch.float16)
out = torch.ops._C.int4_scaled_mm_cpu(x, blocked_w, blocked_zp, blocked_s, None)
ref_out = torch.mm(x.float(), float_ref)
abs_diff = (out.float() - ref_out).abs()
ref_mag = ref_out.abs().mean().item() + 1e-6
mean_rel = abs_diff.mean().item() / ref_mag
assert mean_rel < 0.05, (
f"Mean relative error {mean_rel:.4f} for fp16 exceeds 5%"
)
print(f" [PASS] fp16 input M={M}, K={K}, N={N}")
class TestCreateWeightsUnchanged:
"""Create_weights should still produce correct int4 placeholder shapes."""
@pytest.mark.parametrize(
"K,N,group_size",
[
(128, 128, 128),
(256, 256, 128),
(512, 256, 64),
],
)
def test_int4_placeholder_shapes(self, K, N, group_size):
"""Verify qweight, qzeros, scales shapes."""
pack_factor = 8
num_groups = K // group_size
qweight = torch.empty(K, N // pack_factor, dtype=torch.int32)
qzeros = torch.empty(num_groups, N // pack_factor, dtype=torch.int32)
scales = torch.empty(num_groups, N, dtype=torch.bfloat16)
assert qweight.shape == (K, N // pack_factor)
assert qzeros.shape == (num_groups, N // pack_factor)
assert scales.shape == (num_groups, N)
print(f" [PASS] create_weights shapes: K={K}, N={N}, gs={group_size}")
...@@ -2967,6 +2967,38 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): ...@@ -2967,6 +2967,38 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
return torch.empty((M, N), dtype=out_dtype) return torch.empty((M, N), dtype=out_dtype)
if hasattr(torch.ops._C, "convert_weight_packed_scale_zp"):
@register_fake("_C::convert_weight_packed_scale_zp")
def convert_weight_packed_scale_zp_fake(
qweight: torch.Tensor,
qzeros: torch.Tensor,
scales: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return (
torch.empty_like(qweight),
torch.empty_like(qzeros),
torch.empty_like(scales),
)
if hasattr(torch.ops._C, "int4_scaled_mm_cpu"):
@register_fake("_C::int4_scaled_mm_cpu")
def int4_scaled_mm_cpu_fake(
x: torch.Tensor,
w: torch.Tensor,
w_zeros: torch.Tensor,
w_scales: torch.Tensor,
bias: torch.Tensor | None,
) -> torch.Tensor:
N = w_scales.size(0) * w_scales.size(-1)
return torch.empty((x.size(0), N), dtype=x.dtype, device=x.device)
_supports_cpu_w4a8_int8 = bool(hasattr(torch.ops._C, "convert_weight_packed_scale_zp"))
class CPUDNNLGEMMHandler: class CPUDNNLGEMMHandler:
def __init__(self) -> None: def __init__(self) -> None:
self.handler_tensor: torch.Tensor | None = None self.handler_tensor: torch.Tensor | None = None
......
...@@ -52,6 +52,7 @@ if TYPE_CHECKING: ...@@ -52,6 +52,7 @@ if TYPE_CHECKING:
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
VLLM_CPU_SGL_KERNEL: bool = False VLLM_CPU_SGL_KERNEL: bool = False
VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True
VLLM_CPU_INT4_W4A8: bool = True
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_XLA_CHECK_RECOMPILATION: bool = False VLLM_XLA_CHECK_RECOMPILATION: bool = False
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
...@@ -728,6 +729,8 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -728,6 +729,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool( "VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool(
int(os.getenv("VLLM_ZENTORCH_WEIGHT_PREPACK", "1")) int(os.getenv("VLLM_ZENTORCH_WEIGHT_PREPACK", "1"))
), ),
# (CPU backend only) whether to use SGLang INT4 W4A8 kernels for AWQ.
"VLLM_CPU_INT4_W4A8": lambda: bool(int(os.getenv("VLLM_CPU_INT4_W4A8", "1"))),
# If the env var is set, Ray Compiled Graph uses the specified # If the env var is set, Ray Compiled Graph uses the specified
# channel type to communicate between workers belonging to # channel type to communicate between workers belonging to
# different pipeline-parallel stages. # different pipeline-parallel stages.
......
...@@ -7,9 +7,8 @@ import torch ...@@ -7,9 +7,8 @@ import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm._custom_ops import ( import vllm.envs as envs
cpu_gemm_wna16, from vllm import _custom_ops as ops
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
...@@ -230,7 +229,14 @@ class CPUAWQLinearMethod(LinearMethodBase): ...@@ -230,7 +229,14 @@ class CPUAWQLinearMethod(LinearMethodBase):
layer.register_parameter("scales", scales) layer.register_parameter("scales", scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
torch.set_printoptions(profile="full", linewidth=5000, sci_mode=False) layer.use_w4a8 = envs.VLLM_CPU_INT4_W4A8 and torch.cpu._is_amx_tile_supported()
if layer.use_w4a8:
self._process_weights_sglang_int4(layer)
else:
self._process_weights_woq(layer)
def _process_weights_woq(self, layer: torch.nn.Module) -> None:
"""Original WOQ int4 repack path."""
packed_weight = layer.qweight.data packed_weight = layer.qweight.data
packed_zeros = layer.qzeros.data packed_zeros = layer.qzeros.data
group_num = packed_zeros.size(0) group_num = packed_zeros.size(0)
...@@ -266,8 +272,6 @@ class CPUAWQLinearMethod(LinearMethodBase): ...@@ -266,8 +272,6 @@ class CPUAWQLinearMethod(LinearMethodBase):
) )
zeros = pack_cols(zeros, bits, group_num, output_size).contiguous() zeros = pack_cols(zeros, bits, group_num, output_size).contiguous()
# make 16 output channel as a block and transpose to
# the make the block contiguous
weight = pack_cols(weight, bits, input_size, output_size) weight = pack_cols(weight, bits, input_size, output_size)
weight = ( weight = (
weight.view(input_size, -1, 16 // pack_factor) weight.view(input_size, -1, 16 // pack_factor)
...@@ -278,13 +282,40 @@ class CPUAWQLinearMethod(LinearMethodBase): ...@@ -278,13 +282,40 @@ class CPUAWQLinearMethod(LinearMethodBase):
layer.qweight.data = weight layer.qweight.data = weight
layer.qzeros.data = zeros layer.qzeros.data = zeros
def _process_weights_sglang_int4(self, layer: torch.nn.Module) -> None:
"""SGLang INT4 W4A8 path: pack int4 weights with VNNI reordering."""
packed_weight = layer.qweight.data
packed_zeros = layer.qzeros.data
scales = layer.scales.data
blocked_w, blocked_zp, blocked_s = torch.ops._C.convert_weight_packed_scale_zp(
packed_weight, packed_zeros, scales
)
layer.packed_weight = blocked_w
layer.packed_qzeros = blocked_zp
layer.packed_scales = blocked_s
layer.qweight = None
layer.qzeros = None
layer.scales = None
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
x = cpu_gemm_wna16( if layer.use_w4a8:
return self._apply_sglang_int4(layer, x, bias)
return self._apply_woq(layer, x, bias)
def _apply_woq(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""Original WOQ int4 GEMM path."""
x = ops.cpu_gemm_wna16(
input=x, input=x,
q_weight=layer.qweight, q_weight=layer.qweight,
scales=layer.scales, scales=layer.scales,
...@@ -296,6 +327,26 @@ class CPUAWQLinearMethod(LinearMethodBase): ...@@ -296,6 +327,26 @@ class CPUAWQLinearMethod(LinearMethodBase):
) )
return x return x
def _apply_sglang_int4(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""SGLang INT4 W4A8 GEMM path."""
x_shape = x.shape
x_2d = x.reshape(-1, x_shape[-1]) if len(x_shape) > 2 else x
out = torch.ops._C.int4_scaled_mm_cpu(
x_2d,
layer.packed_weight,
layer.packed_qzeros,
layer.packed_scales,
bias,
)
out = out.reshape(x_shape[:-1] + (out.size(-1),)) if len(x_shape) > 2 else out
return out
def _get_isa_hint(dtype: torch.dtype) -> str: def _get_isa_hint(dtype: torch.dtype) -> str:
supports_amx = torch.cpu._is_amx_tile_supported() supports_amx = torch.cpu._is_amx_tile_supported()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment