Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b00b33d7
Unverified
Commit
b00b33d7
authored
Nov 19, 2024
by
ElizaWszola
Committed by
GitHub
Nov 19, 2024
Browse files
[Model][Quantization] HQQ support through Marlin kernel expansion (#9766)
Signed-off-by:
ElizaWszola
<
eliza@neuralmagic.com
>
parent
efa90846
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
632 additions
and
89 deletions
+632
-89
benchmarks/kernels/benchmark_machete.py
benchmarks/kernels/benchmark_machete.py
+2
-1
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+2
-2
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+200
-77
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+1
-1
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+87
-1
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+2
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+5
-3
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+2
-1
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-0
vllm/model_executor/layers/quantization/hqq_marlin.py
vllm/model_executor/layers/quantization/hqq_marlin.py
+325
-0
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+4
-2
No files found.
benchmarks/kernels/benchmark_machete.py
View file @
b00b33d7
...
...
@@ -210,7 +210,8 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
size_m
=
bt
.
a
.
shape
[
0
],
size_n
=
bt
.
w_ref
.
shape
[
1
],
size_k
=
bt
.
w_ref
.
shape
[
0
],
is_k_full
=
True
)
is_k_full
=
True
,
is_zp_float
=
False
)
else
:
assert
bt
.
a
.
dtype
==
torch
.
int8
assert
bt
.
wtype
==
scalar_types
.
uint4b8
...
...
benchmarks/kernels/benchmark_marlin.py
View file @
b00b33d7
...
...
@@ -131,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)"
,
# noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False,
False,
False)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
...
...
@@ -141,7 +141,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)"
,
# noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True
, False
)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
...
...
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
b00b33d7
...
...
@@ -54,9 +54,10 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
has_act_order
,
// whether act_order is enabled
const
int
group_blocks
=
-
1
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
...
...
@@ -82,7 +83,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeId
const
b_q_type_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
)
{
bool
is_k_full
,
bool
has_zp
,
bool
is_zp_float
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
...
...
@@ -516,10 +517,11 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
=
-
1
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
=
-
1
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
...
...
@@ -692,8 +694,10 @@ __global__ void Marlin(
int
act_s_col_tb_stride
=
act_s_col_warp_stride
*
tb_n_warps
;
// Zero-points sizes/strides
int
zp_gl_stride
=
(
prob_n
/
pack_factor
)
/
4
;
constexpr
int
zp_sh_stride
=
((
16
*
thread_n_blocks
)
/
pack_factor
)
/
4
;
int
zp_gl_stride
=
is_zp_float
?
prob_n
/
8
:
(
prob_n
/
pack_factor
)
/
4
;
constexpr
int
zp_sh_stride
=
is_zp_float
?
16
*
thread_n_blocks
/
8
:
((
16
*
thread_n_blocks
)
/
pack_factor
)
/
4
;
constexpr
int
zp_tb_groups
=
s_tb_groups
;
constexpr
int
zp_sh_stage
=
has_zp
?
zp_tb_groups
*
zp_sh_stride
:
0
;
int
zp_gl_rd_delta
=
zp_gl_stride
;
...
...
@@ -768,9 +772,16 @@ __global__ void Marlin(
constexpr
int
num_ints_per_thread
=
8
/
pack_factor
;
int
zp_sh_rd
;
if
constexpr
(
has_zp
)
{
zp_sh_rd
=
num_ints_per_thread
*
num_col_threads
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
num_ints_per_thread
*
((
threadIdx
.
x
%
32
)
/
num_row_threads
);
if
constexpr
(
is_zp_float
)
{
if
constexpr
(
group_blocks
!=
-
1
)
{
zp_sh_rd
=
8
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
(
threadIdx
.
x
%
32
)
/
4
;
}
}
else
{
zp_sh_rd
=
num_ints_per_thread
*
num_col_threads
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
num_ints_per_thread
*
((
threadIdx
.
x
%
32
)
/
num_row_threads
);
}
}
// Precompute which thread should not read memory in which iterations; this is
...
...
@@ -832,6 +843,7 @@ __global__ void Marlin(
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
int
frag_qzp
[
2
][
num_ints_per_thread
];
// Zero-points
FragZP
frag_zp
;
// Zero-points in fp16
FragZP
frag_zpf
[
2
];
// Zero-points in fp16 in HQQ
// Zero accumulators.
auto
zero_accums
=
[
&
]()
{
...
...
@@ -1126,7 +1138,7 @@ __global__ void Marlin(
// has_zp implies AWQ, which doesn't have act_order,
static_assert
(
!
has_zp
||
group_blocks
!=
0
);
if
constexpr
(
has_zp
)
{
if
constexpr
(
has_zp
&&
!
is_zp_float
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
group_blocks
==
-
1
)
{
...
...
@@ -1170,11 +1182,44 @@ __global__ void Marlin(
}
}
}
else
if
constexpr
(
has_zp
&&
is_zp_float
)
{
int
pipe
=
full_pipe
%
stages
;
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
>=
thread_k_blocks
)
{
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
((
group_blocks
/
thread_k_blocks
)
*
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k
%
2
])[
0
]
=
sh_zp_stage
[
zp_sh_rd
];
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
int
cur_k
=
warp_row
*
16
;
cur_k
+=
k_iter_size
*
(
k
%
b_sh_wr_iters
);
int
k_blocks
=
cur_k
/
16
;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
int
cur_group_id
=
k_blocks
/
group_blocks
;
#pragma nv_diagnostic pop
int4
*
sh_zp_stage
=
sh_zp
+
zp_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k
%
2
])[
0
]
=
sh_zp_stage
[
zp_sh_rd
+
cur_group_id
*
zp_sh_stride
];
}
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
auto
matmul
=
[
&
](
int
k
)
{
if
constexpr
(
has_zp
)
{
if
constexpr
(
has_zp
&&
!
is_zp_float
)
{
FragB
frag_zp_0
;
FragB
frag_zp_1
;
int
zp_quant_0
,
zp_quant_1
;
...
...
@@ -1219,10 +1264,14 @@ __global__ void Marlin(
frag_b1
=
dequant
<
scalar_t
,
w_type_id
>
(
b_quant_1
);
// Apply zero-point to frag_b0
if
constexpr
(
has_zp
)
{
if
constexpr
(
has_zp
&&
!
is_zp_float
)
{
sub_zp
<
scalar_t
>
(
frag_b0
,
frag_zp
[
j
],
0
);
}
else
if
constexpr
(
has_zp
&&
is_zp_float
&&
group_blocks
!=
-
1
)
{
sub_zp
<
scalar_t
>
(
frag_b0
,
frag_zpf
[
k
%
2
][
j
],
0
);
}
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
scale4
<
scalar_t
>
(
frag_b0
,
act_frag_s
[
k
%
2
][
0
][
j
],
...
...
@@ -1235,10 +1284,14 @@ __global__ void Marlin(
}
// Apply zero-point to frag_b1
if
constexpr
(
has_zp
)
{
if
constexpr
(
has_zp
&&
!
is_zp_float
)
{
sub_zp
<
scalar_t
>
(
frag_b1
,
frag_zp
[
j
],
1
);
}
else
if
constexpr
(
has_zp
&&
is_zp_float
&&
group_blocks
!=
-
1
)
{
sub_zp
<
scalar_t
>
(
frag_b1
,
frag_zpf
[
k
%
2
][
j
],
1
);
}
// Apply scale to frag_b1
if
constexpr
(
has_act_order
)
{
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
...
...
@@ -1510,7 +1563,7 @@ __global__ void Marlin(
fetch_scales_to_shared
(
true
,
g_idx
[
slice_k_start
],
g_idx
[
last_g_idx
]);
}
if
constexpr
(
has_zp
&&
group_blocks
==
-
1
)
{
if
constexpr
(
has_zp
&&
!
is_zp_float
&&
group_blocks
==
-
1
)
{
if
(
i
==
0
)
{
fetch_zp_to_shared
();
}
...
...
@@ -1697,23 +1750,27 @@ __global__ void Marlin(
}
#define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \
IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
if constexpr (!IS_ZP_FLOAT || std::is_same<scalar_t, half>::value) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, \
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
} \
}
typedef
struct
{
...
...
@@ -1905,51 +1962,96 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
}
#define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
false)
#define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
false) \
\
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
false) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false)
// We currently have 4-bit models only with group_blocks == 4
#define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
true) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
true) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
true) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true)
template
<
typename
scalar_t
>
void
marlin_mm
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
C_tmp
,
void
*
s
,
...
...
@@ -1958,7 +2060,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
,
bool
use_fp32_reduce
)
{
int
sms
,
int
max_par
,
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
if
(
has_zp
)
{
TORCH_CHECK
(
q_type
==
vllm
::
kU4
||
q_type
==
vllm
::
kU8
,
...
...
@@ -2111,6 +2213,11 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
AWQ_CALL_IF
(
vllm
::
kU8
,
8
,
8
,
256
)
AWQ_CALL_IF
(
vllm
::
kU8
,
8
,
4
,
128
)
AWQ_CALL_IF
(
vllm
::
kU8
,
4
,
8
,
128
)
HQQ_CALL_IF
(
vllm
::
kU4
,
16
,
4
,
256
)
HQQ_CALL_IF
(
vllm
::
kU4
,
8
,
8
,
256
)
HQQ_CALL_IF
(
vllm
::
kU4
,
8
,
4
,
128
)
HQQ_CALL_IF
(
vllm
::
kU4
,
4
,
8
,
128
)
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
,
", has_act_order = "
,
has_act_order
,
...
...
@@ -2135,7 +2242,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
vllm
::
ScalarTypeId
const
&
b_q_type_id
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
,
bool
use_fp32_reduce
)
{
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
vllm
::
ScalarType
const
b_q_type
=
vllm
::
ScalarType
::
from_id
(
b_q_type_id
);
if
(
has_zp
)
{
TORCH_CHECK
(
...
...
@@ -2148,6 +2255,12 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
b_q_type
.
str
());
}
if
(
has_zp
&&
is_zp_float
)
{
TORCH_CHECK
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Computation type must be float16 (half) when using float zero "
"points."
);
}
int
pack_factor
=
32
/
b_q_type
.
size_bits
();
// Verify A
...
...
@@ -2257,12 +2370,22 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
if
(
has_zp
)
{
int
rank
=
b_zeros
.
sizes
().
size
();
TORCH_CHECK
(
rank
==
2
,
"b_zeros rank = "
,
rank
,
" is not 2"
);
TORCH_CHECK
(
b_zeros
.
size
(
0
)
==
num_groups
,
"b_zeros dim 0 = "
,
b_zeros
.
size
(
0
),
" is not num_groups = "
,
num_groups
);
TORCH_CHECK
(
b_zeros
.
size
(
1
)
==
size_n
/
pack_factor
,
"b_zeros dim 1 = "
,
b_zeros
.
size
(
1
),
" is not size_n / pack_factor = "
,
size_n
/
pack_factor
);
if
(
is_zp_float
)
{
TORCH_CHECK
(
b_zeros
.
size
(
1
)
==
size_n
,
"b_zeros dim 1 = "
,
b_zeros
.
size
(
1
),
" is not size_n = "
,
size_n
);
TORCH_CHECK
(
num_groups
==
b_zeros
.
size
(
0
),
"b_zeros dim 0 = "
,
b_zeros
.
size
(
0
),
" is not num_groups = "
,
num_groups
);
TORCH_CHECK
(
num_groups
!=
-
1
,
"num_groups must be != -1"
);
}
else
{
TORCH_CHECK
(
b_zeros
.
size
(
0
)
==
num_groups
,
"b_zeros dim 0 = "
,
b_zeros
.
size
(
0
),
" is not num_groups = "
,
num_groups
);
TORCH_CHECK
(
b_zeros
.
size
(
1
)
==
size_n
/
pack_factor
,
"b_zeros dim 1 = "
,
b_zeros
.
size
(
1
),
" is not size_n / pack_factor = "
,
size_n
/
pack_factor
);
}
}
// Verify workspace size
...
...
@@ -2282,7 +2405,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
,
is_zp_float
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
marlin
::
marlin_mm
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
...
...
@@ -2291,7 +2414,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
,
is_zp_float
);
}
else
{
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
}
...
...
csrc/torch_bindings.cpp
View file @
b00b33d7
...
...
@@ -244,7 +244,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor"
);
"bool has_zp, bool use_fp32_reduce
, bool is_zp_float
) -> Tensor"
);
// conditionally compiled so impl registration is in source file
// gptq_marlin repack from GPTQ.
...
...
tests/kernels/test_marlin_gemm.py
View file @
b00b33d7
...
...
@@ -29,6 +29,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import
marlin_qqq_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
awq_pack
,
gptq_pack
,
gptq_quantize_weights
,
quantize_weights
,
sort_weights
)
from
vllm.scalar_type
import
scalar_types
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
...
...
@@ -40,6 +41,8 @@ MARLIN_N_CHUNKS = [64, 256]
MARLIN_24_K_CHUNKS
=
[
128
]
MARLIN_24_N_CHUNKS
=
[
512
]
HQQ_SUPPORTED_GROUP_SIZES
=
[
64
]
MNK_FACTORS
=
[
(
1
,
1
,
1
),
(
1
,
4
,
8
),
...
...
@@ -226,7 +229,7 @@ def test_gptq_marlin_gemm(
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
quant_type
.
id
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
False
,
use_fp32_reduce
),
a_input
.
shape
[
1
],
is_k_full
,
False
,
use_fp32_reduce
,
False
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
output
=
ops
.
gptq_marlin_gemm
(
...
...
@@ -244,6 +247,7 @@ def test_gptq_marlin_gemm(
is_k_full
=
is_k_full
,
has_zp
=
False
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
,
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
...
...
@@ -441,6 +445,7 @@ def test_awq_marlin_gemm(
is_k_full
=
is_k_full
,
has_zp
=
has_zp
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
,
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
...
...
@@ -451,6 +456,87 @@ def test_awq_marlin_gemm(
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
HQQ_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"use_fp32_reduce"
,
USE_FP32_REDUCE_OPTS
)
def
test_hqq_marlin_gemm
(
k_chunk
,
n_chunk
,
group_size
,
mnk_factors
,
use_fp32_reduce
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
quant_type
=
scalar_types
.
uint4
a_input
=
rand_data
((
size_m
,
size_k
))
dev
=
a_input
.
device
b_weight
=
torch
.
randint
(
0
,
10
,
(
size_n
,
size_k
),
dtype
=
torch
.
uint8
,
device
=
dev
)
scale
=
rand_data
((
size_n
,
size_k
//
group_size
))
zero
=
rand_data
((
size_n
,
size_k
//
group_size
))
gptq_w_q
=
gptq_pack
(
b_weight
.
transpose
(
1
,
0
),
4
,
size_k
,
size_n
)
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
dev
)
marlin_w_q
=
ops
.
gptq_marlin_repack
(
gptq_w_q
,
sort_indices
,
size_k
,
size_n
,
4
).
to
(
dev
)
marlin_s
=
marlin_permute_scales
(
scale
.
transpose
(
1
,
0
),
size_k
,
size_n
,
group_size
).
to
(
dev
)
marlin_zp
=
marlin_permute_scales
(
zero
.
transpose
(
1
,
0
),
size_k
,
size_n
,
group_size
).
to
(
dev
)
g_idx
=
marlin_make_empty_g_idx
(
dev
)
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
dev
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
marlin_w_q
,
marlin_s
,
marlin_zp
,
g_idx
,
g_idx_sort_indices
,
workspace
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
0
],
a_input
.
shape
[
1
],
is_k_full
=
True
,
has_zp
=
True
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
True
,
)
b_flat
=
b_weight
.
reshape
(
-
1
,
group_size
)
zp_flat
=
zero
.
reshape
(
-
1
,
1
)
s_flat
=
scale
.
reshape
(
-
1
,
1
)
dequant
=
(
b_flat
-
zp_flat
)
*
s_flat
output_ref
=
torch
.
matmul
(
a_input
,
dequant
.
reshape
(
b_weight
.
shape
).
transpose
(
1
,
0
))
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"qqq"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
...
...
tests/weight_loading/models.txt
View file @
b00b33d7
...
...
@@ -27,4 +27,5 @@ fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
qqq, HandH1998/QQQ-Llama-3-8b, main
\ No newline at end of file
qqq, HandH1998/QQQ-Llama-3-8b, main
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
\ No newline at end of file
vllm/_custom_ops.py
View file @
b00b33d7
...
...
@@ -343,7 +343,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
size_k
:
torch
.
SymInt
,
is_k_full
:
bool
,
has_zp
:
bool
=
False
,
use_fp32_reduce
:
bool
=
False
)
->
torch
.
Tensor
:
use_fp32_reduce
:
bool
=
False
,
is_zp_float
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
@
register_fake
(
"_C::ggml_dequantize"
)
...
...
@@ -601,11 +602,12 @@ def gptq_marlin_gemm(a: torch.Tensor,
size_k
:
int
,
is_k_full
:
bool
,
has_zp
:
bool
=
False
,
use_fp32_reduce
:
bool
=
False
)
->
torch
.
Tensor
:
use_fp32_reduce
:
bool
=
False
,
is_zp_float
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
b_zeros
,
g_idx
,
perm
,
workspace
,
b_q_type
.
id
,
size_m
,
size_n
,
size_k
,
is_k_full
,
has_zp
,
use_fp32_reduce
)
has_zp
,
use_fp32_reduce
,
is_zp_float
)
# fp8 marlin
...
...
vllm/model_executor/layers/linear.py
View file @
b00b33d7
...
...
@@ -27,7 +27,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
,
"ModelOptFp8LinearMethod"
,
"IPEXAWQLinearMethod"
,
"IPEXGPTQLinearMethod"
"ModelOptFp8LinearMethod"
,
"IPEXAWQLinearMethod"
,
"IPEXGPTQLinearMethod"
,
"HQQMarlinMethod"
]
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
b00b33d7
...
...
@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQMarlin24Config
)
from
vllm.model_executor.layers.quantization.hqq_marlin
import
HQQMarlinConfig
from
vllm.model_executor.layers.quantization.ipex_quant
import
IPEXConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.modelopt
import
ModelOptFp8Config
...
...
@@ -48,6 +49,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
"hqq"
:
HQQMarlinConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"neuron_quant"
:
NeuronQuantConfig
,
"ipex"
:
IPEXConfig
,
...
...
vllm/model_executor/layers/quantization/hqq_marlin.py
0 → 100644
View file @
b00b33d7
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
marlin_make_empty_g_idx
,
marlin_permute_scales
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
gptq_pack
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
class
HQQMarlinConfig
(
QuantizationConfig
):
"""Config class for HQQ Marlin"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
skip_modules
:
Optional
[
List
[
str
]]
=
None
,
)
->
None
:
assert
group_size
==
64
,
(
"The only supported HQQ group size is "
"currently 64."
)
assert
weight_bits
==
4
,
(
"The only supported HQQ quantization "
"bitsize is currently 4."
)
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
pack_factor
=
32
//
weight_bits
# packed into int32 in GPTQ format
self
.
quant_type
=
scalar_types
.
uint4
self
.
skip_modules
=
skip_modules
def
__repr__
(
self
)
->
str
:
return
(
f
"HQQMarlinConfig(quant_type=
{
self
.
quant_type
}
, "
f
"group_size=
{
self
.
group_size
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"hqq"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"HQQMarlinConfig"
:
wq_params
=
(
config
[
"quant_config"
][
"weight_quant_params"
])
weight_bits
=
cls
.
get_from_keys
(
wq_params
,
[
"nbits"
])
group_size
=
cls
.
get_from_keys
(
wq_params
,
[
"group_size"
])
skip_modules
=
config
[
"skip_modules"
]
return
cls
(
weight_bits
,
group_size
,
skip_modules
)
def
is_layer_skipped
(
self
,
prefix
:
str
)
->
bool
:
# Split the prefix into its dot-separated components
components
=
prefix
.
split
(
'.'
)
# Check if any of the skip modules exactly matches any component
return
self
.
skip_modules
is
not
None
and
any
(
module_name
in
components
for
module_name
in
self
.
skip_modules
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
self
.
is_layer_skipped
(
prefix
):
return
UnquantizedLinearMethod
()
return
HQQMarlinMethod
(
self
)
return
None
# Empty HQQ parameter, will be ignored during loading
class
HQQEmptyParameter
(
BasevLLMParameter
):
def
load_merged_column_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
):
pass
def
load_row_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
):
pass
def
load_qkv_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
):
pass
def
error_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
raise
ValueError
(
"No loader provided for HQQ parameter!"
)
# HQQ packing creates issues with sharding - therefore, prior to loading, we
# repack to GPTQ. We also reshape the weights to their proper GPTQ shape.
class
HQQweightParameter
(
PackedvLLMParameter
):
# unpack function from https://github.com/mobiusml/hqq
def
unpack_4bit_u8
(
self
,
W_q
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# uint8/2 > uint8
assert
self
.
weight_bits
==
4
,
"Unsupported quant bitsize (must be 4)"
dtype
=
torch
.
uint8
step
=
W_q
.
shape
[
0
]
tmp
=
torch
.
empty
([
2
*
step
,
W_q
.
shape
[
1
]],
dtype
=
dtype
,
device
=
W_q
.
device
)
tmp
[:
step
]
=
(
W_q
&
0b11110000
)
>>
4
tmp
[
step
:]
=
W_q
&
0b00001111
return
tmp
def
__init__
(
self
,
packed_factor
:
int
,
packed_dim
:
int
,
weight_bits
:
int
,
**
kwargs
):
super
().
__init__
(
packed_factor
,
packed_dim
,
None
,
**
kwargs
)
self
.
weight_bits
=
weight_bits
self
.
input_shape
=
self
.
shape
[
self
.
input_dim
]
*
self
.
packed_factor
self
.
output_shape
=
self
.
shape
[
self
.
output_dim
]
def
load_merged_column_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
):
loaded_weight
=
self
.
unpack_4bit_u8
(
loaded_weight
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
self
.
input_shape
).
transpose
(
1
,
0
)
loaded_weight
=
gptq_pack
(
loaded_weight
,
self
.
weight_bits
,
loaded_weight
.
shape
[
0
],
loaded_weight
.
shape
[
1
])
super
().
load_merged_column_weight
(
loaded_weight
,
**
kwargs
)
def
load_row_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
):
loaded_weight
=
self
.
unpack_4bit_u8
(
loaded_weight
)
loaded_weight
=
loaded_weight
.
reshape
(
self
.
output_shape
,
-
1
).
transpose
(
1
,
0
)
loaded_weight
=
gptq_pack
(
loaded_weight
,
self
.
weight_bits
,
loaded_weight
.
shape
[
0
],
loaded_weight
.
shape
[
1
])
super
().
load_row_parallel_weight
(
loaded_weight
)
def
load_qkv_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
):
loaded_weight
=
self
.
unpack_4bit_u8
(
loaded_weight
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
self
.
input_shape
).
transpose
(
1
,
0
)
loaded_weight
=
gptq_pack
(
loaded_weight
,
self
.
weight_bits
,
loaded_weight
.
shape
[
0
],
loaded_weight
.
shape
[
1
])
super
().
load_qkv_weight
(
loaded_weight
,
**
kwargs
)
# Zero points and scales in HQQ must also be reshaped to correspond to W_q's
# GPTQ shape (transposed - we transpose them too when processing weights).
class
HQQZeroScaleParameter
(
GroupQuantScaleParameter
):
def
load_merged_column_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
):
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
self
.
shape
[
1
])
super
().
load_merged_column_weight
(
loaded_weight
,
**
kwargs
)
def
load_row_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
):
loaded_weight
=
loaded_weight
.
reshape
(
self
.
shape
[
0
],
-
1
)
super
().
load_row_parallel_weight
(
loaded_weight
)
def
load_qkv_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
):
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
self
.
shape
[
1
])
super
().
load_qkv_weight
(
loaded_weight
,
**
kwargs
)
class
HQQMarlinMethod
(
LinearMethodBase
):
"""Linear method for HQQ Marlin.
"""
def
__init__
(
self
,
quant_config
:
HQQMarlinConfig
,
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
self
.
output_size_per_partition
=
sum
(
output_partition_sizes
)
self
.
input_size_per_partition
=
input_size_per_partition
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
,
error_loader
)
self
.
scales_and_zp_size
=
(
input_size_per_partition
//
self
.
quant_config
.
group_size
)
qweight
=
HQQweightParameter
(
data
=
torch
.
empty
(
self
.
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
self
.
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_bits
=
self
.
quant_config
.
weight_bits
,
weight_loader
=
weight_loader
)
zeros
=
HQQZeroScaleParameter
(
data
=
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
scales_and_zp_size
,
dtype
=
params_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
scales
=
HQQZeroScaleParameter
(
data
=
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
scales_and_zp_size
,
dtype
=
params_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"W_q"
,
qweight
)
layer
.
register_parameter
(
"zero"
,
zeros
)
layer
.
register_parameter
(
"scale"
,
scales
)
# Ignore extra parameters in the HQQ model.
# To be added as needed.
ignore_parameters
=
(
"axis"
,
"channel_wise"
,
"compute_dtype"
,
"encoded_state_dict"
,
"group_size"
,
"nbits"
,
"offload_meta"
,
"optimize"
,
"packing"
,
"quant_scale"
,
"quant_zero"
,
"round_zero"
,
"shape"
,
"stores_quant_config"
,
"unpack_view_dtype"
,
"view_as_float"
)
for
name
in
ignore_parameters
:
layer
.
register_parameter
(
name
,
HQQEmptyParameter
(
data
=
torch
.
empty
(
0
),
weight_loader
=
weight_loader
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
dev
=
layer
.
W_q
.
device
# Repack to Marlin
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
dev
)
marlin_w_q
=
ops
.
gptq_marlin_repack
(
layer
.
W_q
,
sort_indices
,
self
.
input_size_per_partition
,
self
.
output_size_per_partition
,
self
.
quant_config
.
weight_bits
,
).
to
(
dev
)
marlin_s
=
marlin_permute_scales
(
layer
.
scale
.
transpose
(
1
,
0
),
self
.
input_size_per_partition
,
self
.
output_size_per_partition
,
self
.
quant_config
.
group_size
).
to
(
dev
)
marlin_zp
=
marlin_permute_scales
(
layer
.
zero
.
transpose
(
1
,
0
),
self
.
input_size_per_partition
,
self
.
output_size_per_partition
,
self
.
quant_config
.
group_size
).
to
(
dev
)
layer
.
g_idx
=
marlin_make_empty_g_idx
(
dev
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
dev
)
layer
.
marlin_qweight
=
marlin_w_q
layer
.
marlin_zeros
=
marlin_zp
layer
.
marlin_scales
=
marlin_s
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
workspace
=
MarlinWorkspace
(
self
.
output_size_per_partition
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
scales
=
layer
.
marlin_scales
zeros
=
layer
.
marlin_zeros
orig_type
=
x
.
dtype
if
orig_type
!=
torch
.
float16
:
x
=
x
.
to
(
torch
.
float16
)
scales
=
scales
.
to
(
torch
.
float16
)
zeros
=
zeros
.
to
(
torch
.
float16
)
marlin_out
=
ops
.
gptq_marlin_gemm
(
x
,
layer
.
marlin_qweight
,
scales
,
zeros
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
workspace
.
scratch
,
scalar_types
.
uint4
,
x
.
shape
[
0
],
self
.
output_size_per_partition
,
self
.
input_size_per_partition
,
True
,
# is_k_full
True
,
# has_zp
True
,
# use 32-bit reduce
True
,
# use float zp
)
if
orig_type
!=
torch
.
float16
:
marlin_out
=
marlin_out
.
to
(
orig_type
)
if
bias
is
not
None
:
marlin_out
.
add_
(
bias
)
return
marlin_out
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
b00b33d7
...
...
@@ -303,7 +303,8 @@ def apply_gptq_marlin_linear(
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
,
has_zp
=
False
,
use_fp32_reduce
=
use_fp32_reduce
)
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
...
...
@@ -340,7 +341,8 @@ def apply_awq_marlin_linear(
size_k
=
input_size_per_partition
,
is_k_full
=
True
,
has_zp
=
True
,
use_fp32_reduce
=
use_fp32_reduce
)
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment