Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
60662532
Unverified
Commit
60662532
authored
May 23, 2024
by
Alexander Matveev
Committed by
GitHub
May 23, 2024
Browse files
Marlin 24 prefill performance improvement (about 25% better on average) (#4983)
parent
ee3eea0a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
107 additions
and
32 deletions
+107
-32
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+62
-12
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
+40
-15
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+1
-1
vllm/model_executor/layers/quantization/gptq_marlin_24.py
vllm/model_executor/layers/quantization/gptq_marlin_24.py
+4
-4
No files found.
benchmarks/kernels/benchmark_marlin.py
View file @
60662532
...
@@ -6,9 +6,13 @@ from benchmark_shapes import WEIGHT_SHAPES
...
@@ -6,9 +6,13 @@ from benchmark_shapes import WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
MarlinWorkspace
,
marlin_quantize
)
MarlinWorkspace
,
marlin_24_quantize
,
marlin_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
quantize_weights
,
sort_weights
)
gptq_pack
,
quantize_weights
,
sort_weights
)
...
@@ -44,6 +48,10 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
...
@@ -44,6 +48,10 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
marlin_rand_perm
,
marlin_rand_perm
,
)
=
marlin_quantize
(
b
,
num_bits
,
group_size
,
act_order
)
)
=
marlin_quantize
(
b
,
num_bits
,
group_size
,
act_order
)
# Marlin_24 quant
(
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
=
marlin_24_quantize
(
b
,
num_bits
,
group_size
)
# GPTQ quant
# GPTQ quant
(
w_ref
,
q_w
,
s
,
g_idx
,
(
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
)
=
quantize_weights
(
b
,
num_bits
,
group_size
,
act_order
)
rand_perm
)
=
quantize_weights
(
b
,
num_bits
,
group_size
,
act_order
)
...
@@ -56,28 +64,43 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
...
@@ -56,28 +64,43 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
(
q_w
,
g_idx
,
repack_sort_indices
)
=
sort_weights
(
q_w
,
g_idx
)
(
q_w
,
g_idx
,
repack_sort_indices
)
=
sort_weights
(
q_w
,
g_idx
)
# Prepare
# Prepare
marlin_workspace
=
MarlinWorkspace
(
size_n
)
marlin_workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
marlin_24_workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_MAX_PARALLEL
)
globals
=
{
globals
=
{
# Gen params
"num_bits"
:
num_bits
,
"group_size"
:
group_size
,
"size_m"
:
size_m
,
"size_n"
:
size_n
,
"size_k"
:
size_k
,
"a"
:
a
,
"a_tmp"
:
a_tmp
,
# Marlin params
"marlin_w_ref"
:
marlin_w_ref
,
"marlin_w_ref"
:
marlin_w_ref
,
"marlin_q_w"
:
marlin_q_w
,
"marlin_q_w"
:
marlin_q_w
,
"marlin_s"
:
marlin_s
,
"marlin_s"
:
marlin_s
,
"marlin_g_idx"
:
marlin_g_idx
,
"marlin_g_idx"
:
marlin_g_idx
,
"marlin_sort_indices"
:
marlin_sort_indices
,
"marlin_sort_indices"
:
marlin_sort_indices
,
"marlin_rand_perm"
:
marlin_rand_perm
,
"marlin_rand_perm"
:
marlin_rand_perm
,
"marlin_workspace"
:
marlin_workspace
,
"is_k_full"
:
is_k_full
,
# Marlin_24 params
"marlin_24_w_ref"
:
marlin_24_w_ref
,
"marlin_24_q_w_comp"
:
marlin_24_q_w_comp
,
"marlin_24_meta"
:
marlin_24_meta
,
"marlin_24_s"
:
marlin_24_s
,
"marlin_24_workspace"
:
marlin_24_workspace
,
# GPTQ params
"q_w_gptq"
:
q_w_gptq
,
"q_w_gptq"
:
q_w_gptq
,
"repack_sort_indices"
:
repack_sort_indices
,
"repack_sort_indices"
:
repack_sort_indices
,
"num_bits"
:
num_bits
,
# Kernels
"group_size"
:
group_size
,
"size_m"
:
size_m
,
"size_n"
:
size_n
,
"size_k"
:
size_k
,
"is_k_full"
:
is_k_full
,
"a"
:
a
,
"a_tmp"
:
a_tmp
,
"gptq_marlin_gemm"
:
ops
.
gptq_marlin_gemm
,
"gptq_marlin_gemm"
:
ops
.
gptq_marlin_gemm
,
"gptq_marlin_24_gemm"
:
ops
.
gptq_marlin_24_gemm
,
"gptq_marlin_repack"
:
ops
.
gptq_marlin_repack
,
"gptq_marlin_repack"
:
ops
.
gptq_marlin_repack
,
"marlin_workspace"
:
marlin_workspace
,
}
}
min_run_time
=
1
min_run_time
=
1
...
@@ -105,6 +128,18 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
...
@@ -105,6 +128,18 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
description
=
"gptq_marlin_gemm"
,
description
=
"gptq_marlin_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
).
blocked_autorange
(
min_run_time
=
min_run_time
))
if
(
num_bits
in
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
and
group_size
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
):
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_24_gemm"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
stmt
=
...
@@ -135,8 +170,20 @@ def main(args):
...
@@ -135,8 +170,20 @@ def main(args):
continue
continue
for
act_order
in
ACT_ORDER_OPTS
:
for
act_order
in
ACT_ORDER_OPTS
:
if
len
(
args
.
limit_act_order
)
>
0
and
act_order
not
in
args
.
limit_act_order
:
continue
for
is_k_full
in
K_FULL_OPTS
:
for
is_k_full
in
K_FULL_OPTS
:
if
len
(
args
.
limit_k_full
)
>
0
and
is_k_full
not
in
args
.
limit_k_full
:
continue
for
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
for
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
if
len
(
args
.
limit_num_bits
)
>
0
and
num_bits
not
in
args
.
limit_num_bits
:
continue
for
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
:
for
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
:
if
len
(
if
len
(
args
.
limit_group_size
args
.
limit_group_size
...
@@ -159,7 +206,7 @@ def main(args):
...
@@ -159,7 +206,7 @@ def main(args):
# For quick benchmarking use:
# For quick benchmarking use:
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128
--limit-num-bits 4 --limit-act-order 0 --limit-k-full 1
# noqa E501
#
#
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
...
@@ -178,6 +225,9 @@ if __name__ == "__main__":
...
@@ -178,6 +225,9 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-k"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-n"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-group-size"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-group-size"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-num-bits"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-act-order"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
parser
.
add_argument
(
"--limit-k-full"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
View file @
60662532
...
@@ -48,12 +48,12 @@ namespace marlin_24 {
...
@@ -48,12 +48,12 @@ namespace marlin_24 {
// than 1 warp per schedule allows some more latency hiding. At the same time,
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
// we want relatively few warps to have many registers per warp and small tiles.
static
constexpr
int
THREADS
=
256
;
static
constexpr
int
THREADS
=
256
;
static
constexpr
int
STAGES
=
4
;
// 4 pipeline stages fit into shared memory
static
constexpr
int
STAGES
=
4
;
static
constexpr
int
min_thread_n
=
128
;
static
constexpr
int
min_thread_n
=
128
;
static
constexpr
int
tile_size
=
16
;
static
constexpr
int
tile_size
=
16
;
static
constexpr
int
max_par
=
1
6
;
static
constexpr
int
max_par
=
6
4
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
...
@@ -736,10 +736,10 @@ __global__ void Marlin_24(
...
@@ -736,10 +736,10 @@ __global__ void Marlin_24(
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
fetch_to_shared
((
pipe
+
stages
-
1
)
%
stages
,
pipe
,
slice_iters
>=
stages
);
slice_iters
>=
stages
);
matmul
(
pipe
);
wait_for_stage
();
wait_for_stage
();
fetch_to_registers
(
pipe
+
1
,
(
pipe
+
1
)
%
stages
);
fetch_to_registers
(
pipe
+
1
,
(
pipe
+
1
)
%
stages
);
matmul
(
pipe
);
pipe
++
;
pipe
++
;
slice_iters
--
;
slice_iters
--
;
...
@@ -899,9 +899,12 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
...
@@ -899,9 +899,12 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
// than better compute utilization
// than better compute utilization
thread_k
=
128
;
thread_k
=
128
;
thread_m
=
128
;
thread_m
=
128
;
}
else
{
}
else
if
(
prob_n
<=
256
)
{
thread_k
=
64
;
thread_k
=
64
;
thread_m
=
256
;
thread_m
=
256
;
}
else
{
thread_k
=
32
;
thread_m
=
512
;
}
}
}
}
...
@@ -928,19 +931,21 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
...
@@ -928,19 +931,21 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_ptr
=
(
int4
*
)
C
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
constexpr
int
max_m_blocks
=
4
;
int
*
locks
=
(
int
*
)
workspace
;
int
*
locks
=
(
int
*
)
workspace
;
for
(
int
i
=
0
;
i
<
tot_n_blocks
;
i
+=
4
)
{
for
(
int
i
=
0
;
i
<
tot_n_blocks
;
i
+=
max_m_blocks
)
{
int
thread_n_blocks
=
tot_n_blocks
-
i
;
int
thread_n_blocks
=
tot_n_blocks
-
i
;
prob_n
=
tot_n
-
16
*
i
;
prob_n
=
tot_n
-
16
*
i
;
int
par
=
1
;
int
par
=
1
;
if
(
thread_n_blocks
>
4
)
{
if
(
thread_n_blocks
>
max_m_blocks
)
{
// Note that parallel > 1 currently only works for inputs without any
// Note that parallel > 1 currently only works for inputs without any
// padding
// padding
par
=
(
16
*
thread_n_blocks
-
pad
)
/
64
;
par
=
(
16
*
thread_n_blocks
-
pad
)
/
(
max_m_blocks
*
16
)
;
if
(
par
>
max_par
)
par
=
max_par
;
if
(
par
>
max_par
)
par
=
max_par
;
prob_n
=
64
*
par
;
prob_n
=
(
max_m_blocks
*
16
)
*
par
;
i
+=
4
*
(
par
-
1
);
i
+=
max_m_blocks
*
(
par
-
1
);
thread_n_blocks
=
4
;
thread_n_blocks
=
max_m_blocks
;
}
}
// For compilation speed, we only define the kernel configurations that have
// For compilation speed, we only define the kernel configurations that have
...
@@ -953,6 +958,7 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
...
@@ -953,6 +958,7 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
// 4-bit
// 4-bit
CALL_IF_2_4
(
4
,
8
,
1
,
4
,
-
1
)
// e.g., 16x128x128
CALL_IF_2_4
(
4
,
8
,
1
,
4
,
-
1
)
// e.g., 16x128x128
CALL_IF_2_4
(
4
,
8
,
1
,
4
,
4
)
// e.g., 16x128x128, 64
CALL_IF_2_4
(
4
,
8
,
1
,
4
,
4
)
// e.g., 16x128x128, 64
CALL_IF_2_4
(
4
,
16
,
1
,
2
,
-
1
)
// e.g., 16x256x64
CALL_IF_2_4
(
4
,
16
,
1
,
2
,
-
1
)
// e.g., 16x256x64
CALL_IF_2_4
(
4
,
16
,
1
,
2
,
4
)
// e.g., 16x256x64, 64
CALL_IF_2_4
(
4
,
16
,
1
,
2
,
4
)
// e.g., 16x256x64, 64
CALL_IF_2_4
(
4
,
16
,
2
,
2
,
-
1
)
// e.g.. 32x256x64
CALL_IF_2_4
(
4
,
16
,
2
,
2
,
-
1
)
// e.g.. 32x256x64
...
@@ -962,9 +968,19 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
...
@@ -962,9 +968,19 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
CALL_IF_2_4
(
4
,
16
,
4
,
2
,
-
1
)
CALL_IF_2_4
(
4
,
16
,
4
,
2
,
-
1
)
CALL_IF_2_4
(
4
,
16
,
4
,
2
,
4
)
CALL_IF_2_4
(
4
,
16
,
4
,
2
,
4
)
CALL_IF_2_4
(
4
,
32
,
1
,
1
,
-
1
)
// e.g., 16x256x64
CALL_IF_2_4
(
4
,
32
,
1
,
1
,
4
)
// e.g., 16x256x64, 64
CALL_IF_2_4
(
4
,
32
,
2
,
1
,
-
1
)
// e.g.. 32x256x64
CALL_IF_2_4
(
4
,
32
,
2
,
1
,
4
)
CALL_IF_2_4
(
4
,
32
,
3
,
1
,
-
1
)
CALL_IF_2_4
(
4
,
32
,
3
,
1
,
4
)
CALL_IF_2_4
(
4
,
32
,
4
,
1
,
-
1
)
CALL_IF_2_4
(
4
,
32
,
4
,
1
,
4
)
// 8-bit
// 8-bit
CALL_IF_2_4
(
8
,
8
,
1
,
4
,
-
1
)
// e.g., 16x128x128
CALL_IF_2_4
(
8
,
8
,
1
,
4
,
-
1
)
// e.g., 16x128x128
CALL_IF_2_4
(
8
,
8
,
1
,
4
,
4
)
// e.g., 16x128x128, 64
CALL_IF_2_4
(
8
,
8
,
1
,
4
,
4
)
// e.g., 16x128x128, 64
CALL_IF_2_4
(
8
,
16
,
1
,
2
,
-
1
)
// e.g., 16x256x64
CALL_IF_2_4
(
8
,
16
,
1
,
2
,
-
1
)
// e.g., 16x256x64
CALL_IF_2_4
(
8
,
16
,
1
,
2
,
4
)
// e.g., 16x256x64, 64
CALL_IF_2_4
(
8
,
16
,
1
,
2
,
4
)
// e.g., 16x256x64, 64
CALL_IF_2_4
(
8
,
16
,
2
,
2
,
-
1
)
// e.g.. 32x256x64
CALL_IF_2_4
(
8
,
16
,
2
,
2
,
-
1
)
// e.g.. 32x256x64
...
@@ -973,6 +989,15 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
...
@@ -973,6 +989,15 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
CALL_IF_2_4
(
8
,
16
,
3
,
2
,
4
)
CALL_IF_2_4
(
8
,
16
,
3
,
2
,
4
)
CALL_IF_2_4
(
8
,
16
,
4
,
2
,
-
1
)
CALL_IF_2_4
(
8
,
16
,
4
,
2
,
-
1
)
CALL_IF_2_4
(
8
,
16
,
4
,
2
,
4
)
CALL_IF_2_4
(
8
,
16
,
4
,
2
,
4
)
CALL_IF_2_4
(
8
,
32
,
1
,
1
,
-
1
)
// e.g., 16x256x64
CALL_IF_2_4
(
8
,
32
,
1
,
1
,
4
)
// e.g., 16x256x64, 64
CALL_IF_2_4
(
8
,
32
,
2
,
1
,
-
1
)
// e.g.. 32x256x64
CALL_IF_2_4
(
8
,
32
,
2
,
1
,
4
)
CALL_IF_2_4
(
8
,
32
,
3
,
1
,
-
1
)
CALL_IF_2_4
(
8
,
32
,
3
,
1
,
4
)
CALL_IF_2_4
(
8
,
32
,
4
,
1
,
-
1
)
CALL_IF_2_4
(
8
,
32
,
4
,
1
,
4
)
else
{
else
{
throw
std
::
runtime_error
(
"Unsupported shapes: MKN = ["
+
str
(
prob_m
)
+
throw
std
::
runtime_error
(
"Unsupported shapes: MKN = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_k
)
+
", "
+
str
(
prob_n
)
+
"]"
+
", "
+
str
(
prob_k
)
+
", "
+
str
(
prob_n
)
+
"]"
+
...
@@ -1062,7 +1087,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -1062,7 +1087,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int
thread_k
=
-
1
;
int
thread_k
=
-
1
;
int
thread_m
=
-
1
;
int
thread_m
=
-
1
;
int
sms
=
-
1
;
int
sms
=
-
1
;
int
max_par
=
16
;
int
max_par
=
marlin_24
::
max_par
;
int
groupsize
=
-
1
;
int
groupsize
=
-
1
;
if
(
b_scales
.
size
(
0
)
>
1
)
{
if
(
b_scales
.
size
(
0
)
>
1
)
{
...
...
tests/kernels/test_marlin_gemm.py
View file @
60662532
...
@@ -27,7 +27,7 @@ MARLIN_K_CHUNKS = [128]
...
@@ -27,7 +27,7 @@ MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS
=
[
64
,
128
,
256
]
MARLIN_N_CHUNKS
=
[
64
,
128
,
256
]
MARLIN_24_K_CHUNKS
=
[
128
]
MARLIN_24_K_CHUNKS
=
[
128
]
MARLIN_24_N_CHUNKS
=
[
256
]
MARLIN_24_N_CHUNKS
=
[
512
]
MNK_FACTORS
=
[
MNK_FACTORS
=
[
(
1
,
1
,
1
),
(
1
,
1
,
1
),
...
...
vllm/model_executor/layers/quantization/gptq_marlin_24.py
View file @
60662532
...
@@ -15,7 +15,7 @@ logger = init_logger(__name__)
...
@@ -15,7 +15,7 @@ logger = init_logger(__name__)
GPTQ_MARLIN_24_TILE
=
16
GPTQ_MARLIN_24_TILE
=
16
GPTQ_MARLIN_24_MIN_THREAD_N
=
128
GPTQ_MARLIN_24_MIN_THREAD_N
=
128
GPTQ_MARLIN_24_MIN_THREAD_K
=
128
GPTQ_MARLIN_24_MIN_THREAD_K
=
128
GPTQ_MARLIN_24_MAX_PARALLEL
=
1
6
GPTQ_MARLIN_24_MAX_PARALLEL
=
6
4
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
...
@@ -53,14 +53,14 @@ class GPTQMarlin24Config(QuantizationConfig):
...
@@ -53,14 +53,14 @@ class GPTQMarlin24Config(QuantizationConfig):
self
.
tile_size
=
16
self
.
tile_size
=
16
# Min out_features dim
# Min out_features dim
self
.
min_n_threads
=
128
self
.
min_n_threads
=
GPTQ_MARLIN_24_MIN_THREAD_N
# Min in_features dim
# Min in_features dim
self
.
min_k_threads
=
128
self
.
min_k_threads
=
GPTQ_MARLIN_24_MIN_THREAD_K
# Max parallel problems to solve at once (improves large
# Max parallel problems to solve at once (improves large
# batch performance)
# batch performance)
self
.
max_parallel
=
16
self
.
max_parallel
=
GPTQ_MARLIN_24_MAX_PARALLEL
# Permutation length used by the marlin kernels.
# Permutation length used by the marlin kernels.
self
.
perm_len
=
1024
self
.
perm_len
=
1024
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment