Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
75acdaa4
Unverified
Commit
75acdaa4
authored
Jul 27, 2024
by
Alexander Matveev
Committed by
GitHub
Jul 27, 2024
Browse files
[Kernel] Increase precision of GPTQ/AWQ Marlin kernel (#6795)
parent
fad5576c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
168 additions
and
44 deletions
+168
-44
benchmarks/kernels/benchmark_marlin.py
benchmarks/kernels/benchmark_marlin.py
+18
-5
csrc/ops.h
csrc/ops.h
+2
-1
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+122
-28
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+10
-3
vllm/_custom_ops.py
vllm/_custom_ops.py
+3
-3
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+13
-4
No files found.
benchmarks/kernels/benchmark_marlin.py
View file @
75acdaa4
...
@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
...
@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_
MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_
MARLIN_SUPPORTED_NUM_BITS
)
MARLIN_SUPPORTED_GROUP_SIZES
,
MARLIN_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
marlin_quantize
)
MarlinWorkspace
,
marlin_quantize
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
...
@@ -56,6 +56,8 @@ def bench_run(results: List[benchmark.Measurement], model: str,
...
@@ -56,6 +56,8 @@ def bench_run(results: List[benchmark.Measurement], model: str,
(
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
(
marlin_24_w_ref
,
marlin_24_q_w_comp
,
marlin_24_meta
,
marlin_24_s
)
=
marlin_24_quantize
(
b
,
num_bits
,
group_size
)
marlin_24_s
)
=
marlin_24_quantize
(
b
,
num_bits
,
group_size
)
marlin_zp
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
b
.
device
)
# GPTQ quant
# GPTQ quant
(
w_ref
,
q_w
,
s
,
g_idx
,
(
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
)
=
quantize_weights
(
b
,
num_bits
,
group_size
,
act_order
)
rand_perm
)
=
quantize_weights
(
b
,
num_bits
,
group_size
,
act_order
)
...
@@ -87,6 +89,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
...
@@ -87,6 +89,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
"marlin_w_ref"
:
marlin_w_ref
,
"marlin_w_ref"
:
marlin_w_ref
,
"marlin_q_w"
:
marlin_q_w
,
"marlin_q_w"
:
marlin_q_w
,
"marlin_s"
:
marlin_s
,
"marlin_s"
:
marlin_s
,
"marlin_zp"
:
marlin_zp
,
"marlin_g_idx"
:
marlin_g_idx
,
"marlin_g_idx"
:
marlin_g_idx
,
"marlin_sort_indices"
:
marlin_sort_indices
,
"marlin_sort_indices"
:
marlin_sort_indices
,
"marlin_rand_perm"
:
marlin_rand_perm
,
"marlin_rand_perm"
:
marlin_rand_perm
,
...
@@ -125,11 +128,21 @@ def bench_run(results: List[benchmark.Measurement], model: str,
...
@@ -125,11 +128,21 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results
.
append
(
results
.
append
(
benchmark
.
Timer
(
benchmark
.
Timer
(
stmt
=
stmt
=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)"
,
# noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, False)"
,
# noqa: E501
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_gemm_fp16"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
results
.
append
(
benchmark
.
Timer
(
stmt
=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, True)"
,
# noqa: E501
globals
=
globals
,
globals
=
globals
,
label
=
label
,
label
=
label
,
sub_label
=
sub_label
,
sub_label
=
sub_label
,
description
=
"gptq_marlin_gemm"
,
description
=
"gptq_marlin_gemm
_fp32
"
,
).
blocked_autorange
(
min_run_time
=
min_run_time
))
).
blocked_autorange
(
min_run_time
=
min_run_time
))
if
(
num_bits
in
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
if
(
num_bits
in
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
...
@@ -183,12 +196,12 @@ def main(args):
...
@@ -183,12 +196,12 @@ def main(args):
)
>
0
and
is_k_full
not
in
args
.
limit_k_full
:
)
>
0
and
is_k_full
not
in
args
.
limit_k_full
:
continue
continue
for
num_bits
in
GPTQ_
MARLIN_SUPPORTED_NUM_BITS
:
for
num_bits
in
MARLIN_SUPPORTED_NUM_BITS
:
if
len
(
args
.
limit_num_bits
if
len
(
args
.
limit_num_bits
)
>
0
and
num_bits
not
in
args
.
limit_num_bits
:
)
>
0
and
num_bits
not
in
args
.
limit_num_bits
:
continue
continue
for
group_size
in
GPTQ_
MARLIN_SUPPORTED_GROUP_SIZES
:
for
group_size
in
MARLIN_SUPPORTED_GROUP_SIZES
:
if
len
(
if
len
(
args
.
limit_group_size
args
.
limit_group_size
)
>
0
and
group_size
not
in
args
.
limit_group_size
:
)
>
0
and
group_size
not
in
args
.
limit_group_size
:
...
...
csrc/ops.h
View file @
75acdaa4
...
@@ -93,7 +93,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -93,7 +93,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
);
bool
is_k_full
,
bool
has_zp
,
bool
use_fp32_reduce
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_n
,
...
...
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
75acdaa4
...
@@ -59,6 +59,7 @@ __global__ void Marlin(
...
@@ -59,6 +59,7 @@ __global__ void Marlin(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
// (k/groupsize)xn
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
const
int
*
__restrict__
g_idx
,
// int32 group indices of shape k
...
@@ -66,7 +67,8 @@ __global__ void Marlin(
...
@@ -66,7 +67,8 @@ __global__ void Marlin(
int
prob_m
,
// batch dimension m
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{}
)
{}
}
// namespace gptq_marlin
}
// namespace gptq_marlin
...
@@ -532,6 +534,7 @@ __global__ void Marlin(
...
@@ -532,6 +534,7 @@ __global__ void Marlin(
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
A
,
// fp16 input matrix of shape mxk
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
const
int4
*
__restrict__
B
,
// 4bit quantized weight matrix of shape kxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C
,
// fp16 output buffer of shape mxn
int4
*
__restrict__
C_tmp
,
// fp32 tmp output buffer (for reduce)
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
const
int4
*
__restrict__
scales_ptr
,
// fp16 quantization scales of shape
// (k/groupsize)xn
// (k/groupsize)xn
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
const
int4
*
__restrict__
zp_ptr
,
// 4bit packed zero-points of shape
...
@@ -541,7 +544,8 @@ __global__ void Marlin(
...
@@ -541,7 +544,8 @@ __global__ void Marlin(
int
prob_m
,
// batch dimension m
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
prob_k
,
// reduction dimension k
int
*
locks
// extra global storage for barrier synchronization
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_fp32_reduce
// whether to use fp32 global reduce
)
{
)
{
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// same size, which might involve multiple column "slices" (of width 16 *
...
@@ -595,6 +599,8 @@ __global__ void Marlin(
...
@@ -595,6 +599,8 @@ __global__ void Marlin(
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
int
slice_idx
;
// index of threadblock in current slice; numbered bottom to
// top
// top
int
par_id
=
0
;
// We can easily implement parallel problem execution by just remapping
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
if
(
slice_col_par
>=
n_tiles
)
{
...
@@ -602,6 +608,7 @@ __global__ void Marlin(
...
@@ -602,6 +608,7 @@ __global__ void Marlin(
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
par_id
=
slice_col_par
/
n_tiles
;
}
}
// Compute all information about the current slice which is required for
// Compute all information about the current slice which is required for
...
@@ -632,6 +639,7 @@ __global__ void Marlin(
...
@@ -632,6 +639,7 @@ __global__ void Marlin(
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
locks
+=
n_tiles
;
slice_col
=
0
;
slice_col
=
0
;
par_id
++
;
}
}
};
};
init_slice
();
init_slice
();
...
@@ -1321,7 +1329,7 @@ __global__ void Marlin(
...
@@ -1321,7 +1329,7 @@ __global__ void Marlin(
// finally have to globally reduce over the results. As the striped
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
// usually rather small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
auto
global_reduce
_fp16
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
// We are very careful here to reduce directly in the output buffer to
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
// results in FP16 (but still reduce with FP32 compute).
...
@@ -1382,6 +1390,53 @@ __global__ void Marlin(
...
@@ -1382,6 +1390,53 @@ __global__ void Marlin(
}
}
};
};
// Globally reduce over threadblocks that compute the same column block.
// We use a tmp C buffer to reduce in full fp32 precision.
auto
global_reduce_fp32
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
constexpr
int
tb_m
=
thread_m_blocks
*
16
;
constexpr
int
tb_n
=
thread_n_blocks
*
16
;
constexpr
int
c_size
=
tb_m
*
tb_n
*
sizeof
(
float
)
/
16
;
constexpr
int
active_threads
=
32
*
thread_n_blocks
/
4
;
bool
is_th_active
=
threadIdx
.
x
<
active_threads
;
int
par_offset
=
c_size
*
n_tiles
*
par_id
;
int
slice_offset
=
c_size
*
slice_col
;
constexpr
int
num_floats
=
thread_m_blocks
*
4
*
2
*
4
;
constexpr
int
th_size
=
num_floats
*
sizeof
(
float
)
/
16
;
int
c_cur_offset
=
par_offset
+
slice_offset
;
if
(
!
is_th_active
)
{
return
;
}
if
(
!
first
)
{
float
*
frag_c_ptr
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
sh
[
threadIdx
.
x
]
=
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
];
float
*
sh_c_ptr
=
reinterpret_cast
<
float
*>
(
&
sh
[
threadIdx
.
x
]);
#pragma unroll
for
(
int
f
=
0
;
f
<
4
;
f
++
)
{
frag_c_ptr
[
k
*
4
+
f
]
+=
sh_c_ptr
[
f
];
}
}
}
if
(
!
last
)
{
int4
*
frag_c_ptr
=
reinterpret_cast
<
int4
*>
(
&
frag_c
);
#pragma unroll
for
(
int
k
=
0
;
k
<
th_size
;
k
++
)
{
C_tmp
[
c_cur_offset
+
active_threads
*
k
+
threadIdx
.
x
]
=
frag_c_ptr
[
k
];
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
// in fragment layout.
...
@@ -1606,7 +1661,11 @@ __global__ void Marlin(
...
@@ -1606,7 +1661,11 @@ __global__ void Marlin(
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
// block in a slice
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
barrier_acquire
(
&
locks
[
slice_col
],
slice_idx
);
global_reduce
(
slice_idx
==
0
,
last
);
if
(
use_fp32_reduce
)
{
global_reduce_fp32
(
slice_idx
==
0
,
last
);
}
else
{
global_reduce_fp16
(
slice_idx
==
0
,
last
);
}
barrier_release
(
&
locks
[
slice_col
],
last
);
barrier_release
(
&
locks
[
slice_col
],
last
);
}
}
if
(
last
)
// only the last block in a slice actually writes the result
if
(
last
)
// only the last block in a slice actually writes the result
...
@@ -1661,8 +1720,8 @@ __global__ void Marlin(
...
@@ -1661,8 +1720,8 @@ __global__ void Marlin(
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS> \
HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr,
num_groups,
\
A_ptr, B_ptr, C_ptr,
C_tmp_ptr,
s_ptr, zp_ptr, g_idx_ptr, \
prob_m, prob_n, prob_k, locks
);
\
num_groups,
prob_m, prob_n, prob_k, locks
, use_fp32_reduce);
\
}
}
typedef
struct
{
typedef
struct
{
...
@@ -1801,6 +1860,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
...
@@ -1801,6 +1860,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
return
true
;
return
true
;
}
}
int
determine_reduce_max_m
(
int
prob_m
,
int
max_par
)
{
constexpr
int
tile_m_size
=
16
;
if
(
prob_m
<=
tile_m_size
)
{
return
tile_m_size
;
}
else
if
(
prob_m
<=
tile_m_size
*
2
)
{
return
tile_m_size
*
2
;
}
else
if
(
prob_m
<=
tile_m_size
*
3
)
{
return
tile_m_size
*
3
;
}
else
if
(
prob_m
<=
tile_m_size
*
4
)
{
return
tile_m_size
*
4
;
}
else
{
int
cur_par
=
min
(
div_ceil
(
prob_m
,
tile_m_size
*
4
),
max_par
);
return
tile_m_size
*
4
*
cur_par
;
}
}
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_act_order
,
bool
is_k_full
,
...
@@ -1880,13 +1960,13 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
...
@@ -1880,13 +1960,13 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
void
*
z
p
,
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
C_tm
p
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
void
*
s
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_bits
,
bool
has_act_order
,
bool
is_k_full
,
int
num_groups
,
int
group_size
,
int
dev
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
int
sms
,
int
max_par
)
{
int
max_par
,
bool
use_fp32_reduce
)
{
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
"num_bits must be 4 or 8. Got = "
,
num_bits
);
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
...
@@ -1970,6 +2050,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
...
@@ -1970,6 +2050,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
A_ptr
=
(
const
int4
*
)
A
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
const
int4
*
B_ptr
=
(
const
int4
*
)
B
;
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_ptr
=
(
int4
*
)
C
;
int4
*
C_tmp_ptr
=
(
int4
*
)
C_tmp
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
int4
*
s_ptr
=
(
const
int4
*
)
s
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
...
@@ -2049,7 +2130,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -2049,7 +2130,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
)
{
bool
is_k_full
,
bool
has_zp
,
bool
use_fp32_reduce
)
{
// Verify num_bits
// Verify num_bits
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
"num_bits must be 4 or 8. Got = "
,
num_bits
);
...
@@ -2099,6 +2181,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -2099,6 +2181,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
c
=
torch
::
empty
({
size_m
,
size_n
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
torch
::
Tensor
a_tmp
=
torch
::
empty
({
size_m
,
size_k
},
options
);
// Alloc C tmp buffer that is going to be used for the global reduce
int
reduce_max_m
=
marlin
::
determine_reduce_max_m
(
size_m
,
marlin
::
max_par
);
int
reduce_n
=
size_n
;
auto
options_fp32
=
torch
::
TensorOptions
().
dtype
(
at
::
kFloat
).
device
(
a
.
device
());
if
(
!
use_fp32_reduce
)
{
reduce_max_m
=
0
;
reduce_n
=
0
;
}
torch
::
Tensor
c_tmp
=
torch
::
empty
({
reduce_max_m
,
reduce_n
},
options_fp32
);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
// auto -1)
int
thread_k
=
-
1
;
int
thread_k
=
-
1
;
...
@@ -2171,20 +2264,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
...
@@ -2171,20 +2264,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
marlin
::
marlin_mm_f16i4
<
half
>
(
marlin
::
marlin_mm_f16i4
<
half
>
(
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
c_tmp
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
has_zp
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
marlin
::
marlin_mm_f16i4
<
nv_bfloat16
>
(
marlin
::
marlin_mm_f16i4
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BF
loat
16
>
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
f
loat
>
(),
b_
zero
s
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
b_
scale
s
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
has_zp
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
);
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_fp32_reduce
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
}
}
...
...
tests/kernels/test_marlin_gemm.py
View file @
75acdaa4
...
@@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
ACT_ORDER_OPTS
=
[
False
,
True
]
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
USE_FP32_REDUCE_OPTS
=
[
False
,
True
]
MARLIN_K_CHUNKS
=
[
128
]
MARLIN_K_CHUNKS
=
[
128
]
MARLIN_N_CHUNKS
=
[
64
,
128
,
256
]
MARLIN_N_CHUNKS
=
[
64
,
128
,
256
]
...
@@ -175,6 +176,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
...
@@ -175,6 +176,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
K_FULL_OPTS
)
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
K_FULL_OPTS
)
@
pytest
.
mark
.
parametrize
(
"use_fp32_reduce"
,
USE_FP32_REDUCE_OPTS
)
def
test_gptq_marlin_gemm
(
def
test_gptq_marlin_gemm
(
k_chunk
,
k_chunk
,
n_chunk
,
n_chunk
,
...
@@ -183,6 +185,7 @@ def test_gptq_marlin_gemm(
...
@@ -183,6 +185,7 @@ def test_gptq_marlin_gemm(
mnk_factors
,
mnk_factors
,
act_order
,
act_order
,
is_k_full
,
is_k_full
,
use_fp32_reduce
,
):
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
m_factor
,
n_factor
,
k_factor
=
mnk_factors
...
@@ -222,8 +225,9 @@ def test_gptq_marlin_gemm(
...
@@ -222,8 +225,9 @@ def test_gptq_marlin_gemm(
a_input
.
shape
[
0
],
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
is_k_full
=
is_k_full
,
has_zp
=
False
,
has_zp
=
False
,
use_fp32_reduce
=
use_fp32_reduce
,
)
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
...
@@ -365,12 +369,14 @@ def test_fp8_marlin_gemm(
...
@@ -365,12 +369,14 @@ def test_fp8_marlin_gemm(
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
MARLIN_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"use_fp32_reduce"
,
USE_FP32_REDUCE_OPTS
)
def
test_awq_marlin_gemm
(
def
test_awq_marlin_gemm
(
k_chunk
,
k_chunk
,
n_chunk
,
n_chunk
,
num_bits
,
num_bits
,
group_size
,
group_size
,
mnk_factors
,
mnk_factors
,
use_fp32_reduce
,
):
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
m_factor
,
n_factor
,
k_factor
=
mnk_factors
...
@@ -407,8 +413,9 @@ def test_awq_marlin_gemm(
...
@@ -407,8 +413,9 @@ def test_awq_marlin_gemm(
a_input
.
shape
[
0
],
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
is_k_full
=
is_k_full
,
has_zp
,
has_zp
=
has_zp
,
use_fp32_reduce
=
use_fp32_reduce
,
)
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
...
...
vllm/_custom_ops.py
View file @
75acdaa4
...
@@ -286,12 +286,12 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -286,12 +286,12 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales
:
torch
.
Tensor
,
b_zeros
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_zeros
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
has_zp
:
bool
,
has_zp
:
bool
)
->
torch
.
Tensor
:
use_fp32_reduce
:
bool
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
b_zeros
,
return
torch
.
ops
.
_C
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
b_zeros
,
g_idx
,
perm
,
workspace
,
num_bits
,
g_idx
,
perm
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
,
is_k_full
,
size_m
,
size_n
,
size_k
,
is_k_full
,
has_zp
)
has_zp
,
use_fp32_reduce
)
# fp8 marlin
# fp8 marlin
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
75acdaa4
...
@@ -16,6 +16,11 @@ GPTQ_MARLIN_MAX_PARALLEL = 16
...
@@ -16,6 +16,11 @@ GPTQ_MARLIN_MAX_PARALLEL = 16
MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
# In case there is a performance issue with Marlin, the variable below can be
# changed to False, which allows Marlin to perform global reductions in fp16
# precision (instead of fp32), and therefore, save on some memory movements.
USE_FP32_REDUCE_DEFAULT
=
True
def
_check_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
def
_check_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
min_capability
:
Optional
[
int
],
min_capability
:
Optional
[
int
],
...
@@ -244,7 +249,8 @@ def apply_gptq_marlin_linear(
...
@@ -244,7 +249,8 @@ def apply_gptq_marlin_linear(
output_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
is_k_full
:
bool
,
is_k_full
:
bool
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
...
@@ -260,7 +266,8 @@ def apply_gptq_marlin_linear(
...
@@ -260,7 +266,8 @@ def apply_gptq_marlin_linear(
size_n
=
output_size_per_partition
,
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
,
is_k_full
=
is_k_full
,
has_zp
=
False
)
has_zp
=
False
,
use_fp32_reduce
=
use_fp32_reduce
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
output
.
add_
(
bias
)
# In-place add
...
@@ -279,7 +286,8 @@ def apply_awq_marlin_linear(
...
@@ -279,7 +286,8 @@ def apply_awq_marlin_linear(
num_bits
:
int
,
num_bits
:
int
,
output_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
...
@@ -295,7 +303,8 @@ def apply_awq_marlin_linear(
...
@@ -295,7 +303,8 @@ def apply_awq_marlin_linear(
size_n
=
output_size_per_partition
,
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
True
,
is_k_full
=
True
,
has_zp
=
True
)
has_zp
=
True
,
use_fp32_reduce
=
use_fp32_reduce
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
output
.
add_
(
bias
)
# In-place add
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment