Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
beb89f68
"torchvision/vscode:/vscode.git/clone" did not exist on "86a14cbad46f6f026ffcee7f504ffaca8da33929"
Unverified
Commit
beb89f68
authored
Jan 27, 2024
by
Casper
Committed by
GitHub
Jan 26, 2024
Browse files
AWQ: Up to 2.66x higher throughput (#2566)
parent
390b495f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
127 additions
and
1 deletion
+127
-1
csrc/ops.h
csrc/ops.h
+8
-0
csrc/pybind.cpp
csrc/pybind.cpp
+1
-0
csrc/quantization/awq/gemm_kernels.cu
csrc/quantization/awq/gemm_kernels.cu
+108
-0
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+10
-1
No files found.
csrc/ops.h
View file @
beb89f68
...
...
@@ -70,6 +70,14 @@ torch::Tensor awq_gemm(
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
);
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
int
thy
);
#endif
void
squeezellm_gemm
(
...
...
csrc/pybind.cpp
View file @
beb89f68
...
...
@@ -51,6 +51,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifndef USE_ROCM
// Quantization ops
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
#endif
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
...
...
csrc/quantization/awq/gemm_kernels.cu
View file @
beb89f68
...
...
@@ -493,9 +493,117 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
#endif
}
__global__
void
__launch_bounds__
(
64
)
dequantize_weights
(
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
half
*
__restrict__
C
,
int
G
)
{
int
j_factors1
=
4
;
int
row_stride2
=
4
;
int
split_k_iters
=
1
;
static
constexpr
uint32_t
ZERO
=
0x0
;
half
B_shared
[
32
*
(
128
+
8
)];
half
*
B_shared_ptr2
=
B_shared
;
half
B_shared_warp
[
32
];
int
OC
=
512
;
int
N
=
blockDim
.
x
*
gridDim
.
x
;
// 2
int
col
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
index1
=
8
*
col
+
8
*
row
*
N
;
half
*
C_ptr2
=
C
+
index1
;
int
index2
=
col
+
row
*
N
;
int
*
B_ptr2
=
B
+
index2
;
int
index3
=
col
+
(
int
)(
row
/
G
)
*
N
;
int
*
zeros_ptr2
=
zeros
+
index3
;
int
index4
=
8
*
col
+
(
int
)(
row
/
G
)
*
N
*
8
;
half
*
scaling_factors_ptr2
=
scaling_factors
+
index4
;
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr2
);
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr2
);
int
j
=
0
;
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr2
+
j
);
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_zero
.
x
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_scale
.
x
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_zero
.
y
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_scale
.
y
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_zero
.
z
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_scale
.
z
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
*
(
uint4
*
)(
B_shared_ptr2
+
j
)
=
B_loaded_fp16
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
*
(
C_ptr2
+
i
)
=
B_shared
[
i
];
}
}
}
// namespace awq
}
// namespace vllm
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
int
thy
)
{
int
in_c
=
_kernel
.
size
(
0
);
int
qout_c
=
_kernel
.
size
(
1
);
int
out_c
=
qout_c
*
8
;
int
G
=
in_c
/
_scaling_factors
.
size
(
0
);
int
x_thread
=
thx
;
int
y_thread
=
thy
;
int
x_blocks
=
1
;
int
y_blocks
=
1
;
if
(
thx
==
0
)
{
x_thread
=
qout_c
;
}
if
(
thy
==
0
)
{
y_thread
=
in_c
;
}
if
(
thx
==
0
&&
thy
==
0
)
{
x_thread
=
8
;
y_thread
=
8
;
x_blocks
=
(
int
)(
qout_c
/
8
);
y_blocks
=
(
int
)(
in_c
/
8
);
}
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
_scaling_factors
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
_scaling_factors
.
dtype
()).
device
(
_scaling_factors
.
device
());
at
::
Tensor
_de_kernel
=
torch
::
empty
({
in_c
,
out_c
},
options
);
auto
kernel
=
reinterpret_cast
<
int
*>
(
_kernel
.
data_ptr
<
int
>
());
auto
de_kernel
=
reinterpret_cast
<
half
*>
(
_de_kernel
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
dim3
num_blocks
(
x_blocks
,
y_blocks
);
dim3
threads_per_block
(
x_thread
,
y_thread
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
vllm
::
awq
::
dequantize_weights
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
kernel
,
scaling_factors
,
zeros
,
de_kernel
,
G
);
return
_de_kernel
;
}
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
beb89f68
...
...
@@ -153,7 +153,16 @@ class AWQLinearMethod(LinearMethodBase):
pack_factor
=
self
.
quant_config
.
pack_factor
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
]
*
pack_factor
,
))
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
ops
.
awq_gemm
(
reshaped_x
,
qweight
,
scales
,
qzeros
,
pack_factor
)
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION
=
x
.
shape
[:
-
1
].
numel
()
>=
256
if
FP16_MATMUL_HEURISTIC_CONDITION
:
out
=
ops
.
awq_dequantize
(
qweight
,
scales
,
qzeros
,
0
,
0
,
0
)
out
=
torch
.
matmul
(
reshaped_x
,
out
)
else
:
out
=
ops
.
awq_gemm
(
reshaped_x
,
qweight
,
scales
,
qzeros
,
pack_factor
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
reshape
(
out_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment