Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7038e8b8
Unverified
Commit
7038e8b8
authored
May 02, 2024
by
alexm-nm
Committed by
GitHub
May 02, 2024
Browse files
[Kernel] Support running GPTQ 8-bit models in Marlin (#4533)
parent
2a85f930
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
553 additions
and
324 deletions
+553
-324
csrc/ops.h
csrc/ops.h
+3
-1
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+377
-175
csrc/quantization/gptq_marlin/gptq_marlin.cuh
csrc/quantization/gptq_marlin/gptq_marlin.cuh
+2
-6
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
+90
-62
tests/models/test_gptq_marlin.py
tests/models/test_gptq_marlin.py
+9
-4
vllm/_custom_ops.py
vllm/_custom_ops.py
+8
-6
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+64
-70
No files found.
csrc/ops.h
View file @
7038e8b8
...
...
@@ -132,6 +132,7 @@ torch::Tensor gptq_marlin_gemm(
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
...
...
@@ -141,7 +142,8 @@ torch::Tensor gptq_marlin_repack(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
);
int64_t
size_n
,
int64_t
num_bits
);
#endif
void
squeezellm_gemm
(
...
...
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
7038e8b8
This diff is collapsed.
Click to expand it.
csrc/quantization/gptq_marlin/gptq_marlin.cuh
View file @
7038e8b8
...
...
@@ -24,8 +24,6 @@ static constexpr int min_thread_k = 64;
static
constexpr
int
tile_size
=
16
;
static
constexpr
int
max_par
=
16
;
static
constexpr
int
pack_factor_4bit
=
8
;
// We have 8 4-bit vals inside a 32 bit
template
<
typename
T
,
int
n
>
struct
Vec
{
T
elems
[
n
];
...
...
@@ -51,13 +49,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
__device__
inline
void
cp_async4
_stream
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
__device__
inline
void
cp_async4
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .b64 p;
\n
"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;
\n
"
" cp.async.cg.shared.global [%0], [%1], %2;
\n
"
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
...
...
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
View file @
7038e8b8
...
...
@@ -11,7 +11,7 @@ static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template
<
int
const
num_threads
,
bool
const
has_perm
>
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
...
...
@@ -20,7 +20,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
}
// namespace gptq_marlin
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
)
{
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
...
...
@@ -28,11 +29,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
#else
template
<
int
const
num_threads
,
bool
const
has_perm
>
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{
constexpr
int
pack_factor
=
32
/
num_bits
;
int
k_tiles
=
size_k
/
tile_k_size
;
int
n_tiles
=
size_n
/
tile_n_size
;
int
block_k_tiles
=
div_ceil
(
k_tiles
,
gridDim
.
x
);
...
...
@@ -64,9 +67,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
sh_pipe_ptr
+=
perm_size
;
}
constexpr
int
tile_ints
=
tile_k_size
/
pack_factor
;
constexpr
int
stage_n_threads
=
tile_n_size
/
4
;
constexpr
int
stage_k_threads
=
has_perm
?
tile_k_size
:
tile_k_size
/
pack_factor_4bit
;
constexpr
int
stage_k_threads
=
has_perm
?
tile_k_size
:
tile_ints
;
constexpr
int
stage_size
=
stage_k_threads
*
stage_n_threads
;
auto
load_perm_to_shared
=
[
&
](
int
k_tile_id
)
{
...
...
@@ -99,9 +103,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
reinterpret_cast
<
uint32_t
const
*>
(
sh_perm_ptr
);
int
src_k
=
sh_perm_int_ptr
[
k_id
];
int
src_k_packed
=
src_k
/
pack_factor
_4bit
;
int
src_k_packed
=
src_k
/
pack_factor
;
cp_async4
_stream
(
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[
src_k_packed
*
size_n
+
first_n
+
(
n_id
*
4
)])));
...
...
@@ -113,12 +117,12 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
int
first_k
=
k_tile_id
*
tile_k_size
;
int
first_k_packed
=
first_k
/
pack_factor
_4bit
;
int
first_k_packed
=
first_k
/
pack_factor
;
cp_async4
_stream
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[(
first_k_packed
+
k_id
)
*
size_n
+
first_n
+
(
n_id
*
4
)])));
cp_async4
(
&
sh_ptr
[
k_id
*
stage_n_threads
+
n_id
],
reinterpret_cast
<
int4
const
*>
(
&
(
b_q_weight_ptr
[(
first_k_packed
+
k_id
)
*
size_n
+
first_n
+
(
n_id
*
4
)])));
}
}
...
...
@@ -145,26 +149,27 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int
cur_n
=
warp_id
*
16
+
tc_col
;
constexpr
int
sh_stride
=
64
;
constexpr
uint32_t
mask
=
(
1
<<
num_bits
)
-
1
;
int4
*
sh_stage_ptr
=
sh_pipe_ptr
+
stage_size
*
pipe
;
uint32_t
*
sh_stage_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_stage_ptr
);
uint32_t
*
sh_perm_int_ptr
=
reinterpret_cast
<
uint32_t
*>
(
sh_perm_ptr
);
uint32_t
vals
[
pack_factor_4bit
];
uint32_t
vals
[
8
];
if
constexpr
(
has_perm
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
k_idx
=
tc_row
+
tc_offsets
[
i
];
uint32_t
src_k
=
sh_perm_int_ptr
[
k_idx
];
uint32_t
src_k_pos
=
src_k
%
pack_factor
_4bit
;
uint32_t
src_k_pos
=
src_k
%
pack_factor
;
uint32_t
b1_val
=
sh_stage_int_ptr
[
k_idx
*
sh_stride
+
cur_n
];
uint32_t
b1_cur_val
=
(
b1_val
>>
(
src_k_pos
*
4
))
&
0xf
;
uint32_t
b1_cur_val
=
(
b1_val
>>
(
src_k_pos
*
num_bits
))
&
mask
;
uint32_t
b2_val
=
sh_stage_int_ptr
[
k_idx
*
sh_stride
+
cur_n
+
8
];
uint32_t
b2_cur_val
=
(
b2_val
>>
(
src_k_pos
*
4
))
&
0xf
;
uint32_t
b2_cur_val
=
(
b2_val
>>
(
src_k_pos
*
num_bits
))
&
mask
;
vals
[
i
]
=
b1_cur_val
;
vals
[
4
+
i
]
=
b2_cur_val
;
...
...
@@ -172,41 +177,56 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
}
else
{
uint32_t
b1_val_1
=
sh_stage_int_ptr
[
cur_n
];
uint32_t
b1_val_2
=
sh_stage_int_ptr
[
sh_stride
+
cur_n
];
uint32_t
b2_val_1
=
sh_stage_int_ptr
[
cur_n
+
8
];
uint32_t
b2_val_2
=
sh_stage_int_ptr
[
sh_stride
+
cur_n
+
8
];
uint32_t
b1_vals
[
tile_ints
];
uint32_t
b2_vals
[
tile_ints
];
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
vals
[
i
]
=
(
b1_val_1
>>
(
cur_elem
*
4
))
&
0xf
;
vals
[
4
+
i
]
=
(
b2_val_1
>>
(
cur_elem
*
4
))
&
0xf
;
for
(
int
i
=
0
;
i
<
tile_ints
;
i
++
)
{
b1_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
sh_stride
*
i
];
b2_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
8
+
sh_stride
*
i
];
}
#pragma unroll
for
(
int
i
=
2
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
]
-
8
;
vals
[
i
]
=
(
b1_val_2
>>
(
cur_elem
*
4
))
&
0xf
;
vals
[
4
+
i
]
=
(
b2_val_2
>>
(
cur_elem
*
4
))
&
0xf
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
int
cur_int
=
cur_elem
/
pack_factor
;
int
cur_pos
=
cur_elem
%
pack_factor
;
vals
[
i
]
=
(
b1_vals
[
cur_int
]
>>
(
cur_pos
*
num_bits
))
&
mask
;
vals
[
4
+
i
]
=
(
b2_vals
[
cur_int
]
>>
(
cur_pos
*
num_bits
))
&
mask
;
}
}
constexpr
int
tile_size
=
tile_k_size
*
tile_n_size
/
pack_factor
;
int
out_offset
=
(
k_tile_id
*
n_tiles
+
n_tile_id
)
*
tile_size
;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
constexpr
int
pack_idx
[
pack_factor_4bit
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
if
constexpr
(
num_bits
==
4
)
{
constexpr
int
pack_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
uint32_t
res
=
0
;
uint32_t
res
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
pack_factor_4bit
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
constexpr
int
tile_size
=
tile_k_size
*
tile_n_size
/
pack_factor_4bit
;
int
out_offset
=
(
k_tile_id
*
n_tiles
+
n_tile_id
)
*
tile_size
;
out_ptr
[
out_offset
+
th_id
*
4
+
warp_id
]
=
res
;
out_ptr
[
out_offset
+
th_id
*
4
+
warp_id
]
=
res
;
}
else
{
constexpr
int
pack_idx
[
4
]
=
{
0
,
2
,
1
,
3
};
uint32_t
res1
=
0
;
uint32_t
res2
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
res1
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
8
);
res2
|=
vals
[
4
+
pack_idx
[
i
]]
<<
(
i
*
8
);
}
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
0
]
=
res1
;
out_ptr
[
out_offset
+
th_id
*
8
+
(
warp_id
*
2
)
+
1
]
=
res2
;
}
};
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
...
...
@@ -242,19 +262,35 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
}
// namespace gptq_marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
)
{
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK
(
size_k
%
gptq_marlin
::
tile_k_size
==
0
,
"size_k = "
,
size_k
,
" is not divisible by tile_k_size = "
,
gptq_marlin
::
tile_k_size
);
TORCH_CHECK
(
size_n
%
gptq_marlin
::
tile_n_size
==
0
,
"size_n = "
,
size_n
,
" is not divisible by tile_n_size = "
,
gptq_marlin
::
tile_n_size
);
TORCH_CHECK
(
num_bits
==
4
||
num_bits
==
8
,
"num_bits must be 4 or 8. Got = "
,
num_bits
);
int
const
pack_factor
=
32
/
num_bits
;
// Verify B
TORCH_CHECK
((
size_k
/
gptq_marlin
::
pack_factor
_4bit
)
==
b_q_weight
.
size
(
0
),
TORCH_CHECK
((
size_k
/
pack_factor
)
==
b_q_weight
.
size
(
0
),
"Shape mismatch: b_q_weight.size(0) = "
,
b_q_weight
.
size
(
0
),
", size_k = "
,
size_k
,
", pack_factor_4bit = "
,
gptq_marlin
::
pack_factor_4bit
);
", size_k = "
,
size_k
,
", pack_factor = "
,
pack_factor
);
TORCH_CHECK
(
b_q_weight
.
size
(
1
)
==
size_n
,
"b_q_weight.size(1) = "
,
b_q_weight
.
size
(
1
),
" is not size_n = "
,
size_n
);
...
...
@@ -273,10 +309,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
b_q_weight
.
dtype
())
.
device
(
b_q_weight
.
device
());
torch
::
Tensor
out
=
torch
::
empty
(
{
size_k
/
gptq_marlin
::
tile_size
,
size_n
*
gptq_marlin
::
tile_size
/
gptq_marlin
::
pack_factor
_4bit
},
options
);
torch
::
Tensor
out
=
torch
::
empty
(
{
size_k
/
gptq_marlin
::
tile_size
,
size_n
*
gptq_marlin
::
tile_size
/
pack_factor
},
options
);
// Detect if there is act_order
bool
has_perm
=
perm
.
size
(
0
)
!=
0
;
...
...
@@ -299,23 +335,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
dev
);
TORCH_CHECK
(
max_shared_mem
>
0
);
if
(
has_perm
)
{
cudaFuncSetAttribute
(
gptq_marlin
::
marlin_repack_kernel
<
gptq_marlin
::
repack_threads
,
true
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
max_shared_mem
);
gptq_marlin
::
marlin_repack_kernel
<
gptq_marlin
::
repack_threads
,
true
>
<<<
blocks
,
gptq_marlin
::
repack_threads
,
max_shared_mem
,
stream
>>>
(
b_q_weight_ptr
,
perm_ptr
,
out_ptr
,
size_k
,
size_n
);
}
else
{
cudaFuncSetAttribute
(
gptq_marlin
::
marlin_repack_kernel
<
gptq_marlin
::
repack_threads
,
false
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
max_shared_mem
);
gptq_marlin
::
marlin_repack_kernel
<
gptq_marlin
::
repack_threads
,
false
>
<<<
blocks
,
gptq_marlin
::
repack_threads
,
max_shared_mem
,
stream
>>>
(
b_q_weight_ptr
,
perm_ptr
,
out_ptr
,
size_k
,
size_n
);
if
(
false
)
{
}
CALL_IF
(
4
,
false
)
CALL_IF
(
4
,
true
)
CALL_IF
(
8
,
false
)
CALL_IF
(
8
,
true
)
else
{
TORCH_CHECK
(
false
,
"Unsupported repack config: num_bits = "
,
num_bits
,
", has_perm = "
,
has_perm
);
}
return
out
;
...
...
tests/models/test_gptq_marlin.py
View file @
7038e8b8
...
...
@@ -39,6 +39,13 @@ MODELS = [
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
"gptq-4bit-64g-actorder_True"
),
# act_order==True, group_size=32
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
"gptq-4bit-32g-actorder_True"
),
# 8-bit, act_order==True, group_size=channelwise
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
"gptq-8bit--1g-actorder_True"
),
# 8-bit, act_order==True, group_size=128
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
"gptq-8bit-128g-actorder_True"
),
# 8-bit, act_order==True, group_size=32
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
,
"gptq-8bit-32g-actorder_True"
),
]
...
...
@@ -65,8 +72,7 @@ def test_models(
dtype
=
dtype
,
quantization
=
"marlin"
,
max_model_len
=
MAX_MODEL_LEN
,
tensor_parallel_size
=
1
,
disable_custom_all_reduce
=
True
)
tensor_parallel_size
=
1
)
gptq_marlin_outputs
=
gptq_marlin_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
...
...
@@ -78,8 +84,7 @@ def test_models(
dtype
=
dtype
,
quantization
=
"gptq"
,
max_model_len
=
MAX_MODEL_LEN
,
tensor_parallel_size
=
1
,
disable_custom_all_reduce
=
True
)
tensor_parallel_size
=
1
)
gptq_outputs
=
gptq_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
...
...
vllm/_custom_ops.py
View file @
7038e8b8
...
...
@@ -169,18 +169,20 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
# gptq_marlin
def
gptq_marlin_repack
(
b_q_weight
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_repack
(
b_q_weight
,
perm
,
size_k
,
size_n
)
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_repack
(
b_q_weight
,
perm
,
size_k
,
size_n
,
num_bits
)
def
gptq_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
)
->
torch
.
Tensor
:
return
vllm_ops
.
gptq_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
g_idx
,
perm
,
workspace
,
size_m
,
size_n
,
size_k
,
is_k_full
)
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
,
is_k_full
)
# fp8
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
7038e8b8
...
...
@@ -2,7 +2,6 @@ import enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
import
numpy
import
torch
from
torch.nn.parameter
import
Parameter
...
...
@@ -17,41 +16,13 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
]
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
# Precompute permutations for Marlin weight and scale shuffling
#
# Marlin works on [16,64] tiles. The goal of the permutations
# is to reorder the weight data so that it is compatible
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the
# kernel will get the data as it is needed for tensor-core
# (without the need to use ldmatrix instructions)
def
_get_perms
():
perm
=
[]
for
i
in
range
(
32
):
perm1
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm
)
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
perm
=
perm
.
reshape
((
-
1
,
8
))[:,
interleave
].
ravel
()
# type: ignore
perm
=
torch
.
from_numpy
(
perm
)
# Permutations for Marlin scale shuffling
def
get_scale_perms
(
num_bits
):
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
...
...
@@ -59,23 +30,21 @@ def _get_perms():
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
perm
,
scale_perm
,
scale_perm_single
_perm
,
_scale_perm
,
_scale_perm_single
=
_get_perms
()
return
scale_perm
,
scale_perm_single
def
get_pack_factor
(
num_bits
):
assert
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
,
(
f
"Unsupported num_bits =
{
num_bits
}
"
)
assert
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
),
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
):
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
num_bits
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
_
scale_perm
)))[:,
_
scale_perm
]
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
_
scale_perm_single
)))[:,
_
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
...
...
@@ -279,13 +248,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
qweight
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
},
)
# Activation order
g_idx
=
Parameter
(
...
...
@@ -296,10 +267,13 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad
=
False
,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"ignore_warning"
:
True
})
set_weight_attrs
(
g_idx
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"ignore_warning"
:
True
},
)
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
...
...
@@ -320,29 +294,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
scales
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
})
},
)
# Quantized zero-points
qzeros
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
device
=
"meta"
),
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
device
=
"meta"
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
qzeros
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
},
)
# Allocate marlin workspace
max_workspace_size
=
(
...
...
@@ -405,13 +384,14 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
)
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
...
...
@@ -419,6 +399,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
...
...
@@ -428,15 +409,28 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
)
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment