Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
2a51f5e9
Commit
2a51f5e9
authored
Jun 07, 2023
by
ys-2020
Browse files
[Minor] fixed CUDA kernel launching bug
parent
3a6dfc39
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
20 deletions
+20
-20
awq/kernels/gemm_cuda_gen.cu
awq/kernels/gemm_cuda_gen.cu
+20
-20
No files found.
awq/kernels/gemm_cuda_gen.cu
View file @
2a51f5e9
...
@@ -25,6 +25,10 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -25,6 +25,10 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
int
j_factors1
=
((
OC
+
128
-
1
)
/
128
);
int
j_factors1
=
((
OC
+
128
-
1
)
/
128
);
int
blockIdx_x
=
0
;
int
blockIdx_y
=
blockIdx
.
x
%
((
M
+
16
-
1
)
/
16
*
j_factors1
);
int
blockIdx_z
=
blockIdx
.
x
/
((
M
+
16
-
1
)
/
16
*
j_factors1
);
half
A_shared_warp
[
8
];
half
A_shared_warp
[
8
];
half
B_shared_warp
[
32
];
half
B_shared_warp
[
32
];
for
(
int
j_0_4_init
=
0
;
j_0_4_init
<
4
;
++
j_0_4_init
)
{
for
(
int
j_0_4_init
=
0
;
j_0_4_init
<
4
;
++
j_0_4_init
)
{
...
@@ -36,20 +40,19 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -36,20 +40,19 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
static
constexpr
int
row_stride_warp
=
32
*
8
/
32
;
static
constexpr
int
row_stride_warp
=
32
*
8
/
32
;
static
constexpr
int
row_stride
=
2
*
32
*
8
/
128
;
static
constexpr
int
row_stride
=
2
*
32
*
8
/
128
;
bool
ld_zero_flag
=
(
threadIdx
.
y
*
32
+
threadIdx
.
x
)
*
8
<
128
;
bool
ld_zero_flag
=
(
threadIdx
.
y
*
32
+
threadIdx
.
x
)
*
8
<
128
;
// TODO: Haotian: blockIdx
.
y / j_factors1 in A loading to support bsz > 16
// TODO: Haotian: blockIdx
_
y / j_factors1 in A loading to support bsz > 16
bool
ld_A_flag
=
(
blockIdx
.
y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
threadIdx
.
x
*
8
/
32
)
<
M
;
// threadIdx.y is warp_id
bool
ld_A_flag
=
(
blockIdx
_
y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
threadIdx
.
x
*
8
/
32
)
<
M
;
// threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
// bool wb_C_flag = (threadIdx.x / 4) < M;
half
*
A_ptr
=
A
half
*
A_ptr
=
A
+
(((
int
)
blockIdx
.
y
)
/
j_factors1
*
16
+
(((
int
)
threadIdx
.
y
)
*
row_stride_warp
)
+
((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
IC
+
(((
int
)
blockIdx
_
y
)
/
j_factors1
*
16
+
(((
int
)
threadIdx
.
y
)
*
row_stride_warp
)
+
((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
IC
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
))
*
8
;
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
))
*
8
;
int
*
B_ptr
=
B
int
*
B_ptr
=
B
+
((
int
)
threadIdx
.
y
)
*
(
OC
/
8
)
*
2
+
((
int
)
threadIdx
.
y
)
*
(
OC
/
8
)
*
2
+
(((
int
)
threadIdx
.
x
)
/
(
128
/
8
))
*
(
OC
/
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
128
/
8
))
*
(
OC
/
8
)
+
(((
int
)
blockIdx
.
y
)
%
j_factors1
)
*
(
128
/
8
)
+
(((
int
)
blockIdx
_
y
)
%
j_factors1
)
*
(
128
/
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
1
;
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
1
;
// Why * 1 in the above line?
half
*
A_shared_ptr
=
A_shared
half
*
A_shared_ptr
=
A_shared
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
...
@@ -62,26 +65,26 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -62,26 +65,26 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
8
;
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
8
;
int
*
zeros_ptr
=
zeros
int
*
zeros_ptr
=
zeros
+
(((
int
)
blockIdx
.
y
)
%
j_factors1
)
*
(
128
/
8
)
+
(((
int
)
blockIdx
_
y
)
%
j_factors1
)
*
(
128
/
8
)
+
((
int
)
threadIdx
.
x
)
%
(
128
/
8
);
+
((
int
)
threadIdx
.
x
)
%
(
128
/
8
);
half
*
scaling_factors_ptr
=
scaling_factors
half
*
scaling_factors_ptr
=
scaling_factors
+
(((
int
)
blockIdx
.
y
)
%
j_factors1
)
*
(
128
)
+
(((
int
)
blockIdx
_
y
)
%
j_factors1
)
*
(
128
)
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
8
;
+
(((
int
)
threadIdx
.
x
)
%
(
128
/
8
))
*
8
;
half
*
C_ptr
=
C
half
*
C_ptr
=
C
+
blockIdx
.
z
*
M
*
OC
// blockIdz.x -> split_k dim
+
blockIdx
_
z
*
M
*
OC
// blockIdz.x -> split_k dim
+
(((
int
)
blockIdx
.
y
)
%
j_factors1
)
*
128
+
(((
int
)
blockIdx
_
y
)
%
j_factors1
)
*
128
+
((
int
)
threadIdx
.
y
)
*
64
+
((
int
)
threadIdx
.
y
)
*
64
+
(((
int
)
threadIdx
.
x
)
%
4
)
*
2
;
+
(((
int
)
threadIdx
.
x
)
%
4
)
*
2
;
// preload s.f. and zeros
// preload s.f. and zeros
int
k_bound
=
(
IC
/
32
+
split_k_iters
-
1
)
/
split_k_iters
;
int
k_bound
=
(
IC
/
32
+
split_k_iters
-
1
)
/
split_k_iters
;
if
((
k_bound
-
1
)
*
32
+
blockIdx
.
z
>=
IC
)
k_bound
-=
1
;
if
((
k_bound
-
1
)
*
32
+
blockIdx
_
z
>=
IC
)
k_bound
-=
1
;
for
(
int
_k_0_0
=
0
;
_k_0_0
<
k_bound
;
++
_k_0_0
)
{
for
(
int
_k_0_0
=
0
;
_k_0_0
<
k_bound
;
++
_k_0_0
)
{
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx
.
z
;
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx
_
z
;
__syncthreads
();
__syncthreads
();
// TODO: Haotian: blockIdx
.
y / j_factors1 in A loading to support bsz > 16
// TODO: Haotian: blockIdx
_
y / j_factors1 in A loading to support bsz > 16
if
(
ld_A_flag
)
if
(
ld_A_flag
)
{
{
*
(
uint4
*
)(
A_shared_ptr
)
=
*
(
uint4
*
)(
A_ptr
+
(
k_0_0
*
32
));
*
(
uint4
*
)(
A_shared_ptr
)
=
*
(
uint4
*
)(
A_ptr
+
(
k_0_0
*
32
));
...
@@ -96,7 +99,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -96,7 +99,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
128
*
(
OC
));
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
128
*
(
OC
));
/*
/*
if (blockIdx
.
z == 0 && blockIdx
.
y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
if (blockIdx
_
z == 0 && blockIdx
_
y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
}
*/
*/
...
@@ -104,12 +107,11 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -104,12 +107,11 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
int
*
B_ptr_local
=
B_ptr
+
k_0_0
*
32
*
(
OC
/
8
);
int
*
B_ptr_local
=
B_ptr
+
k_0_0
*
32
*
(
OC
/
8
);
for
(
int
ax0_ax1_fused_0
=
0
;
ax0_ax1_fused_0
<
8
;
++
ax0_ax1_fused_0
)
{
for
(
int
ax0_ax1_fused_0
=
0
;
ax0_ax1_fused_0
<
8
;
++
ax0_ax1_fused_0
)
{
// TODO: Shang: double check how to get 8.
// B: 32 x 136 (128+8) float16
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx
.
y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx
_
y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
...
@@ -127,7 +129,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -127,7 +129,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
/*
/*
if (ax0_ax1_fused_0 == 0 && blockIdx
.
z == 0 && blockIdx
.
y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
if (ax0_ax1_fused_0 == 0 && blockIdx
_
z == 0 && blockIdx
_
y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
}
*/
*/
...
@@ -194,7 +196,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
...
@@ -194,7 +196,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int spli
// TODO: Shang: Hoist loop invariance.
// TODO: Shang: Hoist loop invariance.
for
(
int
ax1_0_1
=
0
;
ax1_0_1
<
4
;
++
ax1_0_1
)
{
for
(
int
ax1_0_1
=
0
;
ax1_0_1
<
4
;
++
ax1_0_1
)
{
for
(
int
local_id
=
0
;
local_id
<
8
;
++
local_id
)
{
for
(
int
local_id
=
0
;
local_id
<
8
;
++
local_id
)
{
int
row_offset
=
(((
int
)
blockIdx
.
y
)
/
j_factors1
)
*
16
+
((
int
)
threadIdx
.
x
)
/
4
+
(
local_id
%
4
)
/
2
*
8
;
int
row_offset
=
(((
int
)
blockIdx
_
y
)
/
j_factors1
)
*
16
+
((
int
)
threadIdx
.
x
)
/
4
+
(
local_id
%
4
)
/
2
*
8
;
if
(
row_offset
<
M
)
if
(
row_offset
<
M
)
{
{
*
(
C_ptr
+
ax1_0_1
*
16
+
row_offset
*
OC
+
(
local_id
/
4
)
*
8
+
local_id
%
2
)
=
__float2half
(
C_warp
[(
ax1_0_1
*
8
)
+
local_id
]);
*
(
C_ptr
+
ax1_0_1
*
16
+
row_offset
*
OC
+
(
local_id
/
4
)
*
8
+
local_id
%
2
)
=
__float2half
(
C_warp
[(
ax1_0_1
*
8
)
+
local_id
]);
...
@@ -231,15 +233,13 @@ torch::Tensor gemm_forward_cuda(
...
@@ -231,15 +233,13 @@ torch::Tensor gemm_forward_cuda(
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
// blockIdx.x: i_factors[0] * j_factors[0]
// blockIdx.y: i_factors[1] * j_factors[1]
if
(
num_out_channels
%
128
!=
0
)
if
(
num_out_channels
%
128
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 128"
);
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 128"
);
if
(
num_out_channels
%
8
!=
0
)
if
(
num_out_channels
%
8
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of pack_num = 8"
);
throw
std
::
invalid_argument
(
"OC is not multiple of pack_num = 8"
);
int
j_factors1
=
num_out_channels
/
128
/
1
;
int
j_factors1
=
num_out_channels
/
128
/
1
;
dim3
num_blocks
(
1
,
(
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
,
split_k_iters
);
dim3
num_blocks
((
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
// threadIdx.x: 32
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
// threadIdx.y: i_factors[2] * j_factors[2]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment