Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6b3cc75b
Unverified
Commit
6b3cc75b
authored
Mar 24, 2025
by
Jinzhen Lin
Committed by
GitHub
Mar 24, 2025
Browse files
[Kernel] allow non-contiguous input for marlin kernel (#14658)
Signed-off-by:
Jinzhen Lin
<
linjinzhen@hotmail.com
>
parent
7ffcccfa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
20 deletions
+73
-20
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+28
-20
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+45
-0
No files found.
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
6b3cc75b
...
...
@@ -42,7 +42,7 @@ namespace marlin {
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{}
int
size_k
,
int
lda
,
int
block_rows
)
{}
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
vllm
::
ScalarTypeId
w_type_id
,
// weight ScalarType id
...
...
@@ -459,7 +459,7 @@ __device__ inline void barrier_release(int* lock, bool reset = false) {
__global__
void
permute_cols_kernel
(
int4
const
*
__restrict__
a_int4_ptr
,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{
int
size_k
,
int
lda
,
int
block_rows
)
{
int
start_row
=
block_rows
*
blockIdx
.
x
;
int
finish_row
=
start_row
+
block_rows
;
if
(
finish_row
>
size_m
)
{
...
...
@@ -467,16 +467,19 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
}
int
cur_block_rows
=
finish_row
-
start_row
;
int
row_stride
=
size_k
*
sizeof
(
half
)
/
16
;
int
input_row_stride
=
lda
*
sizeof
(
half
)
/
16
;
int
output_row_stride
=
size_k
*
sizeof
(
half
)
/
16
;
auto
permute_row
=
[
&
](
int
row
)
{
int
iters
=
size_k
/
default_threads
;
int
rest
=
size_k
%
default_threads
;
int
offset
=
row
*
row_stride
;
int
input_offset
=
row
*
input_row_stride
;
int
output_offset
=
row
*
output_row_stride
;
half
const
*
a_row_half
=
reinterpret_cast
<
half
const
*>
(
a_int4_ptr
+
offset
);
half
*
out_half
=
reinterpret_cast
<
half
*>
(
out_int4_ptr
+
offset
);
half
const
*
a_row_half
=
reinterpret_cast
<
half
const
*>
(
a_int4_ptr
+
input_offset
);
half
*
out_half
=
reinterpret_cast
<
half
*>
(
out_int4_ptr
+
output_offset
);
int
base_k
=
0
;
...
...
@@ -537,6 +540,7 @@ __global__ void Marlin(
int
prob_m
,
// batch dimension m
int
prob_n
,
// output dimension n
int
prob_k
,
// reduction dimension k
int
lda
,
// A.stride(0), equal to prob_k is A is contiguous
int
*
locks
,
// extra global storage for barrier synchronization
bool
use_atomic_add
,
// whether to use atomic add to reduce
bool
use_fp32_reduce
// whether to use fp32 global reduce
...
...
@@ -600,7 +604,7 @@ __global__ void Marlin(
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if
(
slice_col_par
>=
n_tiles
)
{
A
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_k
/
8
;
A
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
lda
/
8
;
C
+=
(
slice_col_par
/
n_tiles
)
*
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
(
slice_col_par
/
n_tiles
)
*
n_tiles
;
slice_col
=
slice_col_par
%
n_tiles
;
...
...
@@ -631,7 +635,7 @@ __global__ void Marlin(
}
}
if
(
slice_col
==
n_tiles
)
{
A
+=
16
*
thread_m_blocks
*
prob_k
/
8
;
A
+=
16
*
thread_m_blocks
*
lda
/
8
;
C
+=
16
*
thread_m_blocks
*
prob_n
/
8
;
locks
+=
n_tiles
;
slice_col
=
0
;
...
...
@@ -643,7 +647,7 @@ __global__ void Marlin(
// A sizes/strides
// stride of the A matrix in global memory
int
a_gl_stride
=
prob_k
/
8
;
int
a_gl_stride
=
lda
/
8
;
// stride of an A matrix tile in shared memory
constexpr
int
a_sh_stride
=
16
*
thread_k_blocks
/
8
;
// delta between subsequent A tiles in global memory
...
...
@@ -1780,8 +1784,8 @@ __global__ void Marlin(
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, locks,
use_atomic_add,
\
use_
fp32_reduce);
\
num_groups, prob_m, prob_n, prob_k,
lda,
locks,
\
use_
atomic_add, use_fp32_reduce);
\
} \
}
...
...
@@ -2071,7 +2075,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
template
<
typename
scalar_t
>
void
marlin_mm
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
C_tmp
,
void
*
s
,
void
*
zp
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
prob_n
,
int
prob_k
,
int
lda
,
void
*
workspace
,
vllm
::
ScalarType
const
&
q_type
,
bool
has_act_order
,
bool
is_k_full
,
bool
has_zp
,
int
num_groups
,
int
group_size
,
int
dev
,
cudaStream_t
stream
,
int
thread_k
,
int
thread_n
,
...
...
@@ -2184,8 +2188,9 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
// Permute A columns
int
block_rows
=
div_ceil
(
prob_m
,
blocks
);
permute_cols_kernel
<<<
blocks
,
default_threads
,
0
,
stream
>>>
(
A_ptr
,
perm_ptr
,
a_tmp_ptr
,
prob_m
,
prob_k
,
block_rows
);
A_ptr
,
perm_ptr
,
a_tmp_ptr
,
prob_m
,
prob_k
,
lda
,
block_rows
);
A_ptr
=
a_tmp_ptr
;
lda
=
prob_k
;
}
// If we have a full K, then we can run the non-act-order version of Marlin
...
...
@@ -2244,7 +2249,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
", num_bits = "
,
num_bits
);
}
A_ptr
+=
16
*
thread_m_blocks
*
(
prob_k
/
8
)
*
par
;
A_ptr
+=
16
*
thread_m_blocks
*
(
lda
/
8
)
*
par
;
C_ptr
+=
16
*
thread_m_blocks
*
(
prob_n
/
8
)
*
par
;
}
}
...
...
@@ -2300,7 +2305,10 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
// Verify device and strides
TORCH_CHECK
(
a
.
device
().
is_cuda
(),
"A is not on GPU"
);
TORCH_CHECK
(
a
.
is_contiguous
(),
"A is not contiguous"
);
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
,
"A.stride(1) is not 1"
);
// We use int4 (16 bytes) to load A, so A must aligned to 16 bytes
TORCH_CHECK
(
a
.
stride
(
0
)
%
8
==
0
,
"A.stride(0) must divisible by 8"
);
TORCH_CHECK
(((
uint64_t
)
a
.
data_ptr
())
%
16
==
0
,
"A must aligned to 16 bytes"
);
TORCH_CHECK
(
b_q_weight
.
device
().
is_cuda
(),
"b_q_weight is not on GPU"
);
TORCH_CHECK
(
b_q_weight
.
is_contiguous
(),
"b_q_weight is not contiguous"
);
...
...
@@ -2432,7 +2440,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
c_tmp
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
size_k
,
a
.
stride
(
0
),
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_atomic_add
,
...
...
@@ -2443,10 +2451,10 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
c
.
data_ptr
<
at
::
BFloat16
>
(),
c_tmp
.
data_ptr
<
float
>
(),
b_scales
.
data_ptr
<
at
::
BFloat16
>
(),
b_zeros
.
data_ptr
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
a
.
stride
(
0
),
workspace
.
data_ptr
(),
b_q_type
,
has_act_order
,
is_k_full
,
has_zp
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
marlin
::
max_par
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
);
}
else
{
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
}
...
...
tests/kernels/test_marlin_gemm.py
View file @
6b3cc75b
...
...
@@ -606,6 +606,51 @@ def test_marlin_qqq_gemm(
assert
max_diff
<
0.04
def
test_marlin_gemm_subset_input
():
quant_type
=
scalar_types
.
uint4b8
group_size
=
128
size_m
,
size_k
,
size_n
=
32
,
1024
,
2048
big_m
=
size_m
*
2
big_k
=
size_k
*
2
a_input
=
rand_data
((
big_m
,
big_k
))[
8
:
size_m
+
8
,
8
:
size_k
+
8
]
b_weight
=
rand_data
((
size_k
,
size_n
))
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
b_weight
,
quant_type
,
group_size
,
False
)
marlin_zp
=
marlin_make_empty_g_idx
(
marlin_s
.
device
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
True
,
has_zp
=
False
,
use_atomic_add
=
False
,
use_fp32_reduce
=
True
,
is_zp_float
=
False
,
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
assert
max_diff
<
0.04
def
test_marlin_gemm_opcheck
():
size_m
=
2048
size_n
=
4096
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment