Unverified Commit 7038e8b8 authored by alexm-nm's avatar alexm-nm Committed by GitHub
Browse files

[Kernel] Support running GPTQ 8-bit models in Marlin (#4533)

parent 2a85f930
...@@ -132,6 +132,7 @@ torch::Tensor gptq_marlin_gemm( ...@@ -132,6 +132,7 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor &g_idx, torch::Tensor &g_idx,
torch::Tensor &perm, torch::Tensor &perm,
torch::Tensor &workspace, torch::Tensor &workspace,
int64_t num_bits,
int64_t size_m, int64_t size_m,
int64_t size_n, int64_t size_n,
int64_t size_k, int64_t size_k,
...@@ -141,7 +142,8 @@ torch::Tensor gptq_marlin_repack( ...@@ -141,7 +142,8 @@ torch::Tensor gptq_marlin_repack(
torch::Tensor &b_q_weight, torch::Tensor &b_q_weight,
torch::Tensor &perm, torch::Tensor &perm,
int64_t size_k, int64_t size_k,
int64_t size_n); int64_t size_n,
int64_t num_bits);
#endif #endif
void squeezellm_gemm( void squeezellm_gemm(
......
...@@ -24,8 +24,6 @@ static constexpr int min_thread_k = 64; ...@@ -24,8 +24,6 @@ static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16; static constexpr int tile_size = 16;
static constexpr int max_par = 16; static constexpr int max_par = 16;
static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit
template <typename T, int n> template <typename T, int n>
struct Vec { struct Vec {
T elems[n]; T elems[n];
...@@ -51,13 +49,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool ...@@ -51,13 +49,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
"r"(smem), "l"(glob_ptr), "n"(BYTES)); "r"(smem), "l"(glob_ptr), "n"(BYTES));
} }
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile("{\n"
" .reg .b64 p;\n" " cp.async.cg.shared.global [%0], [%1], %2;\n"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
"}\n" ::"r"(smem), "}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES)); "l"(glob_ptr), "n"(BYTES));
} }
......
...@@ -11,7 +11,7 @@ static constexpr int tile_n_size = tile_k_size * 4; ...@@ -11,7 +11,7 @@ static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int const num_threads, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void __global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr, uint32_t const *__restrict__ perm_ptr,
...@@ -20,7 +20,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, ...@@ -20,7 +20,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} // namespace gptq_marlin } // namespace gptq_marlin
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n) { int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1}); return torch::empty({1, 1});
...@@ -28,11 +29,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, ...@@ -28,11 +29,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
#else #else
template <int const num_threads, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void __global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr, uint32_t const *__restrict__ perm_ptr,
uint32_t *__restrict__ out_ptr, int size_k, int size_n) { uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size; int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size; int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x); int block_k_tiles = div_ceil(k_tiles, gridDim.x);
...@@ -64,9 +67,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, ...@@ -64,9 +67,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
sh_pipe_ptr += perm_size; sh_pipe_ptr += perm_size;
} }
constexpr int tile_ints = tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4; constexpr int stage_n_threads = tile_n_size / 4;
constexpr int stage_k_threads = constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
has_perm ? tile_k_size : tile_k_size / pack_factor_4bit;
constexpr int stage_size = stage_k_threads * stage_n_threads; constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) { auto load_perm_to_shared = [&](int k_tile_id) {
...@@ -99,9 +103,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, ...@@ -99,9 +103,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
reinterpret_cast<uint32_t const *>(sh_perm_ptr); reinterpret_cast<uint32_t const *>(sh_perm_ptr);
int src_k = sh_perm_int_ptr[k_id]; int src_k = sh_perm_int_ptr[k_id];
int src_k_packed = src_k / pack_factor_4bit; int src_k_packed = src_k / pack_factor;
cp_async4_stream( cp_async4(
&sh_ptr[k_id * stage_n_threads + n_id], &sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(&( reinterpret_cast<int4 const *>(&(
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
...@@ -113,12 +117,12 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, ...@@ -113,12 +117,12 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int n_id = threadIdx.x % stage_n_threads; int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size; int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor_4bit; int first_k_packed = first_k / pack_factor;
cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id], cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>( reinterpret_cast<int4 const *>(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + &(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)]))); first_n + (n_id * 4)])));
} }
} }
...@@ -145,26 +149,27 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, ...@@ -145,26 +149,27 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int cur_n = warp_id * 16 + tc_col; int cur_n = warp_id * 16 + tc_col;
constexpr int sh_stride = 64; constexpr int sh_stride = 64;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr); uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr);
uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr); uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr);
uint32_t vals[pack_factor_4bit]; uint32_t vals[8];
if constexpr (has_perm) { if constexpr (has_perm) {
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i]; int k_idx = tc_row + tc_offsets[i];
uint32_t src_k = sh_perm_int_ptr[k_idx]; uint32_t src_k = sh_perm_int_ptr[k_idx];
uint32_t src_k_pos = src_k % pack_factor_4bit; uint32_t src_k_pos = src_k % pack_factor;
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf; uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf; uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
vals[i] = b1_cur_val; vals[i] = b1_cur_val;
vals[4 + i] = b2_cur_val; vals[4 + i] = b2_cur_val;
...@@ -172,41 +177,56 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, ...@@ -172,41 +177,56 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} else { } else {
uint32_t b1_val_1 = sh_stage_int_ptr[cur_n]; uint32_t b1_vals[tile_ints];
uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n]; uint32_t b2_vals[tile_ints];
uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8];
uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8];
#pragma unroll #pragma unroll
for (int i = 0; i < 2; i++) { for (int i = 0; i < tile_ints; i++) {
int cur_elem = tc_row + tc_offsets[i]; b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf; b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf;
} }
#pragma unroll #pragma unroll
for (int i = 2; i < 4; i++) { for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i] - 8; int cur_elem = tc_row + tc_offsets[i];
vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf; int cur_int = cur_elem / pack_factor;
vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf; int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
} }
} }
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of: // Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7}; if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0; uint32_t res = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < pack_factor_4bit; i++) { for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4); res |= vals[pack_idx[i]] << (i * 4);
} }
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit; out_ptr[out_offset + th_id * 4 + warp_id] = res;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
out_ptr[out_offset + th_id * 4 + warp_id] = res; } else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
}; };
auto start_pipes = [&](int k_tile_id, int n_tile_id) { auto start_pipes = [&](int k_tile_id, int n_tile_id) {
...@@ -242,19 +262,35 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, ...@@ -242,19 +262,35 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} // namespace gptq_marlin } // namespace gptq_marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n) { int64_t size_k, int64_t size_n,
int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B // Verify B
TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0), TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", size_k = ", size_k, ", pack_factor = ", pack_factor);
", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit);
TORCH_CHECK(b_q_weight.size(1) == size_n, TORCH_CHECK(b_q_weight.size(1) == size_n,
"b_q_weight.size(1) = ", b_q_weight.size(1), "b_q_weight.size(1) = ", b_q_weight.size(1),
" is not size_n = ", size_n); " is not size_n = ", size_n);
...@@ -273,10 +309,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, ...@@ -273,10 +309,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
auto options = torch::TensorOptions() auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype()) .dtype(b_q_weight.dtype())
.device(b_q_weight.device()); .device(b_q_weight.device());
torch::Tensor out = torch::empty( torch::Tensor out =
{size_k / gptq_marlin::tile_size, torch::empty({size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit}, size_n * gptq_marlin::tile_size / pack_factor},
options); options);
// Detect if there is act_order // Detect if there is act_order
bool has_perm = perm.size(0) != 0; bool has_perm = perm.size(0) != 0;
...@@ -299,23 +335,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, ...@@ -299,23 +335,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0); TORCH_CHECK(max_shared_mem > 0);
if (has_perm) { if (false) {
cudaFuncSetAttribute( }
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>, CALL_IF(4, false)
cudaFuncAttributeMaxDynamicSharedMemorySize, CALL_IF(4, true)
max_shared_mem); CALL_IF(8, false)
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true> CALL_IF(8, true)
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, else {
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm);
} else {
cudaFuncSetAttribute(
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
} }
return out; return out;
......
...@@ -39,6 +39,13 @@ MODELS = [ ...@@ -39,6 +39,13 @@ MODELS = [
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"),
# act_order==True, group_size=32 # act_order==True, group_size=32
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"),
# 8-bit, act_order==True, group_size=channelwise
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"),
# 8-bit, act_order==True, group_size=128
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-128g-actorder_True"),
# 8-bit, act_order==True, group_size=32
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-32g-actorder_True"),
] ]
...@@ -65,8 +72,7 @@ def test_models( ...@@ -65,8 +72,7 @@ def test_models(
dtype=dtype, dtype=dtype,
quantization="marlin", quantization="marlin",
max_model_len=MAX_MODEL_LEN, max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1, tensor_parallel_size=1)
disable_custom_all_reduce=True)
gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
...@@ -78,8 +84,7 @@ def test_models( ...@@ -78,8 +84,7 @@ def test_models(
dtype=dtype, dtype=dtype,
quantization="gptq", quantization="gptq",
max_model_len=MAX_MODEL_LEN, max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1, tensor_parallel_size=1)
disable_custom_all_reduce=True)
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
max_tokens, max_tokens,
num_logprobs) num_logprobs)
......
...@@ -169,18 +169,20 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, ...@@ -169,18 +169,20 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
# gptq_marlin # gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int) -> torch.Tensor: size_k: int, size_n: int,
return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n) num_bits: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
num_bits)
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, g_idx: torch.Tensor, b_scales: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor, size_m: int, perm: torch.Tensor, workspace: torch.Tensor,
size_n: int, size_k: int, num_bits: int, size_m: int, size_n: int, size_k: int,
is_k_full: bool) -> torch.Tensor: is_k_full: bool) -> torch.Tensor:
return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
workspace, size_m, size_n, size_k, workspace, num_bits, size_m, size_n,
is_k_full) size_k, is_k_full)
# fp8 # fp8
......
...@@ -2,7 +2,6 @@ import enum ...@@ -2,7 +2,6 @@ import enum
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import numpy
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -17,41 +16,13 @@ GPTQ_MARLIN_MIN_THREAD_N = 64 ...@@ -17,41 +16,13 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16 GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4] GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True] GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Precompute permutations for Marlin weight and scale shuffling # Permutations for Marlin scale shuffling
# def get_scale_perms(num_bits):
# Marlin works on [16,64] tiles. The goal of the permutations
# is to reorder the weight data so that it is compatible
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the
# kernel will get the data as it is needed for tensor-core
# (without the need to use ldmatrix instructions)
def _get_perms():
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])
perm = numpy.array(perm)
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore
perm = torch.from_numpy(perm)
scale_perm = [] scale_perm = []
for i in range(8): for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)]) scale_perm.extend([i + 8 * j for j in range(8)])
...@@ -59,23 +30,21 @@ def _get_perms(): ...@@ -59,23 +30,21 @@ def _get_perms():
for i in range(4): for i in range(4):
scale_perm_single.extend( scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single return scale_perm, scale_perm_single
_perm, _scale_perm, _scale_perm_single = _get_perms()
def get_pack_factor(num_bits): def get_pack_factor(num_bits):
assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, ( assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
f"Unsupported num_bits = {num_bits}") ), f"Unsupported num_bits = {num_bits}"
return 32 // num_bits return 32 // num_bits
def marlin_permute_scales(s, size_k, size_n, group_size): def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1: if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else: else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous() s = s.reshape((-1, size_n)).contiguous()
return s return s
...@@ -279,13 +248,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -279,13 +248,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad=False, requires_grad=False,
) )
set_weight_attrs( set_weight_attrs(
qweight, { qweight,
{
**extra_weight_attrs, **extra_weight_attrs,
"input_dim": 0, "input_dim": 0,
"output_dim": 1, "output_dim": 1,
"packed_dim": 0, "packed_dim": 0,
"pack_factor": self.quant_config.pack_factor, "pack_factor": self.quant_config.pack_factor,
}) },
)
# Activation order # Activation order
g_idx = Parameter( g_idx = Parameter(
...@@ -296,10 +267,13 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -296,10 +267,13 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad=False, requires_grad=False,
) )
# Ignore warning from fused linear layers such as QKVParallelLinear. # Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(g_idx, { set_weight_attrs(
**extra_weight_attrs, "input_dim": 0, g_idx,
"ignore_warning": True {
}) **extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
g_idx_sort_indices = Parameter( g_idx_sort_indices = Parameter(
torch.empty( torch.empty(
...@@ -320,29 +294,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -320,29 +294,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad=False, requires_grad=False,
) )
set_weight_attrs( set_weight_attrs(
scales, { scales,
{
**extra_weight_attrs, **extra_weight_attrs,
"input_dim": scales_and_zp_input_dim, "input_dim": scales_and_zp_input_dim,
"output_dim": 1, "output_dim": 1,
}) },
)
# Quantized zero-points # Quantized zero-points
qzeros = Parameter( qzeros = Parameter(
torch.empty(scales_and_zp_size, torch.empty(
output_size_per_partition // scales_and_zp_size,
self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32, dtype=torch.int32,
device="meta"), device="meta",
),
requires_grad=False, requires_grad=False,
) )
set_weight_attrs( set_weight_attrs(
qzeros, { qzeros,
{
**extra_weight_attrs, **extra_weight_attrs,
"input_dim": scales_and_zp_input_dim, "input_dim": scales_and_zp_input_dim,
"output_dim": 1, "output_dim": 1,
"packed_dim": 1, "packed_dim": 1,
"pack_factor": self.quant_config.pack_factor, "pack_factor": self.quant_config.pack_factor,
}) },
)
# Allocate marlin workspace # Allocate marlin workspace
max_workspace_size = ( max_workspace_size = (
...@@ -405,13 +384,14 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -405,13 +384,14 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
else: else:
# Reset g_idx related tensors # Reset g_idx related tensors
layer.g_idx = Parameter(torch.empty(0, layer.g_idx = Parameter(
dtype=torch.int, torch.empty(0, dtype=torch.int, device=cur_device),
device=cur_device), requires_grad=False,
requires_grad=False) )
layer.g_idx_sort_indices = Parameter(torch.empty( layer.g_idx_sort_indices = Parameter(
0, dtype=torch.int, device=cur_device), torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False) requires_grad=False,
)
# Repack weights # Repack weights
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
...@@ -419,6 +399,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -419,6 +399,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer.g_idx_sort_indices, layer.g_idx_sort_indices,
part_size_k, part_size_k,
part_size_n, part_size_n,
self.quant_config.weight_bits,
) )
replace_tensor("qweight", marlin_qweight) replace_tensor("qweight", marlin_qweight)
...@@ -428,15 +409,28 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -428,15 +409,28 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
if self.quant_config.desc_act: if self.quant_config.desc_act:
scales_size_k = full_size_k scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(layer.scales, scales_size_k, marlin_scales = marlin_permute_scales(
scales_size_n, layer.scales,
self.quant_config.group_size) scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales) replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales, output = ops.gptq_marlin_gemm(
layer.g_idx, layer.g_idx_sort_indices, reshaped_x,
layer.workspace, size_m, part_size_n, layer.qweight,
part_size_k, layer.is_k_full) layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None: if bias is not None:
output.add_(bias) # In-place add output.add_(bias) # In-place add
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment