Unverified Commit 82b2a84c authored by spike-zhu's avatar spike-zhu Committed by GitHub
Browse files

issue/458 add AWQ dequantization torch test and improve variable naming readability

parent 3a91947e
...@@ -21,9 +21,6 @@ __C __export infiniStatus_t infiniopDequantize(infiniopDequantizeDescriptor_t de ...@@ -21,9 +21,6 @@ __C __export infiniStatus_t infiniopDequantize(infiniopDequantizeDescriptor_t de
const void *qweight, const void *qweight,
const void *scales, const void *scales,
const void *zeros, const void *zeros,
size_t split_k_iters,
size_t thx,
size_t thy,
void *stream); void *stream);
__C __export infiniStatus_t infiniopDestroyDequantizeDescriptor(infiniopDequantizeDescriptor_t desc); __C __export infiniStatus_t infiniopDestroyDequantizeDescriptor(infiniopDequantizeDescriptor_t desc);
......
...@@ -46,9 +46,6 @@ ...@@ -46,9 +46,6 @@
const void *qweight, \ const void *qweight, \
const void *scales, \ const void *scales, \
const void *zeros, \ const void *zeros, \
int split_k_iters, \
int thx, \
int thy, \
void *stream) const; \ void *stream) const; \
}; \ }; \
} }
......
...@@ -11,11 +11,11 @@ class DequantizeInfo { ...@@ -11,11 +11,11 @@ class DequantizeInfo {
DequantizeInfo() = default; DequantizeInfo() = default;
public: public:
int _in_c, _qout_c, _G; int _in_features, _out_features, _num_groups;
int in_c() const { return _in_c; } int in_features() const { return _in_features; }
int qout_c() const { return _qout_c; } int out_features() const { return _out_features; }
int G() const { return _G; } int num_groups() const { return _num_groups; }
static utils::Result<DequantizeInfo> create( static utils::Result<DequantizeInfo> create(
infiniopTensorDescriptor_t out_desc, infiniopTensorDescriptor_t out_desc,
...@@ -23,14 +23,14 @@ public: ...@@ -23,14 +23,14 @@ public:
infiniopTensorDescriptor_t scales_desc, infiniopTensorDescriptor_t scales_desc,
infiniopTensorDescriptor_t zeros_desc) { infiniopTensorDescriptor_t zeros_desc) {
int _in_c = qweight_desc->dim(0); int _in_features = qweight_desc->dim(0);
int _qout_c = qweight_desc->dim(1); int _out_features = qweight_desc->dim(1);
int _G = scales_desc->dim(0); int _num_groups = scales_desc->dim(0);
return utils::Result<DequantizeInfo>(DequantizeInfo{ return utils::Result<DequantizeInfo>(DequantizeInfo{
_in_c, _in_features,
_qout_c, _out_features,
_G}); _num_groups});
} }
}; };
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const &source) { __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const &source) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
assert(false); #error "dequantize_s4_to_fp16x2 requires CUDA compute capability >= 7.5"
#else #else
uint4 result; uint4 result;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
__global__ void __launch_bounds__(64) __global__ void __launch_bounds__(64)
dequantize_weights(int *__restrict__ B, half *__restrict__ scaling_factors, dequantize_weights(int *__restrict__ B, half *__restrict__ scaling_factors,
int *__restrict__ zeros, half *__restrict__ C, int G) { int *__restrict__ zeros, half *__restrict__ C, int group_size) {
static constexpr uint32_t ZERO = 0x0; static constexpr uint32_t ZERO = 0x0;
half B_shared[32 * (128 + 8)]; half B_shared[32 * (128 + 8)];
...@@ -23,9 +23,9 @@ __global__ void __launch_bounds__(64) ...@@ -23,9 +23,9 @@ __global__ void __launch_bounds__(64)
int index2 = col + row * N; int index2 = col + row * N;
int *B_ptr2 = B + index2; int *B_ptr2 = B + index2;
int index3 = col + (int)(row / G) * N; int index3 = col + (int)(row / group_size) * N;
int *zeros_ptr2 = zeros + index3; int *zeros_ptr2 = zeros + index3;
int index4 = 8 * col + (int)(row / G) * N * 8; int index4 = 8 * col + (int)(row / group_size) * N * 8;
half *scaling_factors_ptr2 = scaling_factors + index4; half *scaling_factors_ptr2 = scaling_factors + index4;
uint32_t zeros_loaded = *(uint32_t *)(zeros_ptr2); uint32_t zeros_loaded = *(uint32_t *)(zeros_ptr2);
...@@ -103,32 +103,21 @@ Descriptor::calculate( ...@@ -103,32 +103,21 @@ Descriptor::calculate(
const void *qweight, const void *qweight,
const void *scales, const void *scales,
const void *zeros, const void *zeros,
int split_k_iters,
int thx,
int thy,
void *stream) const { void *stream) const {
int in_c = _info.in_c(); int in_features = _info.in_features();
int qout_c = _info.qout_c(); int out_features = _info.out_features();
int out_c = qout_c * 8; int group_size = in_features / _info.num_groups();
int G = in_c / _info.G();
// ==================== 默认配置, 固定为 8 ====================
int x_thread = thx; constexpr int BLOCK_X = 8;
int y_thread = thy; constexpr int BLOCK_Y = 8;
int x_blocks = 1; int x_blocks = (out_features + BLOCK_X - 1) / BLOCK_X;
int y_blocks = 1; int y_blocks = (in_features + BLOCK_Y - 1) / BLOCK_Y;
if (thx == 0) {
x_thread = qout_c; dim3 num_blocks(x_blocks, y_blocks);
} dim3 threads_per_block(BLOCK_X, BLOCK_Y);
if (thy == 0) { // =====================================================
y_thread = in_c;
}
if (thx == 0 && thy == 0) {
x_thread = 8;
y_thread = 8;
x_blocks = (int)(qout_c / 8);
y_blocks = (int)(in_c / 8);
}
half *out_ = reinterpret_cast<half *>(out); half *out_ = reinterpret_cast<half *>(out);
...@@ -136,11 +125,8 @@ Descriptor::calculate( ...@@ -136,11 +125,8 @@ Descriptor::calculate(
half *scales_ = const_cast<half *>(reinterpret_cast<const half *>(scales)); half *scales_ = const_cast<half *>(reinterpret_cast<const half *>(scales));
int *zeros_ = const_cast<int *>(reinterpret_cast<const int *>(zeros)); int *zeros_ = const_cast<int *>(reinterpret_cast<const int *>(zeros));
dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(x_thread, y_thread);
dequantize_weights<<<num_blocks, threads_per_block, 0, reinterpret_cast<cudaStream_t>(stream)>>>( dequantize_weights<<<num_blocks, threads_per_block, 0, reinterpret_cast<cudaStream_t>(stream)>>>(
qweight_, scales_, zeros_, out_, G); qweight_, scales_, zeros_, out_, group_size);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
...@@ -60,15 +60,12 @@ __C infiniStatus_t infiniopDequantize( ...@@ -60,15 +60,12 @@ __C infiniStatus_t infiniopDequantize(
const void *qweight, const void *qweight,
const void *scales, const void *scales,
const void *zeros, const void *zeros,
size_t split_k_iters,
size_t thx,
size_t thy,
void *stream) { void *stream) {
#define CALCULATE(CASE, NAMESPACE) \ #define CALCULATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return reinterpret_cast<const op::dequantize::NAMESPACE::Descriptor *>(desc) \ return reinterpret_cast<const op::dequantize::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, out, qweight, scales, zeros, split_k_iters, thx, thy, stream) ->calculate(workspace, workspace_size, out, qweight, scales, zeros, stream)
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
......
...@@ -23,22 +23,112 @@ from libinfiniop import ( ...@@ -23,22 +23,112 @@ from libinfiniop import (
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES = [
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride # qweight_shape, qzeros_shape, qscales_shape, out_shape, qweight_strides, qzeros_strides,
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None), # qscales_strides, out_strides, qweights_dtype, qzeros_dtype, qscales_dtype, out_dtype, bits, group_size
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None), (
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)), (512, 256),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)), (16, 256),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None), (16, 2048),
(512, 2048),
None,
None,
None,
None,
InfiniDtype.I32,
InfiniDtype.I32,
InfiniDtype.F16,
InfiniDtype.F16,
4,
32,
),
(
(1024, 128),
(2, 128),
(2, 1024),
(1024, 1024),
None,
None,
None,
None,
InfiniDtype.I32,
InfiniDtype.I32,
InfiniDtype.F16,
InfiniDtype.F16,
4,
512,
),
(
(2048, 1024),
(16, 1024),
(16, 8192),
(2048, 8192),
None,
None,
None,
None,
InfiniDtype.I32,
InfiniDtype.I32,
InfiniDtype.F16,
InfiniDtype.F16,
4,
128,
),
(
(4096, 512),
(4, 512),
(4, 4096),
(4096, 4096),
None,
None,
None,
None,
InfiniDtype.I32,
InfiniDtype.I32,
InfiniDtype.F16,
InfiniDtype.F16,
4,
1024,
),
(
(8192, 256),
(64, 256),
(64, 2048),
(8192, 2048),
None,
None,
None,
None,
InfiniDtype.I32,
InfiniDtype.I32,
InfiniDtype.F16,
InfiniDtype.F16,
4,
128,
),
(
(8192, 512),
(32, 512),
(32, 4096),
(8192, 4096),
None,
None,
None,
None,
InfiniDtype.I32,
InfiniDtype.I32,
InfiniDtype.F16,
InfiniDtype.F16,
4,
256,
),
] ]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] _TENSOR_DTYPES = [InfiniDtype.F16]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 0, "rtol": 1e-2}, InfiniDtype.F16: {"atol": 0, "rtol": 1e-4},
InfiniDtype.F32: {"atol": 0, "rtol": 1e-3},
InfiniDtype.BF16: {"atol": 0, "rtol": 5e-2},
} }
DEBUG = False DEBUG = False
...@@ -46,19 +136,61 @@ PROFILE = False ...@@ -46,19 +136,61 @@ PROFILE = False
NUM_PRERUN = 10 NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def dequantize(
qweight: torch.Tensor,
qzeros: torch.Tensor,
qscales: torch.Tensor,
bits: int,
group_size: int,
):
shifts = torch.arange(0, 32, bits, device=qweight.device)
# Unpacking qweight columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
iweights = iweights.view(iweights.shape[0], -1)
# PyTorch implementation for matrix multiplication # Unpacking qzeros columnwise
def gemm(d, _c, beta, _a, _b, alpha): if qzeros is not None:
try: izeros = torch.bitwise_right_shift(
if _c.ndim == 2: qzeros[:, :, None], shifts[None, None, :]
torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d) ).to(
elif _c.ndim == 3: torch.int8 # smallest dtype available
torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d) )
izeros = izeros.view(izeros.shape[0], -1)
else: else:
raise izeros = qzeros
except Exception:
torch.matmul(_a, _b, out=d) # Reverse AWQ specific packing order - weights are packed in reverse within each 32-bit word
d.mul_(alpha).add_(_c, alpha=beta) reverse_order_tensor = torch.arange(
iweights.shape[-1],
dtype=torch.int32,
device=izeros.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
if izeros is not None:
izeros = izeros[:, reverse_order_tensor]
iweights = iweights[:, reverse_order_tensor]
# Extract the actual quantized values by masking higher bits
iweight = torch.bitwise_and(iweights, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
# Expand scaling factors and zeros to match the full weight dimensions
# Apply dequantization formula: dequantized = (quantized - zero_point) * scale
qscales = qscales.repeat_interleave(group_size, dim=0)
izeros = izeros.repeat_interleave(group_size, dim=0)
iweight = (iweight - izeros) * qscales
return iweight
# The argument list should be (lib, handle, torch_device, <param list>, dtype) # The argument list should be (lib, handle, torch_device, <param list>, dtype)
...@@ -66,29 +198,52 @@ def gemm(d, _c, beta, _a, _b, alpha): ...@@ -66,29 +198,52 @@ def gemm(d, _c, beta, _a, _b, alpha):
def test( def test(
handle, handle,
device, device,
alpha, qweights_shape,
beta, qzeros_shape,
a_shape, qscales_shape,
b_shape, out_shape,
c_shape, qweights_stride,
a_stride=None, qzeros_stride,
b_stride=None, qscales_stride,
c_stride=None, out_stride,
dtype=InfiniDtype.F16, qweights_dtype,
qzeros_dtype,
qscales_dtype,
out_dtype,
bits,
group_size,
dtype=None,
sync=None, sync=None,
): ):
print( print(
f"Testing Gemm on {InfiniDeviceNames[device]} with alpha:{alpha}, beta:{beta}," f"Testing Dequantize on {InfiniDeviceNames[device]} with bits:{bits}, group_size:{group_size},"
f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape}," f" qweights_shape:{qweights_shape}, qzeros_shape:{qzeros_shape}, qscales_shape:{qscales_shape},"
f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, dtype:{InfiniDtypeNames[dtype]}" f" qweights_stride:{qweights_stride}, qzeros_stride:{qzeros_stride}, qscales_stride:{qscales_stride},"
f" qweights_dtype:{InfiniDtypeNames[qweights_dtype]}, qzeros_dtype:{InfiniDtypeNames[qzeros_dtype]}, qscales_dtype:{InfiniDtypeNames[qscales_dtype]}"
) )
qweight = TestTensor((8192, 256), None, InfiniDtype.I32, device, mode="randint") qweights = TestTensor(
scales = TestTensor((64, 2048), None, InfiniDtype.F16, device) qweights_shape, qweights_stride, qweights_dtype, device, mode="randint"
zeros = TestTensor((64, 256), None, InfiniDtype.I32, device, mode="zeros") )
out = TestTensor((8192, 2048), None, InfiniDtype.F16, device, mode="zeros") qzeros = TestTensor(qzeros_shape, qzeros_stride, qzeros_dtype, device, mode="randint")
qscales = TestTensor(qscales_shape, qscales_stride, qscales_dtype, device)
out = TestTensor(out_shape, out_stride, out_dtype, device, mode="zeros")
ans = TestTensor(out_shape, out_stride, out_dtype, device, mode="ones")
# Compute the PyTorch reference result
def torch_dequantize():
return dequantize(
qweights.torch_tensor(),
qzeros.torch_tensor(),
qscales.torch_tensor(),
bits,
group_size,
)
print(out.actual_tensor()) ans = torch_dequantize()
if sync is not None:
sync()
descriptor = infiniopOperatorDescriptor_t() descriptor = infiniopOperatorDescriptor_t()
check_error( check_error(
...@@ -96,15 +251,15 @@ def test( ...@@ -96,15 +251,15 @@ def test(
handle, handle,
ctypes.byref(descriptor), ctypes.byref(descriptor),
out.descriptor, out.descriptor,
qweight.descriptor, qweights.descriptor,
scales.descriptor, qscales.descriptor,
zeros.descriptor, qzeros.descriptor,
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
# for tensor in [a, b, c]: for tensor in [qweights, qzeros, qscales, out]:
# tensor.destroy_desc() tensor.destroy_desc()
# Get workspace size and create workspace # Get workspace size and create workspace
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
...@@ -123,35 +278,30 @@ def test( ...@@ -123,35 +278,30 @@ def test(
workspace.data(), workspace.data(),
workspace_size.value, workspace_size.value,
out.data(), out.data(),
qweight.data(), qweights.data(),
scales.data(), qscales.data(),
zeros.data(), qzeros.data(),
0,
0,
0,
None, None,
) )
) )
lib_dequantize() lib_dequantize()
print(out.actual_tensor()) # Validate results
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
# # Validate results
# atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
# if DEBUG: if DEBUG:
# debug(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol) debug(out.actual_tensor(), ans, atol=atol, rtol=rtol)
# assert torch.allclose(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol) assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol)
# # Profiling workflow # Profiling workflow
# if PROFILE: if PROFILE:
# # fmt: off # fmt: off
# profile_operation("PyTorch", lambda: torch_gemm(), device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: torch_dequantize(), device, NUM_PRERUN, NUM_ITERATIONS)
# profile_operation(" lib", lambda: lib_gemm(), device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_dequantize(), device, NUM_PRERUN, NUM_ITERATIONS)
# # fmt: on # fmt: on
# check_error(LIBINFINIOP.infiniopDestroyDequantizeDescriptor(descriptor)) check_error(LIBINFINIOP.infiniopDestroyDequantizeDescriptor(descriptor))
# ============================================================================== # ==============================================================================
......
...@@ -555,9 +555,6 @@ def dequantize_(lib): ...@@ -555,9 +555,6 @@ def dequantize_(lib):
c_void_p, c_void_p,
c_void_p, c_void_p,
c_void_p, c_void_p,
c_size_t,
c_size_t,
c_size_t,
c_void_p, c_void_p,
] ]
lib.infiniopDestroyDequantizeDescriptor.restype = c_int32 lib.infiniopDestroyDequantizeDescriptor.restype = c_int32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment