Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7038e8b8
Unverified
Commit
7038e8b8
authored
May 02, 2024
by
alexm-nm
Committed by
GitHub
May 02, 2024
Browse files
[Kernel] Support running GPTQ 8-bit models in Marlin (#4533)
parent
2a85f930
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
553 additions
and
324 deletions
+553
-324
csrc/ops.h
csrc/ops.h
+3
-1
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+377
-175
csrc/quantization/gptq_marlin/gptq_marlin.cuh
csrc/quantization/gptq_marlin/gptq_marlin.cuh
+2
-6
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
+90
-62
tests/models/test_gptq_marlin.py
tests/models/test_gptq_marlin.py
+9
-4
vllm/_custom_ops.py
vllm/_custom_ops.py
+8
-6
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+64
-70
No files found.
csrc/ops.h
View file @
7038e8b8
...
@@ -132,6 +132,7 @@ torch::Tensor gptq_marlin_gemm(
...
@@ -132,6 +132,7 @@ torch::Tensor gptq_marlin_gemm(
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_k
,
...
@@ -141,7 +142,8 @@ torch::Tensor gptq_marlin_repack(
...
@@ -141,7 +142,8 @@ torch::Tensor gptq_marlin_repack(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_k
,
int64_t
size_n
);
int64_t
size_n
,
int64_t
num_bits
);
#endif
#endif
void
squeezellm_gemm
(
void
squeezellm_gemm
(
...
...
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
7038e8b8
...
@@ -32,7 +32,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
...
@@ -32,7 +32,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{}
int
size_k
,
int
block_rows
)
{}
template
<
const
int
threads
,
// number of threads in a threadblock
template
<
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock
// dimension (batchsize) of the threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_n_blocks
,
// same for n dimension (output)
...
@@ -62,8 +63,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -62,8 +63,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
bool
is_k_full
)
{
int64_t
size_k
,
bool
is_k_full
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
return
torch
::
empty
({
1
,
1
});
...
@@ -114,11 +115,21 @@ template <int lut> __device__ inline int lop3(int a, int b, int c) {
...
@@ -114,11 +115,21 @@ template <int lut> __device__ inline int lop3(int a, int b, int c) {
return
res
;
return
res
;
}
}
// Constructs destination register by taking bytes from 2 sources (based on mask)
template
<
int
start_byte
,
int
mask
>
__device__
inline
uint32_t
prmt
(
uint32_t
a
)
{
uint32_t
res
;
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"n"
(
start_byte
),
"n"
(
mask
));
return
res
;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// values. We mostly follow the strategy in the link below, with some small
// changes:
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__
inline
FragB
dequant
(
int
q
)
{
__device__
inline
FragB
dequant
_4bit
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
const
int
EX
=
0x64006400
;
...
@@ -139,6 +150,24 @@ __device__ inline FragB dequant(int q) {
...
@@ -139,6 +150,24 @@ __device__ inline FragB dequant(int q) {
return
frag_b
;
return
frag_b
;
}
}
__device__
inline
FragB
dequant_8bit
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
uint32_t
lo
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_01
>
(
q
);
uint32_t
hi
=
prmt
<
start_byte_for_fp16
,
mask_for_elt_23
>
(
q
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
// only for grouped quantization.
__device__
inline
void
scale
(
FragB
&
frag_b
,
FragS
&
frag_s
,
int
i
)
{
__device__
inline
void
scale
(
FragB
&
frag_b
,
FragS
&
frag_s
,
int
i
)
{
...
@@ -162,6 +191,13 @@ __device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2,
...
@@ -162,6 +191,13 @@ __device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2,
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s_val_3_4
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s_val_3_4
);
}
}
// Given 2 floats multiply by 2 scales (halves)
__device__
inline
void
scale_float
(
float
*
c
,
FragS
&
s
)
{
__half
*
s_ptr
=
reinterpret_cast
<
__half
*>
(
&
s
);
c
[
0
]
=
__fmul_rn
(
c
[
0
],
__half2float
(
s_ptr
[
0
]));
c
[
1
]
=
__fmul_rn
(
c
[
1
],
__half2float
(
s_ptr
[
1
]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
...
@@ -250,7 +286,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
...
@@ -250,7 +286,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
}
}
}
}
template
<
const
int
threads
,
// number of threads in a threadblock
template
<
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock
// dimension (batchsize) of the threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_n_blocks
,
// same for n dimension (output)
...
@@ -286,6 +323,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -286,6 +323,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// configurations, while requiring as few slow global cross-threadblock
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
// reductions as possible.
constexpr
int
pack_factor
=
32
/
num_bits
;
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
// better partitioning with less reductions
int
parallel
=
1
;
int
parallel
=
1
;
...
@@ -385,19 +424,23 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -385,19 +424,23 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
constexpr
int
a_sh_wr_iters
=
div_ceil
(
a_sh_stage
,
a_sh_wr_delta
);
constexpr
int
a_sh_wr_iters
=
div_ceil
(
a_sh_stage
,
a_sh_wr_delta
);
// B sizes/strides
// B sizes/strides
int
b_gl_stride
=
16
*
prob_n
/
32
;
int
b_gl_stride
=
16
*
prob_n
/
(
pack_factor
*
4
);
constexpr
int
b_sh_stride
=
32
*
thread_n_blocks
/
4
;
constexpr
int
b_sh_stride
=
((
thread_n_blocks
*
16
)
*
16
/
pack_factor
)
/
4
;
constexpr
int
b_thread_vecs
=
num_bits
==
4
?
1
:
2
;
constexpr
int
b_sh_stride_threads
=
b_sh_stride
/
b_thread_vecs
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride
);
int
b_gl_rd_delta_i
=
b_gl_stride
*
(
threads
/
b_sh_stride
_threads
);
constexpr
int
b_sh_wr_delta
=
threads
;
constexpr
int
b_sh_wr_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_rd_delta
=
threads
;
constexpr
int
b_sh_rd_delta
=
threads
*
b_thread_vecs
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_stage
=
b_sh_stride
*
thread_k_blocks
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
// Scale sizes/strides without act_order
// Scale sizes/strides without act_order
int
s_gl_stride
=
prob_n
/
8
;
int
s_gl_stride
=
prob_n
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
8
;
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
<
thread_k_blocks
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
?
thread_k_blocks
/
group_blocks
?
thread_k_blocks
/
group_blocks
:
1
;
:
1
;
constexpr
int
s_sh_stage
=
s_tb_groups
*
s_sh_stride
;
constexpr
int
s_sh_stage
=
s_tb_groups
*
s_sh_stride
;
...
@@ -425,12 +468,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -425,12 +468,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
16
)
+
(
threadIdx
.
x
%
32
)
/
16
;
a_sh_stride
*
((
threadIdx
.
x
%
32
)
%
16
)
+
(
threadIdx
.
x
%
32
)
/
16
;
a_sh_rd
+=
2
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
a_sh_rd
+=
2
*
((
threadIdx
.
x
/
32
)
/
(
thread_n_blocks
/
4
));
int
b_gl_rd
=
int
b_gl_rd
=
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride
)
+
(
thread
Idx
.
x
%
b_sh_stride
)
;
(
threadIdx
.
x
%
b_sh_stride
_
thread
s
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
;
int
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
int
b_sh_rd
=
threadIdx
.
x
;
int
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
// For act_order
// For act_order
constexpr
int
k_iter_size
=
tb_k
/
b_sh_wr_iters
;
constexpr
int
k_iter_size
=
tb_k
/
b_sh_wr_iters
;
...
@@ -442,9 +485,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -442,9 +485,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// No act_order
// No act_order
int
s_gl_rd
;
int
s_gl_rd
;
if
constexpr
(
!
has_act_order
)
{
if
constexpr
(
!
has_act_order
)
{
if
constexpr
(
group_blocks
==
-
1
)
{
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
else
{
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
}
int
s_sh_wr
=
threadIdx
.
x
;
int
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
...
@@ -511,7 +558,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -511,7 +558,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Register storage for double buffer of shared memory reads.
// Register storage for double buffer of shared memory reads.
FragA
frag_a
[
2
][
thread_m_blocks
];
FragA
frag_a
[
2
][
thread_m_blocks
];
I4
frag_b_quant
[
2
];
I4
frag_b_quant
[
2
]
[
b_thread_vecs
]
;
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragC
frag_c
[
thread_m_blocks
][
4
][
2
];
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
frag_s
[
2
][
4
];
// No act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
FragS
act_frag_s
[
2
][
4
][
4
];
// For act-order
...
@@ -575,7 +622,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -575,7 +622,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
for
(
int
i
=
0
;
i
<
b_sh_wr_iters
;
i
++
)
{
cp_async4_stream
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
],
B_ptr
[
i
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
b_thread_vecs
;
j
++
)
{
cp_async4
(
&
sh_b_stage
[
b_sh_wr_delta
*
i
+
b_sh_wr
+
j
],
B_ptr
[
i
]
+
j
);
}
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
B_ptr
[
i
]
+=
b_gl_rd_delta_o
;
}
}
...
@@ -602,14 +653,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -602,14 +653,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Only fetch scales if this tile starts a new group
// Only fetch scales if this tile starts a new group
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
pipe
%
(
group_blocks
/
thread_k_blocks
)
==
0
)
{
if
(
s_sh_wr_pred
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
_stream
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
cp_async4
(
&
sh_s_stage
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
}
s_gl_rd
+=
s_gl_rd_delta
;
s_gl_rd
+=
s_gl_rd_delta
;
}
}
}
else
{
}
else
{
for
(
int
i
=
0
;
i
<
s_tb_groups
;
i
++
)
{
for
(
int
i
=
0
;
i
<
s_tb_groups
;
i
++
)
{
if
(
s_sh_wr_pred
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
_stream
(
&
sh_s_stage
[
i
*
s_sh_stride
+
s_sh_wr
],
cp_async4
(
&
sh_s_stage
[
i
*
s_sh_stride
+
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
&
scales_ptr
[
s_gl_rd
]);
}
}
s_gl_rd
+=
s_gl_rd_delta
;
s_gl_rd
+=
s_gl_rd_delta
;
...
@@ -641,14 +692,24 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -641,14 +692,24 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm4
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
ldsm4
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
frag_b_quant
[
k
%
2
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
]);
#pragma unroll
for
(
int
i
=
0
;
i
<
b_thread_vecs
;
i
++
)
{
frag_b_quant
[
k
%
2
][
i
]
=
*
reinterpret_cast
<
I4
*>
(
&
sh_b_stage
[
b_sh_rd_delta
*
(
k
%
b_sh_wr_iters
)
+
b_sh_rd
+
i
]);
}
};
};
bool
is_same_group
[
stages
];
bool
is_same_group
[
stages
];
int
same_group_id
[
stages
];
int
same_group_id
[
stages
];
auto
init_same_group
=
[
&
](
int
pipe
)
{
auto
init_same_group
=
[
&
](
int
pipe
)
{
if
constexpr
(
!
has_act_order
)
{
is_same_group
[
pipe
]
=
false
;
same_group_id
[
pipe
]
=
0
;
return
;
}
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int4
*
sh_g_idx_stage
=
sh_g_idx
+
g_idx_stage
*
pipe
;
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
int
*
sh_g_idx_int_ptr
=
reinterpret_cast
<
int
*>
(
sh_g_idx_stage
);
...
@@ -767,10 +828,23 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -767,10 +828,23 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// dequantization and matmul operations.
// dequantization and matmul operations.
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int
b_quant
=
frag_b_quant
[
k
%
2
][
j
];
FragB
frag_b0
;
FragB
frag_b1
;
if
constexpr
(
num_bits
==
4
)
{
int
b_quant
=
frag_b_quant
[
k
%
2
][
0
][
j
];
int
b_quant_shift
=
b_quant
>>
8
;
int
b_quant_shift
=
b_quant
>>
8
;
FragB
frag_b0
=
dequant
(
b_quant
);
frag_b0
=
dequant_4bit
(
b_quant
);
frag_b1
=
dequant_4bit
(
b_quant_shift
);
}
else
{
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
frag_b0
=
dequant_8bit
(
b_quant_0
);
frag_b1
=
dequant_8bit
(
b_quant_1
);
}
// Apply scale to frag_b0
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
if
constexpr
(
has_act_order
)
{
...
@@ -782,8 +856,6 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -782,8 +856,6 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
}
}
}
}
FragB
frag_b1
=
dequant
(
b_quant_shift
);
// Apply scale to frag_b1
// Apply scale to frag_b1
if
constexpr
(
has_act_order
)
{
if
constexpr
(
has_act_order
)
{
scale4
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
scale4
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
...
@@ -808,13 +880,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -808,13 +880,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// multiple warps that accumulate their partial sums of the same output
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
// location; which we have to reduce over in the end. We do in shared memory.
auto
thread_block_reduce
=
[
&
]()
{
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride
/
2
;
constexpr
int
red_off
=
threads
/
b_sh_stride
_threads
/
2
;
if
(
red_off
>=
1
)
{
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride
;
int
red_idx
=
threadIdx
.
x
/
b_sh_stride
_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride
*
4
*
2
;
constexpr
int
red_sh_stride
=
b_sh_stride
_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride
;
constexpr
int
red_sh_delta
=
b_sh_stride
_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride
)
+
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride
_threads
)
+
(
threadIdx
.
x
%
b_sh_stride
);
(
threadIdx
.
x
%
b_sh_stride
_threads
);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// unnecessary read or write iterations, e.g., for two warps we write only
...
@@ -861,7 +933,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -861,7 +933,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
};
};
// Since multiple threadblocks may process parts of the same column slice, we
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped p
or
tioning
// finally have to globally reduce over the results. As the striped p
arti
tioning
// minimizes the number of such reductions and our outputs are usually rather
// minimizes the number of such reductions and our outputs are usually rather
// small, we perform this reduction serially in L2 cache.
// small, we perform this reduction serially in L2 cache.
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
auto
global_reduce
=
[
&
](
bool
first
=
false
,
bool
last
=
false
)
{
...
@@ -951,13 +1023,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -951,13 +1023,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
half2
res
=
__halves2half2
(
__float2half
(
c0
),
__float2half
(
c1
));
half2
res
=
__halves2half2
(
__float2half
(
c0
),
__float2half
(
c1
));
// For per-column quantization we finally apply the scale here
// For per-column quantization we finally apply the scale here (only for
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
// 4-bit)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
num_bits
==
4
)
{
res
=
__hmul2
(
res
,
s
[
0
]);
res
=
__hmul2
(
res
,
s
[
0
]);
}
}
((
half2
*
)
sh
)[
idx
]
=
res
;
((
half2
*
)
sh
)[
idx
]
=
res
;
};
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
...
@@ -1023,6 +1097,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -1023,6 +1097,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// ensure all shared memory accesses are static. Note that both pipelines
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// have even length meaning that the next iteration will always start at
// index 0.
// index 0.
#pragma unroll
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
for
(
int
pipe
=
0
;
pipe
<
stages
;)
{
#pragma unroll
#pragma unroll
...
@@ -1070,16 +1145,32 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -1070,16 +1145,32 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// For per-column scales, we only fetch them here in the final step before
// For per-column scales, we only fetch them here in the final step before
// write-out
// write-out
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
num_bits
==
8
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
cp_async_fence
();
}
else
{
if
(
last
)
{
if
(
last
)
{
if
(
s_sh_wr_pred
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
_stream
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
}
cp_async_fence
();
cp_async_fence
();
}
}
}
}
}
thread_block_reduce
();
thread_block_reduce
();
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
num_bits
==
8
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
0
]
=
sh_s
[
s_sh_rd
+
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
)[
1
]
=
sh_s
[
s_sh_rd
+
4
];
}
}
else
{
if
(
last
)
{
if
(
last
)
{
cp_async_wait
<
0
>
();
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
...
@@ -1089,6 +1180,30 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -1089,6 +1180,30 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
}
}
}
}
}
}
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
num_bits
==
8
)
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
}
}
}
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
if
(
slice_count
>
1
)
{
// only globally reduce if there is more than one
// block in a slice
// block in a slice
...
@@ -1125,28 +1240,25 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -1125,28 +1240,25 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
s_gl_rd
=
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
// if (blockIdx.x == 0 && threadIdx.x == 0) {
// printf("Move\n");
// }
start_pipes
();
start_pipes
();
}
}
}
}
}
}
}
}
#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS,
\
#define __CALL_IF(
NUM_BITS,
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
else if (thread_m_blocks == THREAD_M_BLOCKS &&
\
else if (
num_bits == NUM_BITS &&
thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \
num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
cudaFuncSetAttribute( \
Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,
THREAD_K_BLOCKS,
\
Marlin<
NUM_BITS,
NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,
\
pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>,
\
THREAD_K_BLOCKS,
pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,
THREAD_K_BLOCKS,
\
Marlin<
NUM_BITS,
NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,
\
pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>
\
THREAD_K_BLOCKS,
pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
prob_k, locks); \
prob_k, locks); \
...
@@ -1158,28 +1270,92 @@ typedef struct {
...
@@ -1158,28 +1270,92 @@ typedef struct {
int
num_threads
;
int
num_threads
;
}
thread_config_t
;
}
thread_config_t
;
thread_config_t
small_batch_thread_configs
[]
=
{
typedef
struct
{
int
max_m_blocks
;
thread_config_t
tb_cfg
;
}
exec_config_t
;
thread_config_t
thread_configs
[]
=
{
// Ordered by priority
// Ordered by priority
// thread_k, thread_n, num_threads
// thread_k, thread_n, num_threads
{
128
,
128
,
256
},
// Default
{
64
,
256
,
256
},
// Default
(max cache usage)
{
128
,
64
,
128
},
// Reduce N
2X, same K
{
64
,
128
,
128
},
// Reduce N
, reduce warps
{
64
,
256
,
256
},
// Reduce
K 2X,
increase
N 2X
{
128
,
64
,
128
},
// Reduce
N more, but
increase
K
{
64
,
128
,
128
},
// Reduce K 2X, same N
};
};
thread_config_t
large_batch_thread_configs
[]
=
{
int
get_scales_cache_size
(
thread_config_t
const
&
th_config
,
int
prob_m
,
// Ordered by priority
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
)
{
bool
cache_scales_chunk
=
has_act_order
&&
!
is_k_full
;
// thread_k, thread_n, num_threads
int
tb_n
=
th_config
.
thread_n
;
{
64
,
256
,
256
},
// Default
int
tb_k
=
th_config
.
thread_k
;
{
128
,
64
,
128
},
// Reduce N 2X, same K
{
64
,
128
,
128
},
// Reduce N 2X, same K
// Get max scale groups per thread-block
// {128, 64, 128}, // Reduce N 4X, increase K 2X
int
tb_groups
;
};
if
(
group_size
==
-
1
)
{
tb_groups
=
1
;
}
else
if
(
group_size
==
0
)
{
tb_groups
=
div_ceil
(
tb_k
,
32
);
// Worst case is 32 group size
}
else
{
tb_groups
=
div_ceil
(
tb_k
,
group_size
);
}
if
(
cache_scales_chunk
)
{
int
load_groups
=
tb_groups
*
pipe_stages
*
2
;
// Chunk size is 2x pipeline over dim K
load_groups
=
max
(
load_groups
,
32
);
// We load at least 32 scale groups
return
load_groups
*
tb_n
*
2
;
}
else
{
int
tb_scales
=
tb_groups
*
tb_n
*
2
;
return
tb_scales
*
pipe_stages
;
}
}
bool
is_valid_cache_size
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
scales_cache_size
,
int
max_shared_mem
)
{
int
pack_factor
=
32
/
num_bits
;
// Get B size
int
tb_k
=
th_config
.
thread_k
;
int
tb_n
=
th_config
.
thread_n
;
int
b_size
=
(
tb_k
*
tb_n
/
pack_factor
)
*
4
;
// Get A size
int
m_blocks
=
div_ceil
(
prob_m
,
16
);
int
tb_max_m
=
16
;
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
prob_m
,
int
prob_n
,
while
(
true
)
{
int
prob_k
)
{
if
(
m_blocks
>=
max_m_blocks
)
{
tb_max_m
*=
max_m_blocks
;
break
;
}
max_m_blocks
--
;
if
(
max_m_blocks
==
0
)
{
TORCH_CHECK
(
false
,
"Unexpected m_blocks = "
,
m_blocks
);
}
}
int
a_size
=
(
tb_max_m
*
tb_k
)
*
2
;
float
pipe_size
=
(
a_size
+
b_size
)
*
pipe_stages
;
TORCH_CHECK
(
max_shared_mem
/
2
>
scales_cache_size
);
// Sanity
return
pipe_size
<
0.95
f
*
(
max_shared_mem
-
scales_cache_size
);
}
bool
is_valid_config
(
thread_config_t
const
&
th_config
,
int
max_m_blocks
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
bool
has_act_order
,
bool
is_k_full
,
int
max_shared_mem
)
{
// Sanity
// Sanity
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
if
(
th_config
.
thread_k
==
-
1
||
th_config
.
thread_n
==
-
1
||
th_config
.
num_threads
==
-
1
)
{
th_config
.
num_threads
==
-
1
)
{
...
@@ -1201,62 +1377,79 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
...
@@ -1201,62 +1377,79 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
return
false
;
return
false
;
}
}
// Determine cache for scales
int
scales_cache_size
=
get_scales_cache_size
(
th_config
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
);
// Check that pipeline fits into cache
if
(
!
is_valid_cache_size
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
scales_cache_size
,
max_shared_mem
))
{
return
false
;
}
return
true
;
return
true
;
}
}
thread_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
)
{
exec_config_t
determine_thread_config
(
int
prob_m
,
int
prob_n
,
int
prob_k
,
int
num_bits
,
int
group_size
,
// TODO: Enable if needed after some more testing
bool
has_act_order
,
bool
is_k_full
,
if
(
prob_m
<=
0
)
{
int
max_shared_mem
)
{
for
(
auto
th_config
:
small_batch_thread_configs
)
{
int
max_m_blocks
=
4
;
if
(
is_valid_config
(
th_config
,
prob_m
,
prob_n
,
prob_k
))
{
while
(
max_m_blocks
>
0
)
{
return
th_config
;
for
(
auto
th_config
:
thread_configs
)
{
if
(
is_valid_config
(
th_config
,
max_m_blocks
,
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
))
{
return
exec_config_t
{
max_m_blocks
,
th_config
};
}
}
}
}
}
else
{
printf
(
"WARNING: Marlin kernel is reducing max_m_blocks due to small SM "
for
(
auto
th_config
:
large_batch_thread_configs
)
{
"GPU cache. This may "
if
(
is_valid_config
(
th_config
,
prob_m
,
prob_n
,
prob_k
))
{
"hurt performance. Consider upgrading your GPU.
\n
"
);
return
th_config
;
}
max_m_blocks
--
;
// Process less M blocks per invocation to reduce cache
}
// usage
}
}
return
thread
_config_t
{
-
1
,
-
1
,
-
1
};
return
exec
_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}
}
;
}
}
#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS)
\
#define CALL_IF(
NUM_BITS,
N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
\
\
__CALL_IF(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
\
__CALL_IF(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
\
__CALL_IF(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
\
__CALL_IF(
NUM_BITS,
3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
\
__CALL_IF(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
void
marlin_cuda
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
void
*
g_idx
,
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
void
*
workspace
,
bool
has_act_order
,
bool
is_k_full
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
int
num_groups
,
int
group_size
,
int
dev
=
0
,
bool
has_act_order
,
bool
is_k_full
,
int
num_groups
,
cudaStream_t
stream
=
0
,
int
thread_k
=
-
1
,
int
thread_n
=
-
1
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
sms
=
-
1
,
int
max_par
=
16
)
{
int
thread_n
,
int
sms
,
int
max_par
)
{
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
TORCH_CHECK
(
prob_m
>
0
&&
prob_n
>
0
&&
prob_k
>
0
,
"Invalid MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
", "
,
prob_n
,
", "
,
prob_k
,
"]"
);
...
@@ -1274,25 +1467,34 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
...
@@ -1274,25 +1467,34 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
TORCH_CHECK
(
max_shared_mem
>
0
);
TORCH_CHECK
(
max_shared_mem
>
0
);
// Set thread config
// Set thread config
thread
_config_t
th_confi
g
;
exec
_config_t
exec_cf
g
;
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
if
(
thread_k
!=
-
1
&&
thread_n
!=
-
1
)
{
// User-defined config
// User-defined config
th_config
=
thread_config_t
{
thread_k
,
thread_n
,
default_threads
};
exec_cfg
=
exec_config_t
{
4
,
thread_config_t
{
thread_k
,
thread_n
,
default_threads
}};
}
else
{
}
else
{
// Auto config
// Auto config
th_config
=
determine_thread_config
(
prob_m
,
prob_n
,
prob_k
);
exec_cfg
=
}
determine_thread_config
(
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
has_act_order
,
is_k_full
,
max_shared_mem
);
TORCH_CHECK
(
is_valid_config
(
th_config
,
prob_m
,
prob_n
,
prob_k
),
}
"Invalid thread config: thread_k = "
+
str
(
th_config
.
thread_k
)
+
", thread_n = "
+
str
(
th_config
.
thread_n
)
+
TORCH_CHECK
(
exec_cfg
.
max_m_blocks
>
0
&&
", num_threads = "
+
str
(
th_config
.
num_threads
)
+
is_valid_config
(
exec_cfg
.
tb_cfg
,
exec_cfg
.
max_m_blocks
,
" for MKN = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_k
)
+
", "
+
prob_m
,
prob_n
,
prob_k
,
num_bits
,
group_size
,
str
(
prob_n
)
+
"]"
);
has_act_order
,
is_k_full
,
max_shared_mem
),
"Invalid thread config: max_m_blocks = "
,
exec_cfg
.
max_m_blocks
,
int
num_threads
=
th_config
.
num_threads
;
", thread_k = "
,
exec_cfg
.
tb_cfg
.
thread_k
,
thread_k
=
th_config
.
thread_k
;
", thread_n = "
,
exec_cfg
.
tb_cfg
.
thread_n
,
thread_n
=
th_config
.
thread_n
;
", num_threads = "
,
exec_cfg
.
tb_cfg
.
num_threads
,
" for MKN = ["
,
prob_m
,
", "
,
prob_k
,
", "
,
prob_n
,
"] and num_bits = "
,
num_bits
,
", group_size = "
,
group_size
,
", has_act_order = "
,
has_act_order
,
", is_k_full = "
,
is_k_full
,
", max_shared_mem = "
,
max_shared_mem
);
int
num_threads
=
exec_cfg
.
tb_cfg
.
num_threads
;
thread_k
=
exec_cfg
.
tb_cfg
.
thread_k
;
thread_n
=
exec_cfg
.
tb_cfg
.
thread_n
;
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_k_blocks
=
thread_k
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
int
thread_n_blocks
=
thread_n
/
16
;
...
@@ -1352,28 +1554,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
...
@@ -1352,28 +1554,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
}
}
// Main loop
// Main loop
for
(
int
i
=
0
;
i
<
tot_m_blocks
;
i
+=
4
)
{
for
(
int
i
=
0
;
i
<
tot_m_blocks
;
i
+=
exec_cfg
.
max_m_blocks
)
{
int
thread_m_blocks
=
tot_m_blocks
-
i
;
int
thread_m_blocks
=
tot_m_blocks
-
i
;
prob_m
=
tot_m
-
16
*
i
;
prob_m
=
tot_m
-
16
*
i
;
int
par
=
1
;
int
par
=
1
;
if
(
thread_m_blocks
>
4
)
{
if
(
thread_m_blocks
>
exec_cfg
.
max_m_blocks
)
{
// Note that parallel > 1 currently only works for inputs without any
// Note that parallel > 1 currently only works for inputs without any
// padding
// padding
par
=
(
16
*
thread_m_blocks
-
pad
)
/
64
;
par
=
(
16
*
thread_m_blocks
-
pad
)
/
(
16
*
exec_cfg
.
max_m_blocks
)
;
if
(
par
>
max_par
)
if
(
par
>
max_par
)
par
=
max_par
;
par
=
max_par
;
prob_m
=
64
*
par
;
prob_m
=
(
16
*
exec_cfg
.
max_m_blocks
)
*
par
;
i
+=
4
*
(
par
-
1
);
i
+=
exec_cfg
.
max_m_blocks
*
(
par
-
1
);
thread_m_blocks
=
4
;
thread_m_blocks
=
exec_cfg
.
max_m_blocks
;
}
}
// Define kernel configurations
// Define kernel configurations
if
(
false
)
{
if
(
false
)
{
}
}
CALL_IF
(
16
,
4
,
256
)
CALL_IF
(
4
,
32
,
2
,
256
)
CALL_IF
(
8
,
8
,
256
)
CALL_IF
(
4
,
16
,
4
,
256
)
CALL_IF
(
8
,
4
,
128
)
CALL_IF
(
4
,
8
,
4
,
128
)
CALL_IF
(
4
,
8
,
128
)
CALL_IF
(
4
,
4
,
8
,
128
)
CALL_IF
(
8
,
32
,
2
,
256
)
CALL_IF
(
8
,
16
,
4
,
256
)
CALL_IF
(
8
,
8
,
4
,
128
)
CALL_IF
(
8
,
4
,
8
,
128
)
else
{
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
+
str
(
prob_m
)
+
", "
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
str
(
prob_n
)
+
", "
+
str
(
prob_k
)
+
"]"
+
...
@@ -1395,33 +1601,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
...
@@ -1395,33 +1601,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
bool
is_k_full
)
{
int64_t
size_k
,
bool
is_k_full
)
{
// Verify num_bits
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
pack_factor
=
32
/
num_bits
;
// Verify A
// Verify A
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
TORCH_CHECK
(
a
.
size
(
0
)
==
size_m
,
"Shape mismatch: a.size(0) = "
,
a
.
size
(
0
),
"Shape mismatch: a.size(0) = "
+
str
(
a
.
size
(
0
))
+
", size_m = "
,
size_m
);
", size_m = "
+
str
(
size_m
));
TORCH_CHECK
(
a
.
size
(
1
)
==
size_k
,
"Shape mismatch: a.size(1) = "
,
a
.
size
(
1
),
TORCH_CHECK
(
a
.
size
(
1
)
==
size_k
,
", size_k = "
,
size_k
);
"Shape mismatch: a.size(1) = "
+
str
(
a
.
size
(
1
))
+
", size_k = "
+
str
(
size_k
));
// Verify B
// Verify B
TORCH_CHECK
(
size_k
%
gptq_marlin
::
tile_size
==
0
,
TORCH_CHECK
(
size_k
%
gptq_marlin
::
tile_size
==
0
,
"size_k = "
,
size_k
,
"size_k = "
+
str
(
size_k
)
+
" is not divisible by tile_size = "
+
" is not divisible by tile_size = "
,
gptq_marlin
::
tile_size
);
str
(
gptq_marlin
::
tile_size
));
TORCH_CHECK
((
size_k
/
gptq_marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
TORCH_CHECK
((
size_k
/
gptq_marlin
::
tile_size
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
+
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
str
(
b_q_weight
.
size
(
0
))
+
", size_k = "
+
str
(
size_k
)
+
", size_k = "
,
size_k
,
", tile_size = "
,
gptq_marlin
::
tile_size
);
", tile_size = "
+
str
(
gptq_marlin
::
tile_size
));
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
%
gptq_marlin
::
tile_size
==
0
,
TORCH_CHECK
(
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
b_q_weight
.
size
(
1
)
%
gptq_marlin
::
tile_size
==
0
,
" is not divisible by tile_size = "
,
gptq_marlin
::
tile_size
);
"b_q_weight.size(1) = "
+
str
(
b_q_weight
.
size
(
1
))
+
int
actual_size_n
=
" is not divisible by tile_size = "
+
str
(
gptq_marlin
::
tile_size
));
(
b_q_weight
.
size
(
1
)
/
gptq_marlin
::
tile_size
)
*
pack_factor
;
int
actual_size_n
=
(
b_q_weight
.
size
(
1
)
/
gptq_marlin
::
tile_size
)
*
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
,
size_n
,
gptq_marlin
::
pack_factor_4bit
;
", actual_size_n = "
,
actual_size_n
);
TORCH_CHECK
(
size_n
==
actual_size_n
,
"size_n = "
+
str
(
size_n
)
+
", actual_size_n = "
+
str
(
actual_size_n
));
// Verify device and strides
// Verify device and strides
TORCH_CHECK
(
a
.
device
().
is_cuda
(),
"A is not on GPU"
);
TORCH_CHECK
(
a
.
device
().
is_cuda
(),
"A is not on GPU"
);
...
@@ -1457,9 +1662,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
...
@@ -1457,9 +1662,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
// Verify g_idx and perm
// Verify g_idx and perm
TORCH_CHECK
((
g_idx
.
size
(
0
)
==
0
&&
perm
.
size
(
0
)
==
0
)
||
TORCH_CHECK
((
g_idx
.
size
(
0
)
==
0
&&
perm
.
size
(
0
)
==
0
)
||
(
g_idx
.
size
(
0
)
==
size_k
&&
perm
.
size
(
0
)
==
size_k
),
(
g_idx
.
size
(
0
)
==
size_k
&&
perm
.
size
(
0
)
==
size_k
),
"Unexpected g_idx.size(0) = "
+
str
(
g_idx
.
size
(
0
)
)
+
"Unexpected g_idx.size(0) = "
,
g_idx
.
size
(
0
)
,
" and perm.size(0) = "
+
str
(
perm
.
size
(
0
)
)
+
" and perm.size(0) = "
,
perm
.
size
(
0
)
,
", where size_k = "
+
str
(
size_k
)
)
;
", where size_k = "
,
size_k
);
// Detect groupsize and act_order
// Detect groupsize and act_order
int
num_groups
=
-
1
;
int
num_groups
=
-
1
;
...
@@ -1475,9 +1680,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
...
@@ -1475,9 +1680,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
if
(
has_act_order
)
{
if
(
has_act_order
)
{
if
(
is_k_full
)
{
if
(
is_k_full
)
{
TORCH_CHECK
(
num_groups
>
1
,
"For act_order, num_groups must be > 1"
);
TORCH_CHECK
(
num_groups
>
1
,
"For act_order, num_groups must be > 1"
);
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
"size_k = "
+
str
(
size_k
)
+
", is not divisible by num_groups = "
,
num_groups
);
", is not divisible by num_groups = "
+
str
(
num_groups
));
group_size
=
size_k
/
num_groups
;
group_size
=
size_k
/
num_groups
;
}
else
{
}
else
{
group_size
=
0
;
group_size
=
0
;
...
@@ -1485,10 +1689,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
...
@@ -1485,10 +1689,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
}
else
{
}
else
{
if
(
num_groups
>
1
)
{
if
(
num_groups
>
1
)
{
TORCH_CHECK
(
size_k
%
num_groups
==
0
,
TORCH_CHECK
(
"size_k = "
+
str
(
size_k
)
+
size_k
%
num_groups
==
0
,
"size_k = "
,
size_k
,
", is not divisible by b_scales.size(0) = "
+
", is not divisible by b_scales.size(0) = "
,
b_scales
.
size
(
0
));
str
(
b_scales
.
size
(
0
)));
group_size
=
size_k
/
num_groups
;
group_size
=
size_k
/
num_groups
;
}
else
{
}
else
{
group_size
=
-
1
;
group_size
=
-
1
;
...
@@ -1496,23 +1699,22 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
...
@@ -1496,23 +1699,22 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
}
}
// Verify workspace size
// Verify workspace size
TORCH_CHECK
(
size_n
%
gptq_marlin
::
min_thread_n
==
0
,
TORCH_CHECK
(
"size_n = "
+
str
(
size_n
)
+
size_n
%
gptq_marlin
::
min_thread_n
==
0
,
"size_n = "
,
size_n
,
", is not divisible by min_thread_n = "
+
", is not divisible by min_thread_n = "
,
gptq_marlin
::
min_thread_n
);
str
(
gptq_marlin
::
min_thread_n
));
int
min_workspace_size
=
int
min_workspace_size
=
(
size_n
/
gptq_marlin
::
min_thread_n
)
*
gptq_marlin
::
max_par
;
(
size_n
/
gptq_marlin
::
min_thread_n
)
*
gptq_marlin
::
max_par
;
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
TORCH_CHECK
(
workspace
.
numel
()
>=
min_workspace_size
,
"workspace.numel = "
+
str
(
workspace
.
numel
()
)
+
"workspace.numel = "
,
workspace
.
numel
()
,
" is below min_workspace_size = "
+
str
(
min_workspace_size
)
)
;
" is below min_workspace_size = "
,
min_workspace_size
);
int
dev
=
a
.
get_device
();
int
dev
=
a
.
get_device
();
gptq_marlin
::
marlin_
cuda
(
gptq_marlin
::
marlin_
mm_f16i4
(
a
.
data_ptr
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
(),
b_scales
.
data_ptr
(),
a
.
data_ptr
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
(),
b_scales
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
size_m
,
size_n
,
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
has_act_order
,
is_k_full
,
num_groups
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
sms
,
gptq_marlin
::
max_par
);
thread_k
,
thread_n
,
sms
,
gptq_marlin
::
max_par
);
return
c
;
return
c
;
}
}
...
...
csrc/quantization/gptq_marlin/gptq_marlin.cuh
View file @
7038e8b8
...
@@ -24,8 +24,6 @@ static constexpr int min_thread_k = 64;
...
@@ -24,8 +24,6 @@ static constexpr int min_thread_k = 64;
static
constexpr
int
tile_size
=
16
;
static
constexpr
int
tile_size
=
16
;
static
constexpr
int
max_par
=
16
;
static
constexpr
int
max_par
=
16
;
static
constexpr
int
pack_factor_4bit
=
8
;
// We have 8 4-bit vals inside a 32 bit
template
<
typename
T
,
int
n
>
template
<
typename
T
,
int
n
>
struct
Vec
{
struct
Vec
{
T
elems
[
n
];
T
elems
[
n
];
...
@@ -51,13 +49,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
...
@@ -51,13 +49,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
}
__device__
inline
void
cp_async4
_stream
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
__device__
inline
void
cp_async4
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
const
int
BYTES
=
16
;
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
asm
volatile
(
"{
\n
"
" .reg .b64 p;
\n
"
" cp.async.cg.shared.global [%0], [%1], %2;
\n
"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;
\n
"
"}
\n
"
::
"r"
(
smem
),
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
}
...
...
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
View file @
7038e8b8
...
@@ -11,7 +11,7 @@ static constexpr int tile_n_size = tile_k_size * 4;
...
@@ -11,7 +11,7 @@ static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template
<
int
const
num_threads
,
bool
const
has_perm
>
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
__global__
void
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
...
@@ -20,7 +20,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
...
@@ -20,7 +20,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
}
// namespace gptq_marlin
}
// namespace gptq_marlin
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
)
{
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
return
torch
::
empty
({
1
,
1
});
...
@@ -28,11 +29,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
...
@@ -28,11 +29,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
#else
#else
template
<
int
const
num_threads
,
bool
const
has_perm
>
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
__global__
void
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{
constexpr
int
pack_factor
=
32
/
num_bits
;
int
k_tiles
=
size_k
/
tile_k_size
;
int
k_tiles
=
size_k
/
tile_k_size
;
int
n_tiles
=
size_n
/
tile_n_size
;
int
n_tiles
=
size_n
/
tile_n_size
;
int
block_k_tiles
=
div_ceil
(
k_tiles
,
gridDim
.
x
);
int
block_k_tiles
=
div_ceil
(
k_tiles
,
gridDim
.
x
);
...
@@ -64,9 +67,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
...
@@ -64,9 +67,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
sh_pipe_ptr
+=
perm_size
;
sh_pipe_ptr
+=
perm_size
;
}
}
constexpr
int
tile_ints
=
tile_k_size
/
pack_factor
;
constexpr
int
stage_n_threads
=
tile_n_size
/
4
;
constexpr
int
stage_n_threads
=
tile_n_size
/
4
;
constexpr
int
stage_k_threads
=
constexpr
int
stage_k_threads
=
has_perm
?
tile_k_size
:
tile_ints
;
has_perm
?
tile_k_size
:
tile_k_size
/
pack_factor_4bit
;
constexpr
int
stage_size
=
stage_k_threads
*
stage_n_threads
;
constexpr
int
stage_size
=
stage_k_threads
*
stage_n_threads
;
auto
load_perm_to_shared
=
[
&
](
int
k_tile_id
)
{
auto
load_perm_to_shared
=
[
&
](
int
k_tile_id
)
{
...
@@ -99,9 +103,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
...
@@ -99,9 +103,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
reinterpret_cast
<
uint32_t
const
*>
(
sh_perm_ptr
);
reinterpret_cast
<
uint32_t
const
*>
(
sh_perm_ptr
);
int
src_k
=
sh_perm_int_ptr
[
k_id
];
int
src_k
=
sh_perm_int_ptr
[
k_id
];
int
src_k_packed
=
src_k
/
pack_factor
_4bit
;
int
src_k_packed
=
src_k
/
pack_factor
;
cp_async4
_stream
(
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[
src_k_packed
*
size_n
+
first_n
+
(
n_id
*
4
)])));
b_q_weight_ptr
[
src_k_packed
*
size_n
+
first_n
+
(
n_id
*
4
)])));
...
@@ -113,9 +117,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
...
@@ -113,9 +117,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
int
first_k
=
k_tile_id
*
tile_k_size
;
int
first_k
=
k_tile_id
*
tile_k_size
;
int
first_k_packed
=
first_k
/
pack_factor
_4bit
;
int
first_k_packed
=
first_k
/
pack_factor
;
cp_async4
_stream
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[(
first_k_packed
+
k_id
)
*
size_n
+
&
(
b_q_weight_ptr
[(
first_k_packed
+
k_id
)
*
size_n
+
first_n
+
(
n_id
*
4
)])));
first_n
+
(
n_id
*
4
)])));
...
@@ -145,26 +149,27 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
...
@@ -145,26 +149,27 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int
cur_n
=
warp_id
*
16
+
tc_col
;
int
cur_n
=
warp_id
*
16
+
tc_col
;
constexpr
int
sh_stride
=
64
;
constexpr
int
sh_stride
=
64
;
constexpr
uint32_t
mask
=
(
1
<<
num_bits
)
-
1
;
int4
*
sh_stage_ptr
=
sh_pipe_ptr
+
stage_size
*
pipe
;
int4
*
sh_stage_ptr
=
sh_pipe_ptr
+
stage_size
*
pipe
;
uint32_t
*
sh_stage_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_stage_ptr
);
uint32_t
*
sh_stage_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_stage_ptr
);
uint32_t
*
sh_perm_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_perm_ptr
);
uint32_t
*
sh_perm_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_perm_ptr
);
uint32_t
vals
[
pack_factor_4bit
];
uint32_t
vals
[
8
];
if
constexpr
(
has_perm
)
{
if
constexpr
(
has_perm
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
k_idx
=
tc_row
+
tc_offsets
[
i
];
int
k_idx
=
tc_row
+
tc_offsets
[
i
];
uint32_t
src_k
=
sh_perm_int_ptr
[
k_idx
];
uint32_t
src_k
=
sh_perm_int_ptr
[
k_idx
];
uint32_t
src_k_pos
=
src_k
%
pack_factor
_4bit
;
uint32_t
src_k_pos
=
src_k
%
pack_factor
;
uint32_t
b1_val
=
sh_stage_int_ptr
[
k_idx
*
sh_stride
+
cur_n
];
uint32_t
b1_val
=
sh_stage_int_ptr
[
k_idx
*
sh_stride
+
cur_n
];
uint32_t
b1_cur_val
=
(
b1_val
>>
(
src_k_pos
*
4
))
&
0xf
;
uint32_t
b1_cur_val
=
(
b1_val
>>
(
src_k_pos
*
num_bits
))
&
mask
;
uint32_t
b2_val
=
sh_stage_int_ptr
[
k_idx
*
sh_stride
+
cur_n
+
8
];
uint32_t
b2_val
=
sh_stage_int_ptr
[
k_idx
*
sh_stride
+
cur_n
+
8
];
uint32_t
b2_cur_val
=
(
b2_val
>>
(
src_k_pos
*
4
))
&
0xf
;
uint32_t
b2_cur_val
=
(
b2_val
>>
(
src_k_pos
*
num_bits
))
&
mask
;
vals
[
i
]
=
b1_cur_val
;
vals
[
i
]
=
b1_cur_val
;
vals
[
4
+
i
]
=
b2_cur_val
;
vals
[
4
+
i
]
=
b2_cur_val
;
...
@@ -172,41 +177,56 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
...
@@ -172,41 +177,56 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
}
else
{
}
else
{
uint32_t
b1_val_1
=
sh_stage_int_ptr
[
cur_n
];
uint32_t
b1_vals
[
tile_ints
];
uint32_t
b1_val_2
=
sh_stage_int_ptr
[
sh_stride
+
cur_n
];
uint32_t
b2_vals
[
tile_ints
];
uint32_t
b2_val_1
=
sh_stage_int_ptr
[
cur_n
+
8
];
uint32_t
b2_val_2
=
sh_stage_int_ptr
[
sh_stride
+
cur_n
+
8
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
for
(
int
i
=
0
;
i
<
tile_ints
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
b1_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
sh_stride
*
i
];
vals
[
i
]
=
(
b1_val_1
>>
(
cur_elem
*
4
))
&
0xf
;
b2_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
8
+
sh_stride
*
i
];
vals
[
4
+
i
]
=
(
b2_val_1
>>
(
cur_elem
*
4
))
&
0xf
;
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
2
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
]
-
8
;
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
vals
[
i
]
=
(
b1_val_2
>>
(
cur_elem
*
4
))
&
0xf
;
int
cur_int
=
cur_elem
/
pack_factor
;
vals
[
4
+
i
]
=
(
b2_val_2
>>
(
cur_elem
*
4
))
&
0xf
;
int
cur_pos
=
cur_elem
%
pack_factor
;
vals
[
i
]
=
(
b1_vals
[
cur_int
]
>>
(
cur_pos
*
num_bits
))
&
mask
;
vals
[
4
+
i
]
=
(
b2_vals
[
cur_int
]
>>
(
cur_pos
*
num_bits
))
&
mask
;
}
}
}
}
constexpr
int
tile_size
=
tile_k_size
*
tile_n_size
/
pack_factor
;
int
out_offset
=
(
k_tile_id
*
n_tiles
+
n_tile_id
)
*
tile_size
;
// Result of:
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
constexpr
int
pack_idx
[
pack_factor_4bit
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
if
constexpr
(
num_bits
==
4
)
{
constexpr
int
pack_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
uint32_t
res
=
0
;
uint32_t
res
=
0
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_factor_4bit
;
i
++
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
}
constexpr
int
tile_size
=
tile_k_size
*
tile_n_size
/
pack_factor_4bit
;
int
out_offset
=
(
k_tile_id
*
n_tiles
+
n_tile_id
)
*
tile_size
;
out_ptr
[
out_offset
+
th_id
*
4
+
warp_id
]
=
res
;
out_ptr
[
out_offset
+
th_id
*
4
+
warp_id
]
=
res
;
}
else
{
constexpr
int
pack_idx
[
4
]
=
{
0
,
2
,
1
,
3
};
uint32_t
res1
=
0
;
uint32_t
res2
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
res1
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
8
);
res2
|=
vals
[
4
+
pack_idx
[
i
]]
<<
(
i
*
8
);
}
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
0
]
=
res1
;
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
1
]
=
res2
;
}
};
};
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
...
@@ -242,19 +262,35 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
...
@@ -242,19 +262,35 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
}
// namespace gptq_marlin
}
// namespace gptq_marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
)
{
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
// Verify compatibility with marlin tile of 16x64
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK
(
size_k
%
gptq_marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
TORCH_CHECK
(
size_k
%
gptq_marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_k_size = "
,
gptq_marlin
::
tile_k_size
);
" is not divisible by tile_k_size = "
,
gptq_marlin
::
tile_k_size
);
TORCH_CHECK
(
size_n
%
gptq_marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
TORCH_CHECK
(
size_n
%
gptq_marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
" is not divisible by tile_n_size = "
,
gptq_marlin
::
tile_n_size
);
" is not divisible by tile_n_size = "
,
gptq_marlin
::
tile_n_size
);
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
const
pack_factor
=
32
/
num_bits
;
// Verify B
// Verify B
TORCH_CHECK
((
size_k
/
gptq_marlin
::
pack_factor
_4bit
)
==
b_q_weight
.
size
(
0
),
TORCH_CHECK
((
size_k
/
pack_factor
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", size_k = "
,
size_k
,
", pack_factor = "
,
pack_factor
);
", pack_factor_4bit = "
,
gptq_marlin
::
pack_factor_4bit
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
==
size_n
,
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
==
size_n
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not size_n = "
,
size_n
);
" is not size_n = "
,
size_n
);
...
@@ -273,9 +309,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
...
@@ -273,9 +309,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
auto
options
=
torch
::
TensorOptions
()
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
b_q_weight
.
dtype
())
.
dtype
(
b_q_weight
.
dtype
())
.
device
(
b_q_weight
.
device
());
.
device
(
b_q_weight
.
device
());
torch
::
Tensor
out
=
torch
::
empty
(
torch
::
Tensor
out
=
{
size_k
/
gptq_marlin
::
tile_size
,
torch
::
empty
(
{
size_k
/
gptq_marlin
::
tile_size
,
size_n
*
gptq_marlin
::
tile_size
/
gptq_marlin
::
pack_factor
_4bit
},
size_n
*
gptq_marlin
::
tile_size
/
pack_factor
},
options
);
options
);
// Detect if there is act_order
// Detect if there is act_order
...
@@ -299,23 +335,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
...
@@ -299,23 +335,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
TORCH_CHECK
(
max_shared_mem
>
0
);
if
(
has_perm
)
{
if
(
false
)
{
cudaFuncSetAttribute
(
}
gptq_marlin
::
marlin_repack_kernel
<
gptq_marlin
::
repack_threads
,
true
>
,
CALL_IF
(
4
,
false
)
cudaFuncAttributeMaxDynamicSharedMemorySize
,
CALL_IF
(
4
,
true
)
max_shared_mem
);
CALL_IF
(
8
,
false
)
gptq_marlin
::
marlin_repack_kernel
<
gptq_marlin
::
repack_threads
,
true
>
CALL_IF
(
8
,
true
)
<<<
blocks
,
gptq_marlin
::
repack_threads
,
max_shared_mem
,
else
{
stream
>>>
(
b_q_weight_ptr
,
perm_ptr
,
out_ptr
,
size_k
,
size_n
);
TORCH_CHECK
(
false
,
"Unsupported repack config: num_bits = "
,
num_bits
,
", has_perm = "
,
has_perm
);
}
else
{
cudaFuncSetAttribute
(
gptq_marlin
::
marlin_repack_kernel
<
gptq_marlin
::
repack_threads
,
false
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
max_shared_mem
);
gptq_marlin
::
marlin_repack_kernel
<
gptq_marlin
::
repack_threads
,
false
>
<<<
blocks
,
gptq_marlin
::
repack_threads
,
max_shared_mem
,
stream
>>>
(
b_q_weight_ptr
,
perm_ptr
,
out_ptr
,
size_k
,
size_n
);
}
}
return
out
;
return
out
;
...
...
tests/models/test_gptq_marlin.py
View file @
7038e8b8
...
@@ -39,6 +39,13 @@ MODELS = [
...
@@ -39,6 +39,13 @@ MODELS = [
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
"gptq-4bit-64g-actorder_True"
),
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
"gptq-4bit-64g-actorder_True"
),
# act_order==True, group_size=32
# act_order==True, group_size=32
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
"gptq-4bit-32g-actorder_True"
),
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
"gptq-4bit-32g-actorder_True"
),
# 8-bit, act_order==True, group_size=channelwise
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
"gptq-8bit--1g-actorder_True"
),
# 8-bit, act_order==True, group_size=128
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
"gptq-8bit-128g-actorder_True"
),
# 8-bit, act_order==True, group_size=32
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
"gptq-8bit-32g-actorder_True"
),
]
]
...
@@ -65,8 +72,7 @@ def test_models(
...
@@ -65,8 +72,7 @@ def test_models(
dtype
=
dtype
,
dtype
=
dtype
,
quantization
=
"marlin"
,
quantization
=
"marlin"
,
max_model_len
=
MAX_MODEL_LEN
,
max_model_len
=
MAX_MODEL_LEN
,
tensor_parallel_size
=
1
,
tensor_parallel_size
=
1
)
disable_custom_all_reduce
=
True
)
gptq_marlin_outputs
=
gptq_marlin_model
.
generate_greedy_logprobs
(
gptq_marlin_outputs
=
gptq_marlin_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
...
@@ -78,8 +84,7 @@ def test_models(
...
@@ -78,8 +84,7 @@ def test_models(
dtype
=
dtype
,
dtype
=
dtype
,
quantization
=
"gptq"
,
quantization
=
"gptq"
,
max_model_len
=
MAX_MODEL_LEN
,
max_model_len
=
MAX_MODEL_LEN
,
tensor_parallel_size
=
1
,
tensor_parallel_size
=
1
)
disable_custom_all_reduce
=
True
)
gptq_outputs
=
gptq_model
.
generate_greedy_logprobs
(
example_prompts
,
gptq_outputs
=
gptq_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
max_tokens
,
num_logprobs
)
num_logprobs
)
...
...
vllm/_custom_ops.py
View file @
7038e8b8
...
@@ -169,18 +169,20 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
...
@@ -169,18 +169,20 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
# gptq_marlin
# gptq_marlin
def
gptq_marlin_repack
(
b_q_weight
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
def
gptq_marlin_repack
(
b_q_weight
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
)
->
torch
.
Tensor
:
size_k
:
int
,
size_n
:
int
,
return
vllm_ops
.
gptq_marlin_repack
(
b_q_weight
,
perm
,
size_k
,
size_n
)
num_bits
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_repack
(
b_q_weight
,
perm
,
size_k
,
size_n
,
num_bits
)
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_n
:
int
,
size_k
:
int
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
)
->
torch
.
Tensor
:
is_k_full
:
bool
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
g_idx
,
perm
,
return
vllm_ops
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
g_idx
,
perm
,
workspace
,
size_m
,
size_n
,
size_k
,
workspace
,
num_bits
,
size_m
,
size_n
,
is_k_full
)
size_k
,
is_k_full
)
# fp8
# fp8
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
7038e8b8
...
@@ -2,7 +2,6 @@ import enum
...
@@ -2,7 +2,6 @@ import enum
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
import
numpy
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
@@ -17,41 +16,13 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
...
@@ -17,41 +16,13 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
]
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
# Precompute permutations for Marlin weight and scale shuffling
# Permutations for Marlin scale shuffling
#
def
get_scale_perms
(
num_bits
):
# Marlin works on [16,64] tiles. The goal of the permutations
# is to reorder the weight data so that it is compatible
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the
# kernel will get the data as it is needed for tensor-core
# (without the need to use ldmatrix instructions)
def
_get_perms
():
perm
=
[]
for
i
in
range
(
32
):
perm1
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm
)
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
perm
=
perm
.
reshape
((
-
1
,
8
))[:,
interleave
].
ravel
()
# type: ignore
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
=
[]
scale_perm
=
[]
for
i
in
range
(
8
):
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
...
@@ -59,23 +30,21 @@ def _get_perms():
...
@@ -59,23 +30,21 @@ def _get_perms():
for
i
in
range
(
4
):
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
perm
,
scale_perm
,
scale_perm_single
return
scale_perm
,
scale_perm_single
_perm
,
_scale_perm
,
_scale_perm_single
=
_get_perms
()
def
get_pack_factor
(
num_bits
):
def
get_pack_factor
(
num_bits
):
assert
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
,
(
assert
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
f
"Unsupported num_bits =
{
num_bits
}
"
)
),
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
return
32
//
num_bits
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
):
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
num_bits
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
_
scale_perm
)))[:,
_
scale_perm
]
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
else
:
s
=
s
.
reshape
((
-
1
,
len
(
_
scale_perm_single
)))[:,
_
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
return
s
...
@@ -279,13 +248,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -279,13 +248,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad
=
False
,
requires_grad
=
False
,
)
)
set_weight_attrs
(
set_weight_attrs
(
qweight
,
{
qweight
,
{
**
extra_weight_attrs
,
**
extra_weight_attrs
,
"input_dim"
:
0
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
},
)
# Activation order
# Activation order
g_idx
=
Parameter
(
g_idx
=
Parameter
(
...
@@ -296,10 +267,13 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -296,10 +267,13 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad
=
False
,
requires_grad
=
False
,
)
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
set_weight_attrs
(
g_idx
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
**
extra_weight_attrs
,
"input_dim"
:
0
,
"ignore_warning"
:
True
"ignore_warning"
:
True
})
},
)
g_idx_sort_indices
=
Parameter
(
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
torch
.
empty
(
...
@@ -320,29 +294,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -320,29 +294,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad
=
False
,
requires_grad
=
False
,
)
)
set_weight_attrs
(
set_weight_attrs
(
scales
,
{
scales
,
{
**
extra_weight_attrs
,
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
"output_dim"
:
1
,
})
},
)
# Quantized zero-points
# Quantized zero-points
qzeros
=
Parameter
(
qzeros
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
torch
.
empty
(
output_size_per_partition
//
scales_and_zp_size
,
self
.
quant_config
.
pack_factor
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"meta"
),
device
=
"meta"
,
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
set_weight_attrs
(
set_weight_attrs
(
qzeros
,
{
qzeros
,
{
**
extra_weight_attrs
,
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
},
)
# Allocate marlin workspace
# Allocate marlin workspace
max_workspace_size
=
(
max_workspace_size
=
(
...
@@ -405,13 +384,14 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -405,13 +384,14 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
else
:
else
:
# Reset g_idx related tensors
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
layer
.
g_idx
=
Parameter
(
dtype
=
torch
.
int
,
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
device
=
cur_device
),
requires_grad
=
False
,
requires_grad
=
False
)
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
layer
.
g_idx_sort_indices
=
Parameter
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
)
requires_grad
=
False
,
)
# Repack weights
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
marlin_qweight
=
ops
.
gptq_marlin_repack
(
...
@@ -419,6 +399,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -419,6 +399,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer
.
g_idx_sort_indices
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_k
,
part_size_n
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
...
@@ -428,15 +409,28 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -428,15 +409,28 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
if
self
.
quant_config
.
desc_act
:
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
scales_size_n
,
self
.
quant_config
.
group_size
)
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
output
=
ops
.
gptq_marlin_gemm
(
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
reshaped_x
,
layer
.
workspace
,
size_m
,
part_size_n
,
layer
.
qweight
,
part_size_k
,
layer
.
is_k_full
)
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
output
.
add_
(
bias
)
# In-place add
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment