Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ_kernels
Commits
16c5fe16
Commit
16c5fe16
authored
Dec 28, 2023
by
Casper
Browse files
Add dequantization kernel
parent
b5592bd6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
141 additions
and
1 deletion
+141
-1
awq_ext/pybind_awq.cpp
awq_ext/pybind_awq.cpp
+1
-0
awq_ext/quantization/gemm_cuda.h
awq_ext/quantization/gemm_cuda.h
+5
-1
awq_ext/quantization/gemm_cuda_gen.cu
awq_ext/quantization/gemm_cuda_gen.cu
+135
-0
No files found.
awq_ext/pybind_awq.cpp
View file @
16c5fe16
...
@@ -12,4 +12,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
...
@@ -12,4 +12,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m
.
def
(
"gemmv2_forward_cuda"
,
&
gemmv2_forward_cuda
,
"Quantized v2 GEMM kernel."
);
m
.
def
(
"gemmv2_forward_cuda"
,
&
gemmv2_forward_cuda
,
"Quantized v2 GEMM kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
m
.
def
(
"dequantize_weights_cuda"
,
&
dequantize_weights_cuda
,
"Dequantize weights."
);
}
}
\ No newline at end of file
awq_ext/quantization/gemm_cuda.h
View file @
16c5fe16
...
@@ -4,4 +4,8 @@ torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel,
...
@@ -4,4 +4,8 @@ torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
);
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
);
torch
::
Tensor
gemmv2_forward_cuda
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
gemmv2_forward_cuda
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
group_size
,
int
split_k_iters
);
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
group_size
,
int
split_k_iters
);
\ No newline at end of file
// Source - https://github.com/compressa-ai/AutoAWQ/blob/6673333456b8871522b11a7fb110de612edfdf95/awq_cuda/quantization/gemm_cuda.h#L9C1-L10C106
torch
::
Tensor
dequantize_weights_cuda
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
int
thy
,
bool
dbg
);
\ No newline at end of file
awq_ext/quantization/gemm_cuda_gen.cu
View file @
16c5fe16
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "dequantize.cuh"
#include "dequantize.cuh"
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <cublas_v2.h>
// Pack two half values.
// Pack two half values.
...
@@ -724,6 +725,140 @@ __global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int s
...
@@ -724,6 +725,140 @@ __global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int s
}
}
}
}
// Dequantization to fp16
// kernel
// Source - https://github.com/compressa-ai/AutoAWQ/blob/6673333456b8871522b11a7fb110de612edfdf95/awq_cuda/quantization/gemm_cuda_gen.cu#L32C1-L32C1
__global__
void
__launch_bounds__
(
64
)
dequantize_weights
(
int
*
__restrict__
B
,
// 4096x64 4096 rows 64 cols
half
*
__restrict__
scaling_factors
,
// 32x512 32 rows 512 cols
int
*
__restrict__
zeros
,
// 32x64 32 rows 64 cols
half
*
__restrict__
C
,
// 4096x512 4096 rows 512 cols
int
G
,
bool
dbg
)
{
int
j_factors1
=
4
;
int
row_stride2
=
4
;
int
split_k_iters
=
1
;
static
constexpr
uint32_t
ZERO
=
0x0
;
half
B_shared
[
32
*
(
128
+
8
)];
half
*
B_shared_ptr2
=
B_shared
;
half
B_shared_warp
[
32
];
int
OC
=
512
;
int
N
=
blockDim
.
x
*
gridDim
.
x
;
// 2
int
col
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
index1
=
8
*
col
+
8
*
row
*
N
;
// + i (<8)
half
*
C_ptr2
=
C
+
index1
;
int
index2
=
col
+
row
*
N
;
int
*
B_ptr2
=
B
+
index2
;
if
(
dbg
)
{
printf
(
"
\n
-------- x %d - y %d --------
\n
"
,
col
,
row
);
printf
(
"- %d-%d - N %d index1 %d
\n
"
,
col
,
row
,
N
,
index2
);
printf
(
"- %d-%d - B %d
\n
"
,
col
,
row
,
*
B_ptr2
);
}
int
index3
=
col
+
(
int
)(
row
/
G
)
*
N
;
int
*
zeros_ptr2
=
zeros
+
index3
;
int
index4
=
8
*
col
+
(
int
)(
row
/
G
)
*
N
*
8
;
// + i (<8)
half
*
scaling_factors_ptr2
=
scaling_factors
+
index4
;
if
(
dbg
)
{
printf
(
"- %d-%d - zeros[%d] %d
\n
"
,
col
,
row
,
index3
,
*
zeros_ptr2
);
printf
(
"- %d-%d - N %d index4 %d
\n
"
,
col
,
row
,
N
,
index4
);
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
printf
(
"- %d-%d - scale[%d] %f
\n
"
,
col
,
row
,
index4
+
i
,
__half2float
(
*
(
scaling_factors_ptr2
+
i
)));
}
}
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr2
);
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr2
);
int
j
=
0
;
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr2
+
j
);
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_zero
.
x
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_scale
.
x
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_zero
.
y
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_scale
.
y
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_zero
.
z
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_scale
.
z
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
*
(
uint4
*
)(
B_shared_ptr2
+
j
)
=
B_loaded_fp16
;
if
(
dbg
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
printf
(
"- %d-%d - B_shared_ptr2[%d] %f
\n
"
,
col
,
row
,
i
,
__half2float
(
*
(
B_shared_ptr2
+
i
))
);
}
}
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
*
(
C_ptr2
+
i
)
=
B_shared
[
i
];
}
}
// Dequantization to fp16
// Source - https://github.com/compressa-ai/AutoAWQ/blob/6673333456b8871522b11a7fb110de612edfdf95/awq_cuda/quantization/gemm_cuda_gen.cu#L935C1-L987C2
torch
::
Tensor
dequantize_weights_cuda
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
int
thy
,
bool
dbg
)
{
int
in_c
=
_kernel
.
size
(
0
);
int
qout_c
=
_kernel
.
size
(
1
);
int
out_c
=
qout_c
*
8
;
int
G
=
in_c
/
_scaling_factors
.
size
(
0
);
int
x_thread
=
thx
;
int
y_thread
=
thy
;
int
x_blocks
=
1
;
int
y_blocks
=
1
;
if
(
thx
==
0
)
{
x_thread
=
qout_c
;
}
if
(
thy
==
0
)
{
y_thread
=
in_c
;
}
int
dbg_
=
true
;
if
(
thx
==
0
&&
thy
==
0
)
{
dbg_
=
false
;
x_thread
=
8
;
y_thread
=
8
;
x_blocks
=
(
int
)(
qout_c
/
8
);
y_blocks
=
(
int
)(
in_c
/
8
);
}
dbg
=
dbg
&&
dbg_
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
_scaling_factors
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
_scaling_factors
.
dtype
()).
device
(
_scaling_factors
.
device
());
at
::
Tensor
_de_kernel
=
torch
::
empty
({
in_c
,
out_c
},
options
);
// row, col 4096x512
auto
kernel
=
reinterpret_cast
<
int
*>
(
_kernel
.
data_ptr
<
int
>
());
auto
de_kernel
=
reinterpret_cast
<
half
*>
(
_de_kernel
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
dim3
num_blocks
(
x_blocks
,
y_blocks
);
dim3
threads_per_block
(
x_thread
,
y_thread
);
// col, row 64x4096
dequantize_weights
<<<
num_blocks
,
threads_per_block
>>>
(
kernel
,
scaling_factors
,
zeros
,
de_kernel
,
G
,
dbg
);
return
_de_kernel
;
}
// in_feats: M, IC [float16]
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
// scaling_factors: IC // G, OC [float16]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment