Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
06e299ba
Unverified
Commit
06e299ba
authored
Jun 06, 2023
by
Haotian (Ken) Tang
Committed by
GitHub
Jun 06, 2023
Browse files
Merge pull request #8 from ys-2020/main
[Major] Fix W4A16 CUDA kernel launching bug.
parents
3a6dfc39
2a51f5e9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
20 deletions
+20
-20
awq/kernels/gemm_cuda_gen.cu
awq/kernels/gemm_cuda_gen.cu
+20
-20
No files found.
awq/kernels/gemm_cuda_gen.cu
View file @
06e299ba
...
...
@@ -25,6 +25,10 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
int
j_factors1
=
((
OC
+
128
-
1
)
/
128
);
int
blockIdx_x
=
0
;
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
half
A_shared_warp
[
8
];
half
B_shared_warp
[
32
];
for
(
int
j_0_4_init
=
0
;
j_0_4_init
<
4
;
++
j_0_4_init
)
{
...
...
@@ -36,20 +40,19 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
static
constexpr
int
row_stride_warp
=
32
*
8
/
32
;
static
constexpr
int
row_stride
=
2
*
32
*
8
/
128
;
bool
ld_zero_flag
=
(
threadIdx
.
y
*
32
+
threadIdx
.
x
)
*
8
<
128
;
// TODO: Haotian: blockIdx
.
y / j_factors1 in A loading to support bsz > 16
bool
ld_A_flag
=
(
blockIdx
.
y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
threadIdx
.
x
*
8
/
32
)
<
M
;
// threadIdx.y is warp_id
// TODO: Haotian: blockIdx
_
y / j_factors1 in A loading to support bsz > 16
bool
ld_A_flag
=
(
blockIdx
_
y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
threadIdx
.
x
*
8
/
32
)
<
M
;
// threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half
*
A_ptr
=
A
+
(((
int
)
blockIdx
.
y
)
/
j_factors1
*
16
+
(((
int
)
threadIdx
.
y
)
*
row_stride_warp
)
+
((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
IC
+
(((
int
)
blockIdx
_
y
)
/
j_factors1
*
16
+
(((
int
)
threadIdx
.
y
)
*
row_stride_warp
)
+
((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
IC
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
))
*
8
;
int
*
B_ptr
=
B
+
((
int
)
threadIdx
.
y
)
*
(
OC
/
8
)
*
2
+
(((
int
)
threadIdx
.
x
)
/
(
128
/
8
))
*
(
OC
/
8
)
+
(((
int
)
blockIdx
.
y
)
%
j_factors1
)
*
(
128
/
8
)
+
(((
int
)
blockIdx
_
y
)
%
j_factors1
)
*
(
128
/
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
1
;
// Why * 1 in the above line?
half
*
A_shared_ptr
=
A_shared
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
...
...
@@ -62,26 +65,26 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
8
;
int
*
zeros_ptr
=
zeros
+
(((
int
)
blockIdx
.
y
)
%
j_factors1
)
*
(
128
/
8
)
+
(((
int
)
blockIdx
_
y
)
%
j_factors1
)
*
(
128
/
8
)
+
((
int
)
threadIdx
.
x
)
%
(
128
/
8
);
half
*
scaling_factors_ptr
=
scaling_factors
+
(((
int
)
blockIdx
.
y
)
%
j_factors1
)
*
(
128
)
+
(((
int
)
blockIdx
_
y
)
%
j_factors1
)
*
(
128
)
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
8
;
half
*
C_ptr
=
C
+
blockIdx
.
z
*
M
*
OC
// blockIdz.x -> split_k dim
+
(((
int
)
blockIdx
.
y
)
%
j_factors1
)
*
128
+
blockIdx
_
z
*
M
*
OC
// blockIdz.x -> split_k dim
+
(((
int
)
blockIdx
_
y
)
%
j_factors1
)
*
128
+
((
int
)
threadIdx
.
y
)
*
64
+
(((
int
)
threadIdx
.
x
)
%
4
)
*
2
;
// preload s.f. and zeros
int
k_bound
=
(
IC
/
32
+
split_k_iters
-
1
)
/
split_k_iters
;
if
((
k_bound
-
1
)
*
32
+
blockIdx
.
z
>=
IC
)
k_bound
-=
1
;
if
((
k_bound
-
1
)
*
32
+
blockIdx
_
z
>=
IC
)
k_bound
-=
1
;
for
(
int
_k_0_0
=
0
;
_k_0_0
<
k_bound
;
++
_k_0_0
)
{
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx
.
z
;
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx
_
z
;
__syncthreads
();
// TODO: Haotian: blockIdx
.
y / j_factors1 in A loading to support bsz > 16
// TODO: Haotian: blockIdx
_
y / j_factors1 in A loading to support bsz > 16
if
(
ld_A_flag
)
{
*
(
uint4
*
)(
A_shared_ptr
)
=
*
(
uint4
*
)(
A_ptr
+
(
k_0_0
*
32
));
...
...
@@ -96,7 +99,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
128
*
(
OC
));
/*
if (blockIdx
.
z == 0 && blockIdx
.
y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
if (blockIdx
_
z == 0 && blockIdx
_
y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
...
...
@@ -104,12 +107,11 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
int
*
B_ptr_local
=
B_ptr
+
k_0_0
*
32
*
(
OC
/
8
);
for
(
int
ax0_ax1_fused_0
=
0
;
ax0_ax1_fused_0
<
8
;
++
ax0_ax1_fused_0
)
{
// TODO: Shang: double check how to get 8.
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx
.
y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx
_
y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
...
...
@@ -127,7 +129,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx
.
z == 0 && blockIdx
.
y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
if (ax0_ax1_fused_0 == 0 && blockIdx
_
z == 0 && blockIdx
_
y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
...
...
@@ -194,7 +196,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
// TODO: Shang: Hoist loop invariance.
for
(
int
ax1_0_1
=
0
;
ax1_0_1
<
4
;
++
ax1_0_1
)
{
for
(
int
local_id
=
0
;
local_id
<
8
;
++
local_id
)
{
int
row_offset
=
(((
int
)
blockIdx
.
y
)
/
j_factors1
)
*
16
+
((
int
)
threadIdx
.
x
)
/
4
+
(
local_id
%
4
)
/
2
*
8
;
int
row_offset
=
(((
int
)
blockIdx
_
y
)
/
j_factors1
)
*
16
+
((
int
)
threadIdx
.
x
)
/
4
+
(
local_id
%
4
)
/
2
*
8
;
if
(
row_offset
<
M
)
{
*
(
C_ptr
+
ax1_0_1
*
16
+
row_offset
*
OC
+
(
local_id
/
4
)
*
8
+
local_id
%
2
)
=
__float2half
(
C_warp
[(
ax1_0_1
*
8
)
+
local_id
]);
...
...
@@ -231,15 +233,13 @@ torch::Tensor gemm_forward_cuda(
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
// blockIdx.x: i_factors[0] * j_factors[0]
// blockIdx.y: i_factors[1] * j_factors[1]
if
(
num_out_channels
%
128
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 128"
);
if
(
num_out_channels
%
8
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of pack_num = 8"
);
int
j_factors1
=
num_out_channels
/
128
/
1
;
dim3
num_blocks
(
1
,
(
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
,
split_k_iters
);
dim3
num_blocks
((
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment