Unverified Commit 9935720c authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add support for repacking AWQ weights for GPTQ-Marlin (#2278)

* Add support for repacking AWQ weights for GPTQ-Marlin

So far we couldn't support AWQ because virtually all AWQ models use
symmetric quantization, which GPTQ-Marlin did not suppors. GPTQ-Marlin
has recently added support AWQ repacking and AWQ asymmetric quantization
(zero_point=True).

This change updates all GPTQ-Marlin kernels from upstream and wires up
AWQ support. For now enabling AWQ using Marlin requires running TGI with
`--quantize gptq`.

* Enable Marlin for supported AWQ configurations by default

This makes the AWQ -> GPTQ repack test redundant, since we are now
testing this with the regular AWQ test.
parent 5fca30ee
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace marlin
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
namespace marlin {
template <int const num_threads, int const num_bits>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<repack_stages - 2>();
__syncthreads();
};
extern __shared__ int4 sh[];
constexpr int tile_n_ints = tile_n_size / pack_factor;
constexpr int stage_n_threads = tile_n_ints / 4;
constexpr int stage_k_threads = tile_k_size;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
cp_async_fence();
return;
}
int first_n = n_tile_id * tile_n_size;
int first_n_packed = first_n / pack_factor;
int4* sh_ptr = sh + stage_size * pipe;
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
first_n_packed + (n_id * 4)])));
}
cp_async_fence();
};
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
return;
}
int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
int cur_n_packed = cur_n / pack_factor;
int cur_n_pos = cur_n % pack_factor;
constexpr int sh_stride = tile_n_ints;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh + stage_size * pipe;
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
// Undo interleaving
int cur_n_pos_unpacked;
if constexpr (num_bits == 4) {
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
} else {
constexpr int undo_pack[4] = {0, 2, 1, 3};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
}
uint32_t vals[8];
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
sh_stride * cur_elem];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
}
n_tile_id += repack_stages;
}
}
}
} // namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size);
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
TORCH_CHECK(b_q_weight.size(0) == size_k,
"b_q_weight.size(0) = ", b_q_weight.size(0),
" is not size_k = ", size_k);
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
", size_n = ", size_n, ", pack_factor = ", pack_factor);
// Verify device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out = torch::empty(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
// Get ptrs
uint32_t const* b_q_weight_ptr =
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
// Get dev info
int dev = b_q_weight.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (false) {
}
CALL_IF(4)
CALL_IF(8)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
}
return out;
}
#endif
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include "ext.hh" #include "ext.hh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("awq_marlin_repack", &awq_marlin_repack,
"Repack AWQ parameters for Marlin");
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, m.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"Marlin gemm with GPTQ compatibility"); "Marlin gemm with GPTQ compatibility");
m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm"); m.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin sparse 2:4 gemm");
......
...@@ -6,11 +6,15 @@ ...@@ -6,11 +6,15 @@
// No support for async // No support for async
#else #else
torch::Tensor awq_marlin_repack(torch::Tensor &b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &g_idx, torch::Tensor &b_scales, torch::Tensor &b_zeros,
torch::Tensor &perm, torch::Tensor &workspace, torch::Tensor &g_idx, torch::Tensor &perm,
int64_t num_bits, int64_t size_m, int64_t size_n, torch::Tensor &workspace, int64_t num_bits,
int64_t size_k, bool is_k_full); int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_meta, torch::Tensor &b_meta,
...@@ -27,8 +31,8 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, ...@@ -27,8 +31,8 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &workspace, torch::Tensor &b_scales, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k); int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor fp8_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace, torch::Tensor &b_scales, torch::Tensor &workspace,
int64_t num_bits, int64_t size_m, int64_t size_n, int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k); int64_t size_k);
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
* Adapted from https://github.com/IST-DASLab/marlin * Adapted from https://github.com/IST-DASLab/marlin
*/ */
#include "./gptq_marlin.cuh" #include "marlin.cuh"
#include "./gptq_marlin_dtypes.cuh" #include "marlin_dtypes.cuh"
using namespace gptq_marlin; using namespace marlin;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \ static_assert(std::is_same<scalar_t, half>::value || \
...@@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
", size_k = ", size_k); ", size_k = ", size_k);
// Verify B // Verify B
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", gptq_marlin::tile_size); " is not divisible by tile_size = ", marlin::tile_size);
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); ", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
"b_q_weight.size(1) = ", b_q_weight.size(1), "b_q_weight.size(1) = ", b_q_weight.size(1),
" is not divisible by tile_size = ", gptq_marlin::tile_size); " is not divisible by tile_size = ", marlin::tile_size);
int actual_size_n = int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n); ", actual_size_n = ", actual_size_n);
...@@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
num_groups = b_scales.size(0); num_groups = b_scales.size(0);
// Verify workspace size // Verify workspace size
TORCH_CHECK( TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, ", is not divisible by min_thread_n = ", marlin::min_thread_n);
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
int min_workspace_size =
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size, TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = ", workspace.numel(), "workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size); " is below min_workspace_size = ", min_workspace_size);
...@@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k, b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, num_groups, group_size, dev, workspace.data_ptr(), num_bits, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par); marlin::max_par);
} else if (a.scalar_type() == at::ScalarType::BFloat16) { } else if (a.scalar_type() == at::ScalarType::BFloat16) {
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>( fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(), a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m, c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par); marlin::max_par);
} else { } else {
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
} }
......
This diff is collapsed.
#include "gptq_marlin.cuh" #include "marlin.cuh"
namespace gptq_marlin {
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void marlin_repack_kernel( __global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {} int size_k, int size_n) {}
} // namespace gptq_marlin } // namespace marlin
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
...@@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, ...@@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
#else #else
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void marlin_repack_kernel( __global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) { int size_k, int size_n) {
...@@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel( ...@@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel(
} }
} }
} // namespace gptq_marlin } // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \ #define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
NUM_BITS, HAS_PERM>, \ HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \ HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \ <<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits) { int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); " is not divisible by tile_k_size = ", marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); " is not divisible by tile_n_size = ", marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8, TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits); "num_bits must be 4 or 8. Got = ", num_bits);
...@@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, ...@@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
auto options = torch::TensorOptions() auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype()) .dtype(b_q_weight.dtype())
.device(b_q_weight.device()); .device(b_q_weight.device());
torch::Tensor out = torch::Tensor out = torch::empty(
torch::empty({size_k / gptq_marlin::tile_size, {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
size_n * gptq_marlin::tile_size / pack_factor}, options);
options);
// Detect if there is act_order // Detect if there is act_order
bool has_perm = perm.size(0) != 0; bool has_perm = perm.size(0) != 0;
......
...@@ -9,7 +9,9 @@ ...@@ -9,7 +9,9 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream> #include <iostream>
namespace gptq_marlin { namespace marlin {
// Marlin params
// 8 warps are a good choice since every SM has 4 schedulers and having more // 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time, // than 1 warp per schedule allows some more latency hiding. At the same time,
...@@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64; ...@@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16; static constexpr int tile_size = 16;
static constexpr int max_par = 16; static constexpr int max_par = 16;
// Repack params
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
// Helpers
template <typename T, int n> template <typename T, int n>
struct Vec { struct Vec {
T elems[n]; T elems[n];
...@@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() { ...@@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() {
#endif #endif
} // namespace gptq_marlin } // namespace marlin
...@@ -30,7 +30,7 @@ inline std::string str(T x) { ...@@ -30,7 +30,7 @@ inline std::string str(T x) {
return std::to_string(x); return std::to_string(x);
} }
namespace marlin { namespace marlin_dense {
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
...@@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, ...@@ -1040,7 +1040,7 @@ void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
} }
} }
} // namespace marlin } // namespace marlin_dense
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace, torch::Tensor& b_scales, torch::Tensor& workspace,
...@@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1054,24 +1054,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK(size_k == a.size(1), TORCH_CHECK(size_k == a.size(1),
"Shape mismatch: a.size(1) = " + str(a.size(1)) + "Shape mismatch: a.size(1) = " + str(a.size(1)) +
", size_k = " + str(size_k)); ", size_k = " + str(size_k));
TORCH_CHECK(size_k % marlin::tile_size == 0, TORCH_CHECK(size_k % marlin_dense::tile_size == 0,
"size_k = " + str(size_k) + "size_k = " + str(size_k) + " is not divisible by tile_size = " +
" is not divisible by tile_size = " + str(marlin::tile_size)); str(marlin_dense::tile_size));
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = " + "Shape mismatch: b_q_weight.size(0) = " +
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
", tile_size = " + str(marlin::tile_size)); ", tile_size = " + str(marlin_dense::tile_size));
// Verify N // Verify N
TORCH_CHECK(b_scales.size(1) == size_n, TORCH_CHECK(b_scales.size(1) == size_n,
"b_scales.size(1) = " + str(b_scales.size(1)) + "b_scales.size(1) = " + str(b_scales.size(1)) +
", size_n = " + str(size_n)); ", size_n = " + str(size_n));
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, TORCH_CHECK(
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) + b_q_weight.size(1) % marlin_dense::tile_size == 0,
" is not divisible by tile_size = " + str(marlin::tile_size)); "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
" is not divisible by tile_size = " + str(marlin_dense::tile_size));
int actual_size_n = int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) *
(b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; marlin_dense::pack_factor_4bit;
TORCH_CHECK( TORCH_CHECK(
size_n == actual_size_n, size_n == actual_size_n,
"size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
...@@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -1116,21 +1117,22 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
"Unexpected groupsize = " + str(groupsize)); "Unexpected groupsize = " + str(groupsize));
// Verify workspace size // Verify workspace size
TORCH_CHECK( TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0,
size_n % marlin::min_thread_n == 0, "size_n = " + str(size_n) +
"size_n = " + str(size_n) + ", is not divisible by min_thread_n = " +
", is not divisible by min_thread_n = " + str(marlin::min_thread_n)); str(marlin_dense::min_thread_n));
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; int min_workspace_size =
(size_n / marlin_dense::min_thread_n) * marlin_dense::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size, TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = " + str(workspace.numel()) + "workspace.numel = " + str(workspace.numel()) +
" is below min_workspace_size = " + str(min_workspace_size)); " is below min_workspace_size = " + str(min_workspace_size));
int dev = a.get_device(); int dev = a.get_device();
marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
b_scales.data_ptr(), size_m, size_n, size_k, b_scales.data_ptr(), size_m, size_n, size_k,
workspace.data_ptr(), groupsize, dev, workspace.data_ptr(), groupsize, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, at::cuda::getCurrentCUDAStream(dev), thread_k,
sms, marlin::max_par); thread_n, sms, marlin_dense::max_par);
return c; return c;
} }
#ifndef _data_types_cuh #ifndef _data_types_cuh
#define _data_types_cuh #define _data_types_cuh
#include "gptq_marlin.cuh" #include "marlin.cuh"
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
namespace gptq_marlin { namespace marlin {
template <typename scalar_t> template <typename scalar_t>
class ScalarType {}; class ScalarType {};
...@@ -23,6 +23,7 @@ class ScalarType<half> { ...@@ -23,6 +23,7 @@ class ScalarType<half> {
using FragB = Vec<half2, 2>; using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; using FragS = Vec<half2, 1>;
using FragZP = Vec<half2, 4>;
static __device__ float inline num2float(const half x) { static __device__ float inline num2float(const half x) {
return __half2float(x); return __half2float(x);
...@@ -51,6 +52,7 @@ class ScalarType<nv_bfloat16> { ...@@ -51,6 +52,7 @@ class ScalarType<nv_bfloat16> {
using FragB = Vec<nv_bfloat162, 2>; using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>; using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>; using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) { static __device__ float inline num2float(const nv_bfloat16 x) {
...@@ -72,6 +74,6 @@ class ScalarType<nv_bfloat16> { ...@@ -72,6 +74,6 @@ class ScalarType<nv_bfloat16> {
#endif #endif
}; };
} // namespace gptq_marlin } // namespace marlin
#endif #endif
...@@ -9,6 +9,7 @@ setup( ...@@ -9,6 +9,7 @@ setup(
CUDAExtension( CUDAExtension(
name="marlin_kernels", name="marlin_kernels",
sources=[ sources=[
"marlin_kernels/awq_marlin_repack.cu",
"marlin_kernels/fp8_marlin.cu", "marlin_kernels/fp8_marlin.cu",
"marlin_kernels/gptq_marlin.cu", "marlin_kernels/gptq_marlin.cu",
"marlin_kernels/gptq_marlin_repack.cu", "marlin_kernels/gptq_marlin_repack.cu",
......
...@@ -156,16 +156,26 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -156,16 +156,26 @@ class GPTQWeightsLoader(WeightsLoader):
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
) )
g_idx = weights.get_tensor(f"{prefix}.g_idx") if not self.sym:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
scales = weights.get_tensor(f"{prefix}.scales") scales = weights.get_tensor(f"{prefix}.scales")
return repack_gptq_for_marlin( return repack_gptq_for_marlin(
qweight=qweight, qweight=qweight,
scales=scales, scales=scales,
qzeros=qzeros,
g_idx=g_idx, g_idx=g_idx,
bits=self.bits, bits=self.bits,
desc_act=self.desc_act, desc_act=self.desc_act,
groupsize=self.groupsize, groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym, sym=self.sym,
sharded_infeatures=False, sharded_infeatures=False,
) )
...@@ -275,14 +285,26 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -275,14 +285,26 @@ class GPTQWeightsLoader(WeightsLoader):
quantize=self.quantize, quantize=self.quantize,
sym=self.sym, sym=self.sym,
): ):
g_idx = weights.get_tensor(f"{prefix}.g_idx") if not self.sym:
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin( return repack_gptq_for_marlin(
qweight=qweight, qweight=qweight,
scales=scales, scales=scales,
qzeros=qzeros,
g_idx=g_idx, g_idx=g_idx,
bits=self.bits, bits=self.bits,
desc_act=self.desc_act, desc_act=self.desc_act,
groupsize=self.groupsize, groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym, sym=self.sym,
sharded_infeatures=False, sharded_infeatures=False,
) )
...@@ -349,18 +371,31 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -349,18 +371,31 @@ class GPTQWeightsLoader(WeightsLoader):
quantize=self.quantize, quantize=self.quantize,
sym=self.sym, sym=self.sym,
): ):
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]: if not self.sym:
torch.testing.assert_close(w2, w[0]) qzeros = torch.cat(
g_idx = w[0] [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin( return repack_gptq_for_marlin(
qweight=qweight, qweight=qweight,
scales=scales, scales=scales,
qzeros=qzeros,
g_idx=g_idx, g_idx=g_idx,
bits=self.bits, bits=self.bits,
desc_act=self.desc_act, desc_act=self.desc_act,
groupsize=self.groupsize, groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym, sym=self.sym,
sharded_infeatures=False, sharded_infeatures=False,
) )
...@@ -438,7 +473,19 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -438,7 +473,19 @@ class GPTQWeightsLoader(WeightsLoader):
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
) )
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) if not self.sym:
if self.desc_act or self.groupsize == -1:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1: if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales") scales = weights.get_tensor(f"{prefix}.scales")
else: else:
...@@ -449,10 +496,12 @@ class GPTQWeightsLoader(WeightsLoader): ...@@ -449,10 +496,12 @@ class GPTQWeightsLoader(WeightsLoader):
return repack_gptq_for_marlin( return repack_gptq_for_marlin(
qweight=qweight, qweight=qweight,
scales=scales, scales=scales,
qzeros=qzeros,
g_idx=g_idx, g_idx=g_idx,
bits=self.bits, bits=self.bits,
desc_act=self.desc_act, desc_act=self.desc_act,
groupsize=self.groupsize, groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym, sym=self.sym,
sharded_infeatures=sharded_in_features, sharded_infeatures=sharded_in_features,
) )
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy
import torch import torch
import torch.nn as nn import torch.nn as nn
from loguru import logger from loguru import logger
...@@ -174,11 +175,12 @@ def can_use_gptq_marlin( ...@@ -174,11 +175,12 @@ def can_use_gptq_marlin(
SYSTEM == "cuda" SYSTEM == "cuda"
and marlin_kernels is not None and marlin_kernels is not None
and has_sm_8_0 and has_sm_8_0
and quantize == "gptq" and quantize in {"awq", "gptq"}
and quant_method == "gptq" and quant_method in {"awq", "gptq"}
and bits in GPTQ_MARLIN_BITS and bits in GPTQ_MARLIN_BITS
and groupsize in GPTQ_MARLIN_GROUP_SIZES and groupsize in GPTQ_MARLIN_GROUP_SIZES
and sym # We only suppord asymmetric quantization for AWQ.
and (sym or quant_method == "awq")
) )
...@@ -234,6 +236,7 @@ class GPTQMarlinWeight(Weight): ...@@ -234,6 +236,7 @@ class GPTQMarlinWeight(Weight):
""" """
qweight: torch.Tensor qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor scales: torch.Tensor
g_idx: torch.Tensor g_idx: torch.Tensor
perm: torch.Tensor perm: torch.Tensor
...@@ -256,11 +259,13 @@ class GPTQMarlinWeight(Weight): ...@@ -256,11 +259,13 @@ class GPTQMarlinWeight(Weight):
def repack_gptq_for_marlin( def repack_gptq_for_marlin(
*, *,
qweight: torch.Tensor, qweight: torch.Tensor,
qzeros: Optional[torch.Tensor],
scales: torch.Tensor, scales: torch.Tensor,
g_idx: torch.Tensor, g_idx: Optional[torch.Tensor],
bits: int, bits: int,
desc_act: bool, desc_act: bool,
groupsize: int, groupsize: int,
quant_method: str,
sym: bool, sym: bool,
sharded_infeatures: bool, sharded_infeatures: bool,
) -> GPTQMarlinWeight: ) -> GPTQMarlinWeight:
...@@ -279,30 +284,54 @@ def repack_gptq_for_marlin( ...@@ -279,30 +284,54 @@ def repack_gptq_for_marlin(
raise RuntimeError( raise RuntimeError(
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}" f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
) )
if not sym: if not (sym or quant_method == "awq"):
raise RuntimeError( raise RuntimeError(
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported." "Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
) )
log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.")
weights_per_int = 32 // bits weights_per_int = 32 // bits
in_features = qweight.shape[0] * weights_per_int in_features = qweight.shape[0]
out_features = qweight.shape[1] out_features = qweight.shape[1]
# AWQ uses column packing, GPTQ uses row packing
if quant_method == "awq":
out_features *= weights_per_int
else:
in_features *= weights_per_int
if in_features % groupsize != 0: if in_features % groupsize != 0:
raise ValueError( raise ValueError(
f"Number of input features ({in_features}) not divisible by group size ({groupsize})" f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
) )
if desc_act and groupsize != -1: if g_idx is not None and desc_act and groupsize != -1:
perm = torch.argsort(g_idx).to(torch.int) perm = torch.argsort(g_idx).to(torch.int)
g_idx = g_idx[perm] g_idx = g_idx[perm]
else: else:
perm = torch.empty(0, dtype=torch.int, device=qweight.device) perm = torch.empty(0, dtype=torch.int, device=qweight.device)
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device) g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack( if quant_method == "awq":
qweight, perm, in_features, out_features, bits repacked = marlin_kernels.awq_marlin_repack(
) qweight, in_features, out_features, bits
)
if qzeros is not None:
qzeros = awq_to_marlin_zero_points(
qzeros,
in_features // groupsize,
out_features,
bits,
)
else:
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits
)
if qzeros is None:
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
scales = permute_scales(scales) scales = permute_scales(scales)
...@@ -310,6 +339,7 @@ def repack_gptq_for_marlin( ...@@ -310,6 +339,7 @@ def repack_gptq_for_marlin(
return GPTQMarlinWeight( return GPTQMarlinWeight(
qweight=repacked, qweight=repacked,
qzeros=qzeros,
scales=scales, scales=scales,
g_idx=g_idx, g_idx=g_idx,
perm=perm, perm=perm,
...@@ -343,6 +373,7 @@ class GPTQMarlinLinear(nn.Module): ...@@ -343,6 +373,7 @@ class GPTQMarlinLinear(nn.Module):
self.is_full_k = weight.is_full_k self.is_full_k = weight.is_full_k
self.qweight = weight.qweight self.qweight = weight.qweight
self.qzeros = weight.qzeros
self.scales = weight.scales self.scales = weight.scales
self.g_idx = weight.g_idx self.g_idx = weight.g_idx
self.perm = weight.perm self.perm = weight.perm
...@@ -363,6 +394,7 @@ class GPTQMarlinLinear(nn.Module): ...@@ -363,6 +394,7 @@ class GPTQMarlinLinear(nn.Module):
A_flat, A_flat,
self.qweight, self.qweight,
self.scales, self.scales,
self.qzeros,
self.g_idx, self.g_idx,
self.perm, self.perm,
self.workspace, self.workspace,
...@@ -371,6 +403,7 @@ class GPTQMarlinLinear(nn.Module): ...@@ -371,6 +403,7 @@ class GPTQMarlinLinear(nn.Module):
self.scales.shape[1], self.scales.shape[1],
A_flat.shape[1], A_flat.shape[1],
self.is_full_k, self.is_full_k,
self.qzeros.numel() > 0,
) )
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
...@@ -688,3 +721,116 @@ class MarlinLinear(nn.Module): ...@@ -688,3 +721,116 @@ class MarlinLinear(nn.Module):
C += self.bias C += self.bias
return C return C
# Functions below are from vLLM
def get_pack_factor(bits: int) -> int:
if 32 % bits != 0:
raise ValueError(f"Cannot {bits} bit values into uint32")
return 32 // bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
zp = zp.reshape((-1, len(_scale_perm)))[:, _scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
return zp
def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
# Undo interleaving (use argsort(..) to get inverse perm)
if num_bits == 4:
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
elif num_bits == 8:
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
return marlin_zp
...@@ -34,7 +34,7 @@ def _get_quantizer_config(model_id, revision): ...@@ -34,7 +34,7 @@ def _get_quantizer_config(model_id, revision):
groupsize = -1 groupsize = -1
quant_method = "gptq" quant_method = "gptq"
checkpoint_format = None checkpoint_format = None
sym = True sym = False
desc_act = False desc_act = False
filename = "config.json" filename = "config.json"
...@@ -52,12 +52,17 @@ def _get_quantizer_config(model_id, revision): ...@@ -52,12 +52,17 @@ def _get_quantizer_config(model_id, revision):
activation_scale_ub=data["quantization_config"]["activation_scale_ub"] activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
) )
if "zero_point" in data["quantization_config"]:
sym = not data["quantization_config"]["zero_point"]
quant_method = "awq"
elif "sym" in data["quantization_config"]:
sym = data["quantization_config"]["sym"]
bits = data["quantization_config"]["bits"] bits = data["quantization_config"]["bits"]
groupsize = data["quantization_config"]["group_size"] groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models # Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"] quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format") checkpoint_format = data["quantization_config"].get("checkpoint_format")
sym = data["quantization_config"]["sym"]
desc_act = data["quantization_config"]["desc_act"] desc_act = data["quantization_config"]["desc_act"]
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
...@@ -72,7 +77,13 @@ def _get_quantizer_config(model_id, revision): ...@@ -72,7 +77,13 @@ def _get_quantizer_config(model_id, revision):
data = json.load(f) data = json.load(f)
bits = data["bits"] bits = data["bits"]
groupsize = data["group_size"] groupsize = data["group_size"]
sym = data["sym"]
if "zero_point" in data:
sym = not data["zero_point"]
quant_method = "awq"
elif "sym" in data:
sym = data["sym"]
desc_act = data["desc_act"] desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM": if "version" in data and data["version"] == "GEMM":
quant_method = "awq" quant_method = "awq"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment