Unverified Commit 5aa1ebd2 authored by Peng Zhang's avatar Peng Zhang Committed by GitHub
Browse files

[2/n]decouple quantization implementation from vLLM dependency (#8112)


Co-authored-by: default avatarwalker-ai <yiyun.wyt@antgroup.com>
Co-authored-by: default avatarleoneo <1320612015@qq.com>
parent 4dbf4360
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif #endif
#include "gptq_marlin/marlin.cuh" #include "gemm/marlin/marlin.cuh"
#include "gptq_marlin/marlin_dtypes.cuh" #include "gemm/marlin/marlin_dtypes.cuh"
#include "scalar_type.hpp" #include "scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \ #define MARLIN_KERNEL_PARAMS \
......
...@@ -18,13 +18,12 @@ ...@@ -18,13 +18,12 @@
/* /*
* Adapted from https://github.com/IST-DASLab/marlin * Adapted from https://github.com/IST-DASLab/marlin
*/ */
#ifndef MARLIN_NAMESPACE_NAME #ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif #endif
#include "gptq_marlin/marlin.cuh" #include "gemm/marlin/marlin.cuh"
#include "gptq_marlin/marlin_dtypes.cuh" #include "gemm/marlin/marlin_dtypes.cuh"
#include "scalar_type.hpp" #include "scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif #endif
#include "core/registration.h"
#include "kernel.h" #include "kernel.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...@@ -50,8 +49,7 @@ __global__ void permute_cols_kernel( ...@@ -50,8 +49,7 @@ __global__ void permute_cols_kernel(
int size_m, int size_m,
int size_k, int size_k,
int top_k) {}; int top_k) {};
}
} // namespace marlin
torch::Tensor moe_wna16_marlin_gemm( torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, torch::Tensor& a,
......
...@@ -298,6 +298,7 @@ static inline constexpr auto kS8 = ScalarType::int_(8); ...@@ -298,6 +298,7 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
static inline constexpr auto kU8 = ScalarType::uint(8); static inline constexpr auto kU8 = ScalarType::uint(8);
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
...@@ -313,6 +314,7 @@ static inline constexpr auto kInt8 = kS8; ...@@ -313,6 +314,7 @@ static inline constexpr auto kInt8 = kS8;
static inline constexpr auto kUint8 = kU8; static inline constexpr auto kUint8 = kU8;
static inline constexpr auto kUint8b128 = kU8B128; static inline constexpr auto kUint8b128 = kU8B128;
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
static inline constexpr auto kFloat6_e3m2f = kFE3M2f; static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
static inline constexpr auto kFloat8_e5m2 = kFE5M2; static inline constexpr auto kFloat8_e5m2 = kFE5M2;
......
...@@ -224,6 +224,40 @@ void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const t ...@@ -224,6 +224,40 @@ void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const t
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b); void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
torch::Tensor gptq_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
sglang::ScalarTypeId const& b_q_type_id,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
torch::Tensor gptq_gemm(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_shuffle,
int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
/* /*
* From csrc/moe * From csrc/moe
*/ */
...@@ -340,15 +374,6 @@ void scaled_fp4_experts_quant( ...@@ -340,15 +374,6 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input_offset_by_experts, torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts); torch::Tensor const& output_scale_offset_by_experts);
namespace marlin_moe_wna16 {
torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
} // namespace marlin_moe_wna16
/* /*
* From csrc/speculative * From csrc/speculative
*/ */
......
...@@ -44,6 +44,9 @@ from sgl_kernel.gemm import ( ...@@ -44,6 +44,9 @@ from sgl_kernel.gemm import (
dsv3_router_gemm, dsv3_router_gemm,
fp8_blockwise_scaled_mm, fp8_blockwise_scaled_mm,
fp8_scaled_mm, fp8_scaled_mm,
gptq_gemm,
gptq_marlin_gemm,
gptq_shuffle,
int8_scaled_mm, int8_scaled_mm,
qserve_w4a8_per_chn_gemm, qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm, qserve_w4a8_per_group_gemm,
......
...@@ -2,6 +2,7 @@ import functools ...@@ -2,6 +2,7 @@ import functools
from typing import Optional from typing import Optional
import torch import torch
from sgl_kernel import silu_and_mul
def get_scalar_type(num_bits: int, has_zp: bool): def get_scalar_type(num_bits: int, has_zp: bool):
...@@ -165,7 +166,7 @@ def fused_marlin_moe( ...@@ -165,7 +166,7 @@ def fused_marlin_moe(
is_zp_float=False, is_zp_float=False,
) )
torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2)
if expert_map is not None: if expert_map is not None:
intermediate_cache3.zero_() intermediate_cache3.zero_()
......
from typing import List, Optional, Tuple from typing import Optional, Tuple
import torch import torch
from sgl_kernel.scalar_type import ScalarType
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
...@@ -353,3 +354,62 @@ def scaled_fp4_experts_quant( ...@@ -353,3 +354,62 @@ def scaled_fp4_experts_quant(
) )
output_scales = output_scales.view(torch.float8_e4m3fn) output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales return output, output_scales
# GPTQ kernels
def gptq_marlin_gemm(
a: torch.Tensor,
c: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_zeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool = True,
use_atomic_add: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_marlin_gemm(
a,
c,
b_q_weight,
b_scales,
global_scale,
b_zeros,
g_idx,
perm,
workspace,
b_q_type.id,
size_m,
size_n,
size_k,
is_k_full,
use_atomic_add,
use_fp32_reduce,
is_zp_float,
)
def gptq_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor,
use_shuffle: bool,
bit: int,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)
...@@ -7,8 +7,8 @@ def gptq_marlin_repack( ...@@ -7,8 +7,8 @@ def gptq_marlin_repack(
size_k, size_k,
size_n, size_n,
num_bits, num_bits,
): ) -> torch.Tensor:
torch.ops.sgl_kernel.gptq_marlin_repack.default( return torch.ops.sgl_kernel.gptq_marlin_repack(
b_q_weight, b_q_weight,
perm, perm,
size_k, size_k,
......
import pytest
import torch
from sgl_kernel import gptq_gemm
from sglang.srt.layers.quantization.utils import pack_cols, pack_rows
def torch_dequantize(q_weight, q_zeros, scales, g_idx, use_shuffle, bit, K, N):
assert bit == 4, "Reference dequantization only supports 4-bit"
group_size = K // scales.shape[0]
pack_factor = 32 // bit
# unpack q_weight: (K//pack_factor, N) -> (K, N)
unpacked_q_weight = torch.empty(
q_weight.shape[0] * pack_factor,
q_weight.shape[1],
dtype=torch.uint8,
device=q_weight.device,
)
for i in range(pack_factor):
unpacked_q_weight[i::pack_factor, :] = (q_weight >> (i * 4)) & 0x0F
# unpack q_zeros: (num_groups, N//pack_factor) -> (num_groups, N)
unpacked_q_zeros = torch.empty(
q_zeros.shape[0],
q_zeros.shape[1] * pack_factor,
dtype=torch.uint8,
device=q_zeros.device,
)
for i in range(pack_factor):
unpacked_q_zeros[:, i::pack_factor] = (q_zeros >> (i * 4)) & 0x0F
unpacked_q_zeros += 1
unpacked_q_zeros = unpacked_q_zeros.to(scales.dtype)
scale_zeros = unpacked_q_zeros * scales # (num_groups, N)
current_g_idx = torch.tensor(
[i // group_size for i in range(K)], dtype=torch.int32, device=q_weight.device
)
scale_mat = scales[current_g_idx] # (K, N)
scale_zeros_mat = scale_zeros[current_g_idx] # (K, N)
# dequant: weight * scale - scale_zeros
dequantized_b = unpacked_q_weight.to(scales.dtype) * scale_mat - scale_zeros_mat
return dequantized_b.reshape(K, N)
def torch_gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
):
K, N = a.shape[1], b_q_weight.shape[1]
b_dequant = torch_dequantize(
b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit, K, N
)
c = torch.matmul(a, b_dequant)
return c
def _test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, device="cuda"):
b_fp = torch.randn(K, N, dtype=dtype, device=device)
assert K % group_size == 0, "K must be divisible by group_size"
num_groups = K // group_size
if use_shuffle:
return
else:
g_idx = torch.tensor(
[i // group_size for i in range(K)], dtype=torch.int32, device=device
)
b_shuffled = b_fp[g_idx]
b_grouped = b_shuffled.reshape(num_groups, group_size, N)
b_max = torch.max(b_grouped, dim=1, keepdim=True)[0]
b_min = torch.min(b_grouped, dim=1, keepdim=True)[0]
scales = (b_max - b_min) / (2**bit - 1)
scales = scales.clamp(min=1e-6)
zeros_float = (-b_min / scales).round()
q_b = (
(b_grouped / scales + zeros_float).round().clamp(0, 2**bit - 1).to(torch.uint8)
)
q_zeros_unpacked = zeros_float.to(torch.uint8) - 1
b_q_weight = pack_rows(q_b.reshape(K, N), bit, K, N)
q_zeros_unpacked = q_zeros_unpacked.reshape(num_groups, N)
b_gptq_qzeros = pack_cols(q_zeros_unpacked, bit, num_groups, N)
b_gptq_scales = scales.squeeze(1)
a = torch.randn(M, K, dtype=dtype, device=device)
c_ref = torch_gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit
)
c_out = gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, g_idx, use_shuffle, bit
)
rtol = 4e-2
atol = 4e-2
torch.testing.assert_close(c_ref, c_out, rtol=rtol, atol=atol)
print(
f"✅ Test passed: M={M}, N={N}, K={K}, bit={bit}, group_size={group_size}, use_shuffle={use_shuffle}, dtype={dtype}"
)
@pytest.mark.parametrize("M", [1, 8, 128])
@pytest.mark.parametrize("N", [2048, 4096])
@pytest.mark.parametrize("K", [2048, 4096])
@pytest.mark.parametrize("bit", [4])
@pytest.mark.parametrize("group_size", [128])
@pytest.mark.parametrize("use_shuffle", [False])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_gptq_gemm(M, N, K, bit, group_size, use_shuffle, dtype):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
_test_gptq_gemm_once(M, N, K, bit, group_size, use_shuffle, dtype, "cuda")
if __name__ == "__main__":
pytest.main([__file__, "-v"])
import pytest
import torch
from sgl_kernel import gptq_marlin_gemm
from sgl_kernel.scalar_type import scalar_types
from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
(257, 13, 11),
(658, 13, 11),
]
# uint4 for awq
# uint4b8 for gptq
@pytest.mark.parametrize("k_chunk", [128])
@pytest.mark.parametrize("n_chunk", [64, 256])
@pytest.mark.parametrize("quant_type", [scalar_types.uint4, scalar_types.uint4b8])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", [False, True])
@pytest.mark.parametrize("is_k_full", [False, True])
@pytest.mark.parametrize("use_atomic_add", [False, True])
@pytest.mark.parametrize("use_fp32_reduce", [False, True])
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
if act_order:
if group_size == -1:
return
if group_size == size_k:
return
if has_zp:
return
if size_k % group_size != 0:
return
a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
if has_zp:
# AWQ style, unsigned + runtime zero-point
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size
)
g_idx = None
sort_indices = None
marlin_s2 = None
else:
# GPTQ style, unsigned + symmetric bias
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order
)
marlin_zp = None
marlin_s2 = None
workspace = marlin_make_workspace(w_ref.device)
# marlin gemm
output = gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_s,
marlin_s2,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
# ref gemm
output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize()
max_diff = torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref)
)
assert max_diff < 0.04
if __name__ == "__main__":
import subprocess
subprocess.call(["pytest", "--tb=short", str(__file__)])
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from sgl_kernel import awq_marlin_repack from sgl_kernel import awq_marlin_repack, gptq_marlin_repack
from sgl_kernel.scalar_type import scalar_types from sgl_kernel.scalar_type import scalar_types
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
get_pack_factor, gptq_quantize_weights,
pack_cols, pack_cols,
pack_rows,
quantize_weights, quantize_weights,
sort_weights,
) )
from sglang.test.test_marlin_utils import get_weight_perm, marlin_weights
GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_TILE = 16
MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 256]
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
(257, 13, 11),
(658, 13, 11),
]
def awq_pack( def awq_pack(
...@@ -35,70 +51,6 @@ def awq_pack( ...@@ -35,70 +51,6 @@ def awq_pack(
return pack_cols(q_w, num_bits, size_k, size_n) return pack_cols(q_w, num_bits, size_k, size_n)
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile))
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
return q_w
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
# Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
# Pack
pack_factor = get_pack_factor(num_bits)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(np.uint32)
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
return q_packed
def get_weight_perm(num_bits: int):
perm_list: list[int] = []
for i in range(32):
perm1: list[int] = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = np.array(perm_list)
if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = np.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
@pytest.mark.parametrize("num_bits", [4, 8]) @pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2)]) @pytest.mark.parametrize("k_tiles,n_tiles", [(1, 1), (2, 2)])
@pytest.mark.parametrize("group_size", [16, 32]) @pytest.mark.parametrize("group_size", [16, 32])
...@@ -130,6 +82,66 @@ def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size): ...@@ -130,6 +82,66 @@ def test_awq_marlin_repack_correct(num_bits, k_tiles, n_tiles, group_size):
torch.testing.assert_close(out_gpu, q_w_marlin) torch.testing.assert_close(out_gpu, q_w_marlin)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", [scalar_types.uint4b8])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [False, True])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_gptq_marlin_repack(
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
):
m_factor, n_factor, k_factor = mnk_factors
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size == size_k:
return
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
if size_k % group_size != 0:
pytest.skip("size_k must be divisible by group_size")
# Create input
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
b_weight, quant_type, group_size, act_order
)
q_w_gptq = pack_rows(q_w, quant_type.size_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
if act_order:
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
marlin_layout_perm = get_weight_perm(quant_type.size_bits)
q_w_marlin_ref = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, marlin_layout_perm
)
# Run Marlin repack GPU kernel
q_w_marlin = gptq_marlin_repack(
q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits
)
torch.cuda.synchronize()
torch.testing.assert_close(q_w_marlin, q_w_marlin_ref)
if __name__ == "__main__": if __name__ == "__main__":
import subprocess import subprocess
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment