Unverified Commit 7038e8b8 authored by alexm-nm's avatar alexm-nm Committed by GitHub
Browse files

[Kernel] Support running GPTQ 8-bit models in Marlin (#4533)

parent 2a85f930
......@@ -132,6 +132,7 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor &g_idx,
torch::Tensor &perm,
torch::Tensor &workspace,
int64_t num_bits,
int64_t size_m,
int64_t size_n,
int64_t size_k,
......@@ -141,7 +142,8 @@ torch::Tensor gptq_marlin_repack(
torch::Tensor &b_q_weight,
torch::Tensor &perm,
int64_t size_k,
int64_t size_n);
int64_t size_n,
int64_t num_bits);
#endif
void squeezellm_gemm(
......
......@@ -24,8 +24,6 @@ static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit
template <typename T, int n>
struct Vec {
T elems[n];
......@@ -51,13 +49,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n"
" .reg .b64 p;\n"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
......
......@@ -11,7 +11,7 @@ static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int const num_threads, bool const has_perm>
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr,
......@@ -20,7 +20,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} // namespace gptq_marlin
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n) {
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
......@@ -28,11 +29,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
#else
template <int const num_threads, bool const has_perm>
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr,
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
......@@ -64,9 +67,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
sh_pipe_ptr += perm_size;
}
constexpr int tile_ints = tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4;
constexpr int stage_k_threads =
has_perm ? tile_k_size : tile_k_size / pack_factor_4bit;
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) {
......@@ -99,9 +103,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
reinterpret_cast<uint32_t const *>(sh_perm_ptr);
int src_k = sh_perm_int_ptr[k_id];
int src_k_packed = src_k / pack_factor_4bit;
int src_k_packed = src_k / pack_factor;
cp_async4_stream(
cp_async4(
&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(&(
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
......@@ -113,12 +117,12 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor_4bit;
int first_k_packed = first_k / pack_factor;
cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)])));
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)])));
}
}
......@@ -145,26 +149,27 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int cur_n = warp_id * 16 + tc_col;
constexpr int sh_stride = 64;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr);
uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr);
uint32_t vals[pack_factor_4bit];
uint32_t vals[8];
if constexpr (has_perm) {
for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i];
uint32_t src_k = sh_perm_int_ptr[k_idx];
uint32_t src_k_pos = src_k % pack_factor_4bit;
uint32_t src_k_pos = src_k % pack_factor;
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf;
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf;
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
vals[i] = b1_cur_val;
vals[4 + i] = b2_cur_val;
......@@ -172,41 +177,56 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} else {
uint32_t b1_val_1 = sh_stage_int_ptr[cur_n];
uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n];
uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8];
uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8];
uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints];
#pragma unroll
for (int i = 0; i < 2; i++) {
int cur_elem = tc_row + tc_offsets[i];
vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf;
vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf;
for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}
#pragma unroll
for (int i = 2; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i] - 8;
vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf;
vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf;
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int cur_int = cur_elem / pack_factor;
int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
}
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7};
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < pack_factor_4bit; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
out_ptr[out_offset + th_id * 4 + warp_id] = res;
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
......@@ -242,19 +262,35 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} // namespace gptq_marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n) {
int64_t size_k, int64_t size_n,
int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0),
TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k,
", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit);
", size_k = ", size_k, ", pack_factor = ", pack_factor);
TORCH_CHECK(b_q_weight.size(1) == size_n,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not size_n = ", size_n);
......@@ -273,10 +309,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out = torch::empty(
{size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit},
options);
torch::Tensor out =
torch::empty({size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / pack_factor},
options);
// Detect if there is act_order
bool has_perm = perm.size(0) != 0;
......@@ -299,23 +335,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (has_perm) {
cudaFuncSetAttribute(
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
} else {
cudaFuncSetAttribute(
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
if (false) {
}
CALL_IF(4, false)
CALL_IF(4, true)
CALL_IF(8, false)
CALL_IF(8, true)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm);
}
return out;
......
......@@ -39,6 +39,13 @@ MODELS = [
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"),
# act_order==True, group_size=32
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"),
# 8-bit, act_order==True, group_size=channelwise
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"),
# 8-bit, act_order==True, group_size=128
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-128g-actorder_True"),
# 8-bit, act_order==True, group_size=32
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-32g-actorder_True"),
]
......@@ -65,8 +72,7 @@ def test_models(
dtype=dtype,
quantization="marlin",
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1,
disable_custom_all_reduce=True)
tensor_parallel_size=1)
gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
......@@ -78,8 +84,7 @@ def test_models(
dtype=dtype,
quantization="gptq",
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1,
disable_custom_all_reduce=True)
tensor_parallel_size=1)
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
......
......@@ -169,18 +169,20 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n)
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
num_bits)
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor, size_m: int,
size_n: int, size_k: int,
perm: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int, size_k: int,
is_k_full: bool) -> torch.Tensor:
return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
workspace, size_m, size_n, size_k,
is_k_full)
workspace, num_bits, size_m, size_n,
size_k, is_k_full)
# fp8
......
......@@ -2,7 +2,6 @@ import enum
from enum import Enum
from typing import Any, Dict, List, Optional
import numpy
import torch
from torch.nn.parameter import Parameter
......@@ -17,41 +16,13 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4]
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Precompute permutations for Marlin weight and scale shuffling
#
# Marlin works on [16,64] tiles. The goal of the permutations
# is to reorder the weight data so that it is compatible
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the
# kernel will get the data as it is needed for tensor-core
# (without the need to use ldmatrix instructions)
def _get_perms():
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])
perm = numpy.array(perm)
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore
perm = torch.from_numpy(perm)
# Permutations for Marlin scale shuffling
def get_scale_perms(num_bits):
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
......@@ -59,23 +30,21 @@ def _get_perms():
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single
_perm, _scale_perm, _scale_perm_single = _get_perms()
return scale_perm, scale_perm_single
def get_pack_factor(num_bits):
assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, (
f"Unsupported num_bits = {num_bits}")
assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
), f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def marlin_permute_scales(s, size_k, size_n, group_size):
def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
......@@ -279,13 +248,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad=False,
)
set_weight_attrs(
qweight, {
qweight,
{
**extra_weight_attrs,
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
})
},
)
# Activation order
g_idx = Parameter(
......@@ -296,10 +267,13 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(g_idx, {
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
})
set_weight_attrs(
g_idx,
{
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
g_idx_sort_indices = Parameter(
torch.empty(
......@@ -320,29 +294,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad=False,
)
set_weight_attrs(
scales, {
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
})
},
)
# Quantized zero-points
qzeros = Parameter(
torch.empty(scales_and_zp_size,
output_size_per_partition //
self.quant_config.pack_factor,
dtype=torch.int32,
device="meta"),
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
device="meta",
),
requires_grad=False,
)
set_weight_attrs(
qzeros, {
qzeros,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
},
)
# Allocate marlin workspace
max_workspace_size = (
......@@ -405,13 +384,14 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
else:
# Reset g_idx related tensors
layer.g_idx = Parameter(torch.empty(0,
dtype=torch.int,
device=cur_device),
requires_grad=False)
layer.g_idx_sort_indices = Parameter(torch.empty(
0, dtype=torch.int, device=cur_device),
requires_grad=False)
layer.g_idx = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
layer.g_idx_sort_indices = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
......@@ -419,6 +399,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer.g_idx_sort_indices,
part_size_k,
part_size_n,
self.quant_config.weight_bits,
)
replace_tensor("qweight", marlin_qweight)
......@@ -428,15 +409,28 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
if self.quant_config.desc_act:
scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(layer.scales, scales_size_k,
scales_size_n,
self.quant_config.group_size)
marlin_scales = marlin_permute_scales(
layer.scales,
scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales,
layer.g_idx, layer.g_idx_sort_indices,
layer.workspace, size_m, part_size_n,
part_size_k, layer.is_k_full)
output = ops.gptq_marlin_gemm(
reshaped_x,
layer.qweight,
layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None:
output.add_(bias) # In-place add
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment