Unverified Commit 82b2a84c authored by spike-zhu's avatar spike-zhu Committed by GitHub
Browse files

issue/458 add AWQ dequantization torch test and improve variable naming readability

parent 3a91947e
......@@ -21,9 +21,6 @@ __C __export infiniStatus_t infiniopDequantize(infiniopDequantizeDescriptor_t de
const void *qweight,
const void *scales,
const void *zeros,
size_t split_k_iters,
size_t thx,
size_t thy,
void *stream);
__C __export infiniStatus_t infiniopDestroyDequantizeDescriptor(infiniopDequantizeDescriptor_t desc);
......
......@@ -46,9 +46,6 @@
const void *qweight, \
const void *scales, \
const void *zeros, \
int split_k_iters, \
int thx, \
int thy, \
void *stream) const; \
}; \
}
......
......@@ -11,11 +11,11 @@ class DequantizeInfo {
DequantizeInfo() = default;
public:
int _in_c, _qout_c, _G;
int _in_features, _out_features, _num_groups;
int in_c() const { return _in_c; }
int qout_c() const { return _qout_c; }
int G() const { return _G; }
int in_features() const { return _in_features; }
int out_features() const { return _out_features; }
int num_groups() const { return _num_groups; }
static utils::Result<DequantizeInfo> create(
infiniopTensorDescriptor_t out_desc,
......@@ -23,14 +23,14 @@ public:
infiniopTensorDescriptor_t scales_desc,
infiniopTensorDescriptor_t zeros_desc) {
int _in_c = qweight_desc->dim(0);
int _qout_c = qweight_desc->dim(1);
int _G = scales_desc->dim(0);
int _in_features = qweight_desc->dim(0);
int _out_features = qweight_desc->dim(1);
int _num_groups = scales_desc->dim(0);
return utils::Result<DequantizeInfo>(DequantizeInfo{
_in_c,
_qout_c,
_G});
_in_features,
_out_features,
_num_groups});
}
};
......
......@@ -2,7 +2,7 @@
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const &source) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
assert(false);
#error "dequantize_s4_to_fp16x2 requires CUDA compute capability >= 7.5"
#else
uint4 result;
......
......@@ -8,7 +8,7 @@
__global__ void __launch_bounds__(64)
dequantize_weights(int *__restrict__ B, half *__restrict__ scaling_factors,
int *__restrict__ zeros, half *__restrict__ C, int G) {
int *__restrict__ zeros, half *__restrict__ C, int group_size) {
static constexpr uint32_t ZERO = 0x0;
half B_shared[32 * (128 + 8)];
......@@ -23,9 +23,9 @@ __global__ void __launch_bounds__(64)
int index2 = col + row * N;
int *B_ptr2 = B + index2;
int index3 = col + (int)(row / G) * N;
int index3 = col + (int)(row / group_size) * N;
int *zeros_ptr2 = zeros + index3;
int index4 = 8 * col + (int)(row / G) * N * 8;
int index4 = 8 * col + (int)(row / group_size) * N * 8;
half *scaling_factors_ptr2 = scaling_factors + index4;
uint32_t zeros_loaded = *(uint32_t *)(zeros_ptr2);
......@@ -103,32 +103,21 @@ Descriptor::calculate(
const void *qweight,
const void *scales,
const void *zeros,
int split_k_iters,
int thx,
int thy,
void *stream) const {
int in_c = _info.in_c();
int qout_c = _info.qout_c();
int out_c = qout_c * 8;
int G = in_c / _info.G();
int x_thread = thx;
int y_thread = thy;
int x_blocks = 1;
int y_blocks = 1;
if (thx == 0) {
x_thread = qout_c;
}
if (thy == 0) {
y_thread = in_c;
}
if (thx == 0 && thy == 0) {
x_thread = 8;
y_thread = 8;
x_blocks = (int)(qout_c / 8);
y_blocks = (int)(in_c / 8);
}
int in_features = _info.in_features();
int out_features = _info.out_features();
int group_size = in_features / _info.num_groups();
// ==================== 默认配置, 固定为 8 ====================
constexpr int BLOCK_X = 8;
constexpr int BLOCK_Y = 8;
int x_blocks = (out_features + BLOCK_X - 1) / BLOCK_X;
int y_blocks = (in_features + BLOCK_Y - 1) / BLOCK_Y;
dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(BLOCK_X, BLOCK_Y);
// =====================================================
half *out_ = reinterpret_cast<half *>(out);
......@@ -136,11 +125,8 @@ Descriptor::calculate(
half *scales_ = const_cast<half *>(reinterpret_cast<const half *>(scales));
int *zeros_ = const_cast<int *>(reinterpret_cast<const int *>(zeros));
dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(x_thread, y_thread);
dequantize_weights<<<num_blocks, threads_per_block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
qweight_, scales_, zeros_, out_, G);
qweight_, scales_, zeros_, out_, group_size);
return INFINI_STATUS_SUCCESS;
}
......
......@@ -60,15 +60,12 @@ __C infiniStatus_t infiniopDequantize(
const void *qweight,
const void *scales,
const void *zeros,
size_t split_k_iters,
size_t thx,
size_t thy,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::dequantize::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, out, qweight, scales, zeros, split_k_iters, thx, thy, stream)
->calculate(workspace, workspace_size, out, qweight, scales, zeros, stream)
switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API
......
......@@ -23,22 +23,112 @@ from libinfiniop import (
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
# qweight_shape, qzeros_shape, qscales_shape, out_shape, qweight_strides, qzeros_strides,
# qscales_strides, out_strides, qweights_dtype, qzeros_dtype, qscales_dtype, out_dtype, bits, group_size
(
(512, 256),
(16, 256),
(16, 2048),
(512, 2048),
None,
None,
None,
None,
InfiniDtype.I32,
InfiniDtype.I32,
InfiniDtype.F16,
InfiniDtype.F16,
4,
32,
),
(
(1024, 128),
(2, 128),
(2, 1024),
(1024, 1024),
None,
None,
None,
None,
InfiniDtype.I32,
InfiniDtype.I32,
InfiniDtype.F16,
InfiniDtype.F16,
4,
512,
),
(
(2048, 1024),
(16, 1024),
(16, 8192),
(2048, 8192),
None,
None,
None,
None,
InfiniDtype.I32,
InfiniDtype.I32,
InfiniDtype.F16,
InfiniDtype.F16,
4,
128,
),
(
(4096, 512),
(4, 512),
(4, 4096),
(4096, 4096),
None,
None,
None,
None,
InfiniDtype.I32,
InfiniDtype.I32,
InfiniDtype.F16,
InfiniDtype.F16,
4,
1024,
),
(
(8192, 256),
(64, 256),
(64, 2048),
(8192, 2048),
None,
None,
None,
None,
InfiniDtype.I32,
InfiniDtype.I32,
InfiniDtype.F16,
InfiniDtype.F16,
4,
128,
),
(
(8192, 512),
(32, 512),
(32, 4096),
(8192, 4096),
None,
None,
None,
None,
InfiniDtype.I32,
InfiniDtype.I32,
InfiniDtype.F16,
InfiniDtype.F16,
4,
256,
),
]
# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32]
_TENSOR_DTYPES = [InfiniDtype.F16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 0, "rtol": 1e-2},
InfiniDtype.F32: {"atol": 0, "rtol": 1e-3},
InfiniDtype.BF16: {"atol": 0, "rtol": 5e-2},
InfiniDtype.F16: {"atol": 0, "rtol": 1e-4},
}
DEBUG = False
......@@ -46,19 +136,61 @@ PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
# PyTorch implementation for matrix multiplication
def gemm(d, _c, beta, _a, _b, alpha):
try:
if _c.ndim == 2:
torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
elif _c.ndim == 3:
torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
else:
raise
except Exception:
torch.matmul(_a, _b, out=d)
d.mul_(alpha).add_(_c, alpha=beta)
def dequantize(
qweight: torch.Tensor,
qzeros: torch.Tensor,
qscales: torch.Tensor,
bits: int,
group_size: int,
):
shifts = torch.arange(0, 32, bits, device=qweight.device)
# Unpacking qweight columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
iweights = iweights.view(iweights.shape[0], -1)
# Unpacking qzeros columnwise
if qzeros is not None:
izeros = torch.bitwise_right_shift(
qzeros[:, :, None], shifts[None, None, :]
).to(
torch.int8 # smallest dtype available
)
izeros = izeros.view(izeros.shape[0], -1)
else:
izeros = qzeros
# Reverse AWQ specific packing order - weights are packed in reverse within each 32-bit word
reverse_order_tensor = torch.arange(
iweights.shape[-1],
dtype=torch.int32,
device=izeros.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
if izeros is not None:
izeros = izeros[:, reverse_order_tensor]
iweights = iweights[:, reverse_order_tensor]
# Extract the actual quantized values by masking higher bits
iweight = torch.bitwise_and(iweights, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
# Expand scaling factors and zeros to match the full weight dimensions
# Apply dequantization formula: dequantized = (quantized - zero_point) * scale
qscales = qscales.repeat_interleave(group_size, dim=0)
izeros = izeros.repeat_interleave(group_size, dim=0)
iweight = (iweight - izeros) * qscales
return iweight
# The argument list should be (lib, handle, torch_device, <param list>, dtype)
......@@ -66,29 +198,52 @@ def gemm(d, _c, beta, _a, _b, alpha):
def test(
handle,
device,
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride=None,
b_stride=None,
c_stride=None,
dtype=InfiniDtype.F16,
qweights_shape,
qzeros_shape,
qscales_shape,
out_shape,
qweights_stride,
qzeros_stride,
qscales_stride,
out_stride,
qweights_dtype,
qzeros_dtype,
qscales_dtype,
out_dtype,
bits,
group_size,
dtype=None,
sync=None,
):
print(
f"Testing Gemm on {InfiniDeviceNames[device]} with alpha:{alpha}, beta:{beta},"
f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape},"
f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, dtype:{InfiniDtypeNames[dtype]}"
f"Testing Dequantize on {InfiniDeviceNames[device]} with bits:{bits}, group_size:{group_size},"
f" qweights_shape:{qweights_shape}, qzeros_shape:{qzeros_shape}, qscales_shape:{qscales_shape},"
f" qweights_stride:{qweights_stride}, qzeros_stride:{qzeros_stride}, qscales_stride:{qscales_stride},"
f" qweights_dtype:{InfiniDtypeNames[qweights_dtype]}, qzeros_dtype:{InfiniDtypeNames[qzeros_dtype]}, qscales_dtype:{InfiniDtypeNames[qscales_dtype]}"
)
qweights = TestTensor(
qweights_shape, qweights_stride, qweights_dtype, device, mode="randint"
)
qweight = TestTensor((8192, 256), None, InfiniDtype.I32, device, mode="randint")
scales = TestTensor((64, 2048), None, InfiniDtype.F16, device)
zeros = TestTensor((64, 256), None, InfiniDtype.I32, device, mode="zeros")
out = TestTensor((8192, 2048), None, InfiniDtype.F16, device, mode="zeros")
print(out.actual_tensor())
qzeros = TestTensor(qzeros_shape, qzeros_stride, qzeros_dtype, device, mode="randint")
qscales = TestTensor(qscales_shape, qscales_stride, qscales_dtype, device)
out = TestTensor(out_shape, out_stride, out_dtype, device, mode="zeros")
ans = TestTensor(out_shape, out_stride, out_dtype, device, mode="ones")
# Compute the PyTorch reference result
def torch_dequantize():
return dequantize(
qweights.torch_tensor(),
qzeros.torch_tensor(),
qscales.torch_tensor(),
bits,
group_size,
)
ans = torch_dequantize()
if sync is not None:
sync()
descriptor = infiniopOperatorDescriptor_t()
check_error(
......@@ -96,15 +251,15 @@ def test(
handle,
ctypes.byref(descriptor),
out.descriptor,
qweight.descriptor,
scales.descriptor,
zeros.descriptor,
qweights.descriptor,
qscales.descriptor,
qzeros.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
# for tensor in [a, b, c]:
# tensor.destroy_desc()
for tensor in [qweights, qzeros, qscales, out]:
tensor.destroy_desc()
# Get workspace size and create workspace
workspace_size = c_uint64(0)
......@@ -123,35 +278,30 @@ def test(
workspace.data(),
workspace_size.value,
out.data(),
qweight.data(),
scales.data(),
zeros.data(),
0,
0,
0,
qweights.data(),
qscales.data(),
qzeros.data(),
None,
)
)
lib_dequantize()
print(out.actual_tensor())
# # Validate results
# atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
# Validate results
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
# if DEBUG:
# debug(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol)
if DEBUG:
debug(out.actual_tensor(), ans, atol=atol, rtol=rtol)
# assert torch.allclose(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol)
# # Profiling workflow
# if PROFILE:
# # fmt: off
# profile_operation("PyTorch", lambda: torch_gemm(), device, NUM_PRERUN, NUM_ITERATIONS)
# profile_operation(" lib", lambda: lib_gemm(), device, NUM_PRERUN, NUM_ITERATIONS)
# # fmt: on
# check_error(LIBINFINIOP.infiniopDestroyDequantizeDescriptor(descriptor))
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: torch_dequantize(), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_dequantize(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(LIBINFINIOP.infiniopDestroyDequantizeDescriptor(descriptor))
# ==============================================================================
......
......@@ -555,9 +555,6 @@ def dequantize_(lib):
c_void_p,
c_void_p,
c_void_p,
c_size_t,
c_size_t,
c_size_t,
c_void_p,
]
lib.infiniopDestroyDequantizeDescriptor.restype = c_int32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment