Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
82b2a84c
Unverified
Commit
82b2a84c
authored
Sep 18, 2025
by
spike-zhu
Committed by
GitHub
Sep 18, 2025
Browse files
issue/458 add AWQ dequantization torch test and improve variable naming readability
parent
3a91947e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
246 additions
and
122 deletions
+246
-122
include/infiniop/ops/dequantize.h
include/infiniop/ops/dequantize.h
+0
-3
src/infiniop/ops/dequantize/dequantize.h
src/infiniop/ops/dequantize/dequantize.h
+0
-3
src/infiniop/ops/dequantize/info.h
src/infiniop/ops/dequantize/info.h
+10
-10
src/infiniop/ops/dequantize/nvidia/dequantize_w42f16_kernel.cuh
...finiop/ops/dequantize/nvidia/dequantize_w42f16_kernel.cuh
+1
-1
src/infiniop/ops/dequantize/nvidia/dequantize_w42f16_nvidia.cu
...nfiniop/ops/dequantize/nvidia/dequantize_w42f16_nvidia.cu
+18
-32
src/infiniop/ops/dequantize/operator.cc
src/infiniop/ops/dequantize/operator.cc
+1
-4
test/infiniop/dequantize.py
test/infiniop/dequantize.py
+216
-66
test/infiniop/libinfiniop/op_register.py
test/infiniop/libinfiniop/op_register.py
+0
-3
No files found.
include/infiniop/ops/dequantize.h
View file @
82b2a84c
...
...
@@ -21,9 +21,6 @@ __C __export infiniStatus_t infiniopDequantize(infiniopDequantizeDescriptor_t de
const
void
*
qweight
,
const
void
*
scales
,
const
void
*
zeros
,
size_t
split_k_iters
,
size_t
thx
,
size_t
thy
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroyDequantizeDescriptor
(
infiniopDequantizeDescriptor_t
desc
);
...
...
src/infiniop/ops/dequantize/dequantize.h
View file @
82b2a84c
...
...
@@ -46,9 +46,6 @@
const void *qweight, \
const void *scales, \
const void *zeros, \
int split_k_iters, \
int thx, \
int thy, \
void *stream) const; \
}; \
}
...
...
src/infiniop/ops/dequantize/info.h
View file @
82b2a84c
...
...
@@ -11,11 +11,11 @@ class DequantizeInfo {
DequantizeInfo
()
=
default
;
public:
int
_in_
c
,
_
q
out_
c
,
_G
;
int
_in_
features
,
_out_
features
,
_num_groups
;
int
in_
c
()
const
{
return
_in_
c
;
}
int
q
out_
c
()
const
{
return
_
q
out_
c
;
}
int
G
()
const
{
return
_
G
;
}
int
in_
features
()
const
{
return
_in_
features
;
}
int
out_
features
()
const
{
return
_out_
features
;
}
int
num_groups
()
const
{
return
_
num_groups
;
}
static
utils
::
Result
<
DequantizeInfo
>
create
(
infiniopTensorDescriptor_t
out_desc
,
...
...
@@ -23,14 +23,14 @@ public:
infiniopTensorDescriptor_t
scales_desc
,
infiniopTensorDescriptor_t
zeros_desc
)
{
int
_in_
c
=
qweight_desc
->
dim
(
0
);
int
_
q
out_
c
=
qweight_desc
->
dim
(
1
);
int
_
G
=
scales_desc
->
dim
(
0
);
int
_in_
features
=
qweight_desc
->
dim
(
0
);
int
_out_
features
=
qweight_desc
->
dim
(
1
);
int
_
num_groups
=
scales_desc
->
dim
(
0
);
return
utils
::
Result
<
DequantizeInfo
>
(
DequantizeInfo
{
_in_
c
,
_
q
out_
c
,
_
G
});
_in_
features
,
_out_
features
,
_
num_groups
});
}
};
...
...
src/infiniop/ops/dequantize/nvidia/dequantize_w42f16_kernel.cuh
View file @
82b2a84c
...
...
@@ -2,7 +2,7 @@
__device__
uint4
dequantize_s4_to_fp16x2
(
uint32_t
const
&
source
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
assert
(
false
);
#error "dequantize_s4_to_fp16x2 requires CUDA compute capability >= 7.5"
#else
uint4
result
;
...
...
src/infiniop/ops/dequantize/nvidia/dequantize_w42f16_nvidia.cu
View file @
82b2a84c
...
...
@@ -8,7 +8,7 @@
__global__
void
__launch_bounds__
(
64
)
dequantize_weights
(
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
half
*
__restrict__
C
,
int
G
)
{
int
*
__restrict__
zeros
,
half
*
__restrict__
C
,
int
group_size
)
{
static
constexpr
uint32_t
ZERO
=
0x0
;
half
B_shared
[
32
*
(
128
+
8
)];
...
...
@@ -23,9 +23,9 @@ __global__ void __launch_bounds__(64)
int
index2
=
col
+
row
*
N
;
int
*
B_ptr2
=
B
+
index2
;
int
index3
=
col
+
(
int
)(
row
/
G
)
*
N
;
int
index3
=
col
+
(
int
)(
row
/
group_size
)
*
N
;
int
*
zeros_ptr2
=
zeros
+
index3
;
int
index4
=
8
*
col
+
(
int
)(
row
/
G
)
*
N
*
8
;
int
index4
=
8
*
col
+
(
int
)(
row
/
group_size
)
*
N
*
8
;
half
*
scaling_factors_ptr2
=
scaling_factors
+
index4
;
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr2
);
...
...
@@ -103,32 +103,21 @@ Descriptor::calculate(
const
void
*
qweight
,
const
void
*
scales
,
const
void
*
zeros
,
int
split_k_iters
,
int
thx
,
int
thy
,
void
*
stream
)
const
{
int
in_c
=
_info
.
in_c
();
int
qout_c
=
_info
.
qout_c
();
int
out_c
=
qout_c
*
8
;
int
G
=
in_c
/
_info
.
G
();
int
x_thread
=
thx
;
int
y_thread
=
thy
;
int
x_blocks
=
1
;
int
y_blocks
=
1
;
if
(
thx
==
0
)
{
x_thread
=
qout_c
;
}
if
(
thy
==
0
)
{
y_thread
=
in_c
;
}
if
(
thx
==
0
&&
thy
==
0
)
{
x_thread
=
8
;
y_thread
=
8
;
x_blocks
=
(
int
)(
qout_c
/
8
);
y_blocks
=
(
int
)(
in_c
/
8
);
}
int
in_features
=
_info
.
in_features
();
int
out_features
=
_info
.
out_features
();
int
group_size
=
in_features
/
_info
.
num_groups
();
// ==================== 默认配置, 固定为 8 ====================
constexpr
int
BLOCK_X
=
8
;
constexpr
int
BLOCK_Y
=
8
;
int
x_blocks
=
(
out_features
+
BLOCK_X
-
1
)
/
BLOCK_X
;
int
y_blocks
=
(
in_features
+
BLOCK_Y
-
1
)
/
BLOCK_Y
;
dim3
num_blocks
(
x_blocks
,
y_blocks
);
dim3
threads_per_block
(
BLOCK_X
,
BLOCK_Y
);
// =====================================================
half
*
out_
=
reinterpret_cast
<
half
*>
(
out
);
...
...
@@ -136,11 +125,8 @@ Descriptor::calculate(
half
*
scales_
=
const_cast
<
half
*>
(
reinterpret_cast
<
const
half
*>
(
scales
));
int
*
zeros_
=
const_cast
<
int
*>
(
reinterpret_cast
<
const
int
*>
(
zeros
));
dim3
num_blocks
(
x_blocks
,
y_blocks
);
dim3
threads_per_block
(
x_thread
,
y_thread
);
dequantize_weights
<<<
num_blocks
,
threads_per_block
,
0
,
reinterpret_cast
<
cudaStream_t
>
(
stream
)
>>>
(
qweight_
,
scales_
,
zeros_
,
out_
,
G
);
qweight_
,
scales_
,
zeros_
,
out_
,
group_size
);
return
INFINI_STATUS_SUCCESS
;
}
...
...
src/infiniop/ops/dequantize/operator.cc
View file @
82b2a84c
...
...
@@ -60,15 +60,12 @@ __C infiniStatus_t infiniopDequantize(
const
void
*
qweight
,
const
void
*
scales
,
const
void
*
zeros
,
size_t
split_k_iters
,
size_t
thx
,
size_t
thy
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::dequantize::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, out, qweight, scales, zeros,
split_k_iters, thx, thy,
stream)
->calculate(workspace, workspace_size, out, qweight, scales, zeros, stream)
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
...
...
test/infiniop/dequantize.py
View file @
82b2a84c
...
...
@@ -23,22 +23,112 @@ from libinfiniop import (
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES
=
[
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride
(
1.0
,
0.0
,
(
1
,
2048
),
(
2048
,
2048
),
(
1
,
2048
),
None
,
None
,
None
),
(
1.0
,
0.0
,
(
2
,
4
,
2048
),
(
2
,
2048
,
2048
),
(
2
,
4
,
2048
),
None
,
None
,
None
),
(
1.0
,
0.0
,
(
1
,
2048
),
(
2048
,
2048
),
(
1
,
2048
),
(
4096
,
1
),
(
4096
,
1
),
(
4096
,
1
)),
(
1.0
,
1.0
,
(
6
,
2048
),
(
2048
,
2560
),
(
6
,
2560
),
(
2048
,
1
),
(
1
,
2048
),
(
2560
,
1
)),
(
1.0
/
8.0
,
0.0
,
(
4
,
8
*
6
,
64
),
(
4
,
64
,
6
),
(
4
,
8
*
6
,
6
),
None
,
None
,
None
),
# qweight_shape, qzeros_shape, qscales_shape, out_shape, qweight_strides, qzeros_strides,
# qscales_strides, out_strides, qweights_dtype, qzeros_dtype, qscales_dtype, out_dtype, bits, group_size
(
(
512
,
256
),
(
16
,
256
),
(
16
,
2048
),
(
512
,
2048
),
None
,
None
,
None
,
None
,
InfiniDtype
.
I32
,
InfiniDtype
.
I32
,
InfiniDtype
.
F16
,
InfiniDtype
.
F16
,
4
,
32
,
),
(
(
1024
,
128
),
(
2
,
128
),
(
2
,
1024
),
(
1024
,
1024
),
None
,
None
,
None
,
None
,
InfiniDtype
.
I32
,
InfiniDtype
.
I32
,
InfiniDtype
.
F16
,
InfiniDtype
.
F16
,
4
,
512
,
),
(
(
2048
,
1024
),
(
16
,
1024
),
(
16
,
8192
),
(
2048
,
8192
),
None
,
None
,
None
,
None
,
InfiniDtype
.
I32
,
InfiniDtype
.
I32
,
InfiniDtype
.
F16
,
InfiniDtype
.
F16
,
4
,
128
,
),
(
(
4096
,
512
),
(
4
,
512
),
(
4
,
4096
),
(
4096
,
4096
),
None
,
None
,
None
,
None
,
InfiniDtype
.
I32
,
InfiniDtype
.
I32
,
InfiniDtype
.
F16
,
InfiniDtype
.
F16
,
4
,
1024
,
),
(
(
8192
,
256
),
(
64
,
256
),
(
64
,
2048
),
(
8192
,
2048
),
None
,
None
,
None
,
None
,
InfiniDtype
.
I32
,
InfiniDtype
.
I32
,
InfiniDtype
.
F16
,
InfiniDtype
.
F16
,
4
,
128
,
),
(
(
8192
,
512
),
(
32
,
512
),
(
32
,
4096
),
(
8192
,
4096
),
None
,
None
,
None
,
None
,
InfiniDtype
.
I32
,
InfiniDtype
.
I32
,
InfiniDtype
.
F16
,
InfiniDtype
.
F16
,
4
,
256
,
),
]
# Data types used for testing
_TENSOR_DTYPES
=
[
InfiniDtype
.
F16
,
InfiniDtype
.
BF16
,
InfiniDtype
.
F32
]
_TENSOR_DTYPES
=
[
InfiniDtype
.
F16
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
0
,
"rtol"
:
1e-2
},
InfiniDtype
.
F32
:
{
"atol"
:
0
,
"rtol"
:
1e-3
},
InfiniDtype
.
BF16
:
{
"atol"
:
0
,
"rtol"
:
5e-2
},
InfiniDtype
.
F16
:
{
"atol"
:
0
,
"rtol"
:
1e-4
},
}
DEBUG
=
False
...
...
@@ -46,19 +136,61 @@ PROFILE = False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
AWQ_ORDER
=
[
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]
AWQ_REVERSE_ORDER
=
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]
# PyTorch implementation for matrix multiplication
def
gemm
(
d
,
_c
,
beta
,
_a
,
_b
,
alpha
):
try
:
if
_c
.
ndim
==
2
:
torch
.
addmm
(
_c
,
_a
,
_b
,
beta
=
beta
,
alpha
=
alpha
,
out
=
d
)
elif
_c
.
ndim
==
3
:
torch
.
baddbmm
(
_c
,
_a
,
_b
,
beta
=
beta
,
alpha
=
alpha
,
out
=
d
)
else
:
raise
except
Exception
:
torch
.
matmul
(
_a
,
_b
,
out
=
d
)
d
.
mul_
(
alpha
).
add_
(
_c
,
alpha
=
beta
)
def
dequantize
(
qweight
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
qscales
:
torch
.
Tensor
,
bits
:
int
,
group_size
:
int
,
):
shifts
=
torch
.
arange
(
0
,
32
,
bits
,
device
=
qweight
.
device
)
# Unpacking qweight columnwise
iweights
=
torch
.
bitwise_right_shift
(
qweight
[:,
:,
None
],
shifts
[
None
,
None
,
:]).
to
(
torch
.
int8
# smallest dtype available
)
iweights
=
iweights
.
view
(
iweights
.
shape
[
0
],
-
1
)
# Unpacking qzeros columnwise
if
qzeros
is
not
None
:
izeros
=
torch
.
bitwise_right_shift
(
qzeros
[:,
:,
None
],
shifts
[
None
,
None
,
:]
).
to
(
torch
.
int8
# smallest dtype available
)
izeros
=
izeros
.
view
(
izeros
.
shape
[
0
],
-
1
)
else
:
izeros
=
qzeros
# Reverse AWQ specific packing order - weights are packed in reverse within each 32-bit word
reverse_order_tensor
=
torch
.
arange
(
iweights
.
shape
[
-
1
],
dtype
=
torch
.
int32
,
device
=
izeros
.
device
,
)
reverse_order_tensor
=
reverse_order_tensor
.
view
(
-
1
,
32
//
bits
)
reverse_order_tensor
=
reverse_order_tensor
[:,
AWQ_REVERSE_ORDER
]
reverse_order_tensor
=
reverse_order_tensor
.
view
(
-
1
)
if
izeros
is
not
None
:
izeros
=
izeros
[:,
reverse_order_tensor
]
iweights
=
iweights
[:,
reverse_order_tensor
]
# Extract the actual quantized values by masking higher bits
iweight
=
torch
.
bitwise_and
(
iweights
,
(
2
**
bits
)
-
1
)
izeros
=
torch
.
bitwise_and
(
izeros
,
(
2
**
bits
)
-
1
)
# Expand scaling factors and zeros to match the full weight dimensions
# Apply dequantization formula: dequantized = (quantized - zero_point) * scale
qscales
=
qscales
.
repeat_interleave
(
group_size
,
dim
=
0
)
izeros
=
izeros
.
repeat_interleave
(
group_size
,
dim
=
0
)
iweight
=
(
iweight
-
izeros
)
*
qscales
return
iweight
# The argument list should be (lib, handle, torch_device, <param list>, dtype)
...
...
@@ -66,29 +198,52 @@ def gemm(d, _c, beta, _a, _b, alpha):
def
test
(
handle
,
device
,
alpha
,
beta
,
a_shape
,
b_shape
,
c_shape
,
a_stride
=
None
,
b_stride
=
None
,
c_stride
=
None
,
dtype
=
InfiniDtype
.
F16
,
qweights_shape
,
qzeros_shape
,
qscales_shape
,
out_shape
,
qweights_stride
,
qzeros_stride
,
qscales_stride
,
out_stride
,
qweights_dtype
,
qzeros_dtype
,
qscales_dtype
,
out_dtype
,
bits
,
group_size
,
dtype
=
None
,
sync
=
None
,
):
print
(
f
"Testing Gemm on
{
InfiniDeviceNames
[
device
]
}
with alpha:
{
alpha
}
, beta:
{
beta
}
,"
f
" a_shape:
{
a_shape
}
, b_shape:
{
b_shape
}
, c_shape:
{
c_shape
}
,"
f
" a_stride:
{
a_stride
}
, b_stride:
{
b_stride
}
, c_stride:
{
c_stride
}
, dtype:
{
InfiniDtypeNames
[
dtype
]
}
"
f
"Testing Dequantize on
{
InfiniDeviceNames
[
device
]
}
with bits:
{
bits
}
, group_size:
{
group_size
}
,"
f
" qweights_shape:
{
qweights_shape
}
, qzeros_shape:
{
qzeros_shape
}
, qscales_shape:
{
qscales_shape
}
,"
f
" qweights_stride:
{
qweights_stride
}
, qzeros_stride:
{
qzeros_stride
}
, qscales_stride:
{
qscales_stride
}
,"
f
" qweights_dtype:
{
InfiniDtypeNames
[
qweights_dtype
]
}
, qzeros_dtype:
{
InfiniDtypeNames
[
qzeros_dtype
]
}
, qscales_dtype:
{
InfiniDtypeNames
[
qscales_dtype
]
}
"
)
qweights
=
TestTensor
(
qweights_shape
,
qweights_stride
,
qweights_dtype
,
device
,
mode
=
"randint"
)
qweight
=
TestTensor
((
8192
,
256
),
None
,
InfiniDtype
.
I32
,
device
,
mode
=
"randint"
)
scales
=
TestTensor
((
64
,
2048
),
None
,
InfiniDtype
.
F16
,
device
)
zeros
=
TestTensor
((
64
,
256
),
None
,
InfiniDtype
.
I32
,
device
,
mode
=
"zeros"
)
out
=
TestTensor
((
8192
,
2048
),
None
,
InfiniDtype
.
F16
,
device
,
mode
=
"zeros"
)
print
(
out
.
actual_tensor
())
qzeros
=
TestTensor
(
qzeros_shape
,
qzeros_stride
,
qzeros_dtype
,
device
,
mode
=
"randint"
)
qscales
=
TestTensor
(
qscales_shape
,
qscales_stride
,
qscales_dtype
,
device
)
out
=
TestTensor
(
out_shape
,
out_stride
,
out_dtype
,
device
,
mode
=
"zeros"
)
ans
=
TestTensor
(
out_shape
,
out_stride
,
out_dtype
,
device
,
mode
=
"ones"
)
# Compute the PyTorch reference result
def
torch_dequantize
():
return
dequantize
(
qweights
.
torch_tensor
(),
qzeros
.
torch_tensor
(),
qscales
.
torch_tensor
(),
bits
,
group_size
,
)
ans
=
torch_dequantize
()
if
sync
is
not
None
:
sync
()
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
...
...
@@ -96,15 +251,15 @@ def test(
handle
,
ctypes
.
byref
(
descriptor
),
out
.
descriptor
,
qweight
.
descriptor
,
scales
.
descriptor
,
zeros
.
descriptor
,
qweight
s
.
descriptor
,
q
scales
.
descriptor
,
q
zeros
.
descriptor
,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
#
for tensor in [
a, b, c
]:
#
tensor.destroy_desc()
for
tensor
in
[
qweights
,
qzeros
,
qscales
,
out
]:
tensor
.
destroy_desc
()
# Get workspace size and create workspace
workspace_size
=
c_uint64
(
0
)
...
...
@@ -123,35 +278,30 @@ def test(
workspace
.
data
(),
workspace_size
.
value
,
out
.
data
(),
qweight
.
data
(),
scales
.
data
(),
zeros
.
data
(),
0
,
0
,
0
,
qweights
.
data
(),
qscales
.
data
(),
qzeros
.
data
(),
None
,
)
)
lib_dequantize
()
print
(
out
.
actual_tensor
())
#
#
Validate results
#
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
# Validate results
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
#
if DEBUG:
#
debug(
c
.actual_tensor(), ans
.torch_tensor()
, atol=atol, rtol=rtol)
if
DEBUG
:
debug
(
out
.
actual_tensor
(),
ans
,
atol
=
atol
,
rtol
=
rtol
)
#
assert torch.allclose(
c
.actual_tensor(), ans
.torch_tensor()
, atol=atol, rtol=rtol)
assert
torch
.
allclose
(
out
.
actual_tensor
(),
ans
,
atol
=
atol
,
rtol
=
rtol
)
#
#
Profiling workflow
#
if PROFILE:
#
# fmt: off
#
profile_operation("PyTorch", lambda: torch_
gemm
(), device, NUM_PRERUN, NUM_ITERATIONS)
#
profile_operation(" lib", lambda: lib_
gemm
(), device, NUM_PRERUN, NUM_ITERATIONS)
#
# fmt: on
#
check_error(LIBINFINIOP.infiniopDestroyDequantizeDescriptor(descriptor))
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
torch_
dequantize
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_
dequantize
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
LIBINFINIOP
.
infiniopDestroyDequantizeDescriptor
(
descriptor
))
# ==============================================================================
...
...
test/infiniop/libinfiniop/op_register.py
View file @
82b2a84c
...
...
@@ -555,9 +555,6 @@ def dequantize_(lib):
c_void_p
,
c_void_p
,
c_void_p
,
c_size_t
,
c_size_t
,
c_size_t
,
c_void_p
,
]
lib
.
infiniopDestroyDequantizeDescriptor
.
restype
=
c_int32
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment