Unverified Commit 5aa1ebd2 authored by Peng Zhang's avatar Peng Zhang Committed by GitHub
Browse files

[2/n]decouple quantization implementation from vLLM dependency (#8112)


Co-authored-by: default avatarwalker-ai <yiyun.wyt@antgroup.com>
Co-authored-by: default avatarleoneo <1320612015@qq.com>
parent 4dbf4360
...@@ -321,6 +321,30 @@ def pack_cols( ...@@ -321,6 +321,30 @@ def pack_cols(
return q_res return q_res
def pack_rows(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_k % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[i::pack_factor, :] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
return q_res
def unpack_cols( def unpack_cols(
packed_q_w: torch.Tensor, packed_q_w: torch.Tensor,
num_bits: int, num_bits: int,
......
...@@ -254,13 +254,15 @@ set(SOURCES ...@@ -254,13 +254,15 @@ set(SOURCES
"csrc/gemm/per_token_quant_fp8.cu" "csrc/gemm/per_token_quant_fp8.cu"
"csrc/gemm/qserve_w4a8_per_chn_gemm.cu" "csrc/gemm/qserve_w4a8_per_chn_gemm.cu"
"csrc/gemm/qserve_w4a8_per_group_gemm.cu" "csrc/gemm/qserve_w4a8_per_group_gemm.cu"
"csrc/gemm/marlin/gptq_marlin.cu"
"csrc/gemm/marlin/gptq_marlin_repack.cu"
"csrc/gemm/marlin/awq_marlin_repack.cu"
"csrc/gemm/gptq/gptq_kernel.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu" "csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu" "csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/moe/marlin_moe_wna16/ops.cu" "csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu"
"csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu" "csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu" "csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu" "csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"
......
...@@ -161,6 +161,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -161,6 +161,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm); m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
// GPTQ related method
m.def(
"gptq_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none,"
"Tensor? b_zeros_or_none, Tensor? g_idx_or_none, Tensor? perm_or_none,"
"Tensor! workspace, int b_q_type_id, int size_m, int size_n, int size_k,"
"bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
m.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, Tensor b_gptq_scales, Tensor b_g_idx, bool "
"use_shuffle, int bit) -> Tensor");
m.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
m.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
m.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
/* /*
* From csrc/moe * From csrc/moe
*/ */
...@@ -207,12 +229,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -207,12 +229,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()"); m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum); m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("gptq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::gptq_marlin_repack);
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("awq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::awq_marlin_repack);
/* /*
* From csrc/speculative * From csrc/speculative
*/ */
......
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _compat_cuh
#define _compat_cuh
namespace sglang {
namespace gptq {
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val) {
unsigned int* address_as_ui = (unsigned int*)((char*)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do {
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do {
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
} while (assumed != old);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) {
atomicAdd_half(address, val);
}
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
atomicAdd_half2(address, val);
}
#endif
#endif
#endif
} // namespace gptq
} // namespace sglang
#endif
This diff is collapsed.
/*
Adapted from https://github.com/turboderp/exllamav2 and
https://github.com/turboderp/exllama
*/
#ifndef _matrix_view_cuh
#define _matrix_view_cuh
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "qdq_util.cuh"
namespace sglang {
namespace gptq {
class MatrixView_half {
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ half item(int row, int column) const {
return data[row * width + column];
}
__device__ __forceinline__ half2 item_half2(int row, int column) const {
return ((half2*)data)[(row * width + column) / 2];
}
__device__ __forceinline__ half2 item_half2half2(int row, int column) const {
return __half2half2(data[row * width + column]);
}
__device__ __forceinline__ const half* item_ptr(int row, int column) const {
return &data[row * width + column];
}
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const {
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __low2half(i01);
items[1] = __high2half(i01);
items[2] = __low2half(i23);
items[3] = __high2half(i23);
}
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const {
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2float(__low2half(i01));
items[1] = __half2float(__high2half(i01));
items[2] = __half2float(__low2half(i23));
items[3] = __half2float(__high2half(i23));
}
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const {
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2half2(__low2half(i01));
items[1] = __half2half2(__high2half(i01));
items[2] = __half2half2(__low2half(i23));
items[3] = __half2half2(__high2half(i23));
}
};
class MatrixView_half_rw {
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ half item(int row, int column) const {
return data[row * width + column];
}
__device__ __forceinline__ half2 item_half2(int row, int column) const {
return ((half2*)data)[(row * width + column) / 2];
}
__device__ __forceinline__ half* item_ptr(int row, int column) {
return &data[row * width + column];
}
__device__ __forceinline__ void set(int row, int column, half value) {
data[row * width + column] = value;
}
__device__ __forceinline__ void set_half2(int row, int column, half2 value) {
((half2*)data)[(row * width + column) / 2] = value;
}
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) {
half2 v01 = __halves2half2(v0, v1);
half2 v23 = __halves2half2(v2, v3);
half2* ptr = (half2*)item_ptr(row, column);
ptr[0] = v01;
ptr[1] = v23;
}
};
class MatrixView_q4_row {
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const {
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
items[2] = (d >> 8) & 0x0f;
items[3] = (d >> 12) & 0x0f;
}
};
class MatrixView_q4_column {
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const {
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) {
return data[row / 8 * width + column];
}
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) {
return &data[row / 8 * width + column];
}
};
class MatrixView_q2_row {
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const {
int shift = (column & 0x0f) * 2;
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03;
items[2] = (d >> 4) & 0x03;
items[3] = (d >> 6) & 0x03;
}
};
class MatrixView_q3_row {
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const {
int z_w = column * 3 / 32;
int z_mod = column & 0x1f;
if (z_mod == 10) {
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
} else if (z_mod == 21) {
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
} else if (z_mod < 10) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
} else if (z_mod < 21) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
} else {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
}
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
int shift = (column & 0x1f);
uint32_t d;
if (shift <= 4) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
} else if (shift == 8) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) |
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
} else if (shift <= 16) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
} else if (shift == 20) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) |
((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
} else {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
}
items[0] = d & 0x07;
items[1] = (d >> 3) & 0x07;
items[2] = (d >> 6) & 0x07;
items[3] = (d >> 9) & 0x07;
}
};
class MatrixView_q8_row {
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width) {}
__device__ __forceinline__ int item(int row, int column) const {
int shift = (column & 0x03) * 8;
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const {
int shift = (column & 0x03) * 8;
uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const {
int shift = (column & 0x03) * 2;
uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff;
items[2] = (d >> 16) & 0xff;
items[3] = (d >> 24) & 0xff;
}
};
} // namespace gptq
} // namespace sglang
#endif
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
namespace sglang {
namespace gptq {
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) {
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
uint32_t qa0 = qa & 0x03;
uint32_t qa1 = (qa & 0x0c) >> 2;
qa >>= 4;
qb |= (qa1 << (i * 2 + 16));
qb |= (qa0 << (i * 2));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, half2 (&dq)[8], int stride, const uint32_t zero) {
const uint32_t c0 = 0x64006400;
const half y4_ = __float2half_rn(1.0f / 4.0f);
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y4 = __halves2half2(y4_, y4_);
const half2 y16 = __halves2half2(y16_, y16_);
const half2 y64 = __halves2half2(y64_, y64_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __half2half2(z1_.as_half);
const half2 z4 = __half2half2(z4_);
const half2 z16 = __half2half2(z16_);
const half2 z64 = __half2half2(z64_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
qa >>= 8;
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y4, z4);
dq[2] = __hfma2(q2.as_half2, y16, z16);
dq[3] = __hfma2(q3.as_half2, y64, z64);
dq[4] = __hadd2(q4.as_half2, z1);
dq[5] = __hfma2(q5.as_half2, y4, z4);
dq[6] = __hfma2(q6.as_half2, y16, z16);
dq[7] = __hfma2(q7.as_half2, y64, z64);
}
} // namespace gptq
} // namespace sglang
#endif
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
namespace sglang {
namespace gptq {
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) {
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t qd = qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
for (int i = 0; i < 5; i++) {
uint32_t t0 = qa & 0x07;
uint32_t t1 = (qa & 0x38) >> 3;
qa >>= 6;
za |= (t0 << (i * 3));
za |= (t1 << (i * 3 + 16));
}
for (int i = 0; i < 5; i++) {
uint32_t t0 = qb & 0x07;
uint32_t t1 = (qb & 0x38) >> 3;
qb >>= 6;
zb |= (t0 << (i * 3));
zb |= (t1 << (i * 3 + 16));
}
for (int i = 0; i < 5; i++) {
uint32_t t0 = qc & 0x07;
uint32_t t1 = (qc & 0x38) >> 3;
qc >>= 6;
zc |= (t0 << (i * 3));
zc |= (t1 << (i * 3 + 16));
}
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za |= ((qd & 0x01) >> 0) << 15;
zb |= ((qd & 0x02) >> 1) << 15;
zc |= ((qd & 0x04) >> 2) << 15;
za |= ((qd & 0x08) >> 3) << 31;
zb |= ((qd & 0x10) >> 4) << 31;
zc |= ((qd & 0x20) >> 5) << 31;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
}
__forceinline__ __device__ void dequant_3bit_32(
const uint32_t q_0, const uint32_t q_1, const uint32_t q_2, half2 (&dq)[16], int stride, const uint32_t zero) {
const uint32_t c0 = 0x64006400;
const half y8_ = __float2half_rn(1.0f / 8.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y8 = __halves2half2(y8_, y8_);
const half2 y64 = __halves2half2(y64_, y64_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
const half2 z8 = __halves2half2(z8_, z8_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
qa >>= 6;
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
qa >>= 9;
qa &= 0x00010001;
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
qb >>= 6;
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
qb >>= 8;
qb &= 0x00020002;
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
qc >>= 6;
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
qc >>= 7;
qc &= 0x00040004;
half2_uint32 q15((qa | qb | qc) | c0);
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y8, z8);
dq[2] = __hadd2(q2.as_half2, z1);
dq[3] = __hfma2(q3.as_half2, y8, z8);
dq[4] = __hfma2(q4.as_half2, y64, z64);
dq[5] = __hadd2(q5.as_half2, z1);
dq[6] = __hfma2(q6.as_half2, y8, z8);
dq[7] = __hadd2(q7.as_half2, z1);
dq[8] = __hfma2(q8.as_half2, y8, z8);
dq[9] = __hfma2(q9.as_half2, y64, z64);
dq[10] = __hadd2(q10.as_half2, z1);
dq[11] = __hfma2(q11.as_half2, y8, z8);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y8, z8);
dq[14] = __hfma2(q14.as_half2, y64, z64);
dq[15] = __hadd2(q15.as_half2, z1);
}
} // namespace gptq
} // namespace sglang
#endif
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "qdq_util.cuh"
namespace sglang {
namespace gptq {
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) {
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
uint32_t qa0 = qa & 0x0f;
uint32_t qa1 = (qa & 0xf0) >> 4;
qa >>= 8;
qb |= (qa1 << (i * 4 + 16));
qb |= (qa0 << (i * 4));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, half2 (&dq)[4], int stride, const uint32_t zero) {
const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half2 z1 = __half2half2(z1_.as_half);
const half2 z16 = __half2half2(z16_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y16, z16);
dq[2] = __hadd2(q2.as_half2, z1);
dq[3] = __hfma2(q3.as_half2, y16, z16);
}
__forceinline__ __device__ void
dequant_4bit_8_prep_zero_scale(const uint32_t zero, const half scale, half2 (&z1z16)[2], half2 (&y1y16)[2]) {
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
half2 scale2 = __half2half2(scale);
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
z1z16[1] = __hmul2(scale2, __half2half2(z16));
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __hmul2(scale2, __half2half2(y1));
y1y16[1] = __hmul2(scale2, __half2half2(y16));
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, half2 (&z1z16)[2], half2 (&y1y16)[2]) {
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
z1z16[0] = __half2half2(z1.as_half);
z1z16[1] = __half2half2(z16);
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __half2half2(y1);
y1y16[1] = __half2half2(y16);
}
__forceinline__ __device__ void
dequant_4bit_8_gptq(const uint32_t q_0, half2 (&dq)[4], half2 (&z1z16)[2], half2 (&y1y16)[2], int stride, bool scaled) {
const uint32_t c0 = 0x64006400;
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if (scaled) {
dq[0] = __hfma2(q0.as_half2, y1y16[0],
z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
dq[1] = __hfma2(q1.as_half2, y1y16[1],
z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
} else {
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
dq[1] = __hfma2(q1.as_half2, y1y16[1],
z1z16[1]); // half2( q[2] - z, q[3] - z )
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
dq[3] = __hfma2(q3.as_half2, y1y16[1],
z1z16[1]); // half2( q[6] - z, q[7] - z )
}
}
} // namespace gptq
} // namespace sglang
#endif
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
namespace sglang {
namespace gptq {
__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {}
__forceinline__ __device__ void
dequant_8bit_8(const uint32_t q_0, const uint32_t q_1, half2 (&dq)[4], int stride, const uint32_t zero) {
half dqh[8];
for (int i = 0; i < 4; i++)
dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++)
dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++)
dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
} // namespace gptq
} // namespace sglang
#endif
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
namespace sglang {
namespace gptq {
union half2_uint32 {
uint32_t as_uint32;
half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
};
union half_uint16 {
uint16_t as_uint16;
half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {}
};
// Max_scale premultiplied by 1/256
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) {
int qs_i = qs + 1;
half qs_h = __int2half_rn(qs_i * qs_i);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) {
return __hmul(__int2half_rn(q - qzero), scale);
}
__forceinline__ __device__ half dq_ns(const int q, const int qzero) {
// return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return __int2half_rn(q - qzero);
}
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) {
return (int)((q >> shift) & mask);
}
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) {
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
}
} // namespace gptq
} // namespace sglang
#endif
#ifndef MARLIN_NAMESPACE_NAME #include "marlin.cuh"
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "core/registration.h"
#include "gptq_marlin/marlin.cuh"
#include "kernel.h"
namespace MARLIN_NAMESPACE_NAME {
namespace marlin {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async in awq_marlin_repack_kernel template <int const num_threads, int const num_bits>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {
return;
}
#else #else
template <int const num_threads, int const num_bits> template <int const num_threads, int const num_bits>
...@@ -178,21 +175,33 @@ __global__ void awq_marlin_repack_kernel( ...@@ -178,21 +175,33 @@ __global__ void awq_marlin_repack_kernel(
} }
} }
} }
#endif
#define CALL_IF(NUM_BITS) \ } // namespace marlin
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \ #define CALL_IF(NUM_BITS) \
awq_marlin_repack_kernel<repack_threads, NUM_BITS>, \ else if (num_bits == NUM_BITS) { \
cudaFuncAttributeMaxDynamicSharedMemorySize, \ cudaFuncSetAttribute( \
max_shared_mem); \ marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
awq_marlin_repack_kernel<repack_threads, NUM_BITS> \ cudaFuncAttributeMaxDynamicSharedMemorySize, \
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, out_ptr, size_k, size_n); \ max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) { torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); TORCH_CHECK(
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); size_k % marlin::tile_k_size == 0,
"size_k = ",
size_k,
" is not divisible by tile_k_size = ",
marlin::tile_k_size);
TORCH_CHECK(
size_n % marlin::tile_n_size == 0,
"size_n = ",
size_n,
" is not divisible by tile_n_size = ",
marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits; int const pack_factor = 32 / num_bits;
...@@ -216,7 +225,7 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64 ...@@ -216,7 +225,7 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64
// Alloc buffers // Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device()); auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options); torch::Tensor out = torch::empty({size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options);
// Get ptrs // Get ptrs
uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr()); uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
...@@ -242,14 +251,3 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64 ...@@ -242,14 +251,3 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64
return out; return out;
} }
torch::Tensor
awq_marlin_repack_meta(torch::Tensor& b_q_weight, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
}
#endif
} // namespace MARLIN_NAMESPACE_NAME
/*
Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16)
The process of fast dequantization can be summarized as a combination
of bitwise operations and floating-point computations:
weight =>(bit_op / bitwise operations)=>
f16_value =>(flop / floating-point computation)=>
dequantized_weight
Since the dequantized weights typically require subtracting the zero point and
applying a scale factor, the floating-point computation step can be fused with
the zero-point subtraction and scaling operations.
The following are the parts that need to be modified for the fused operation
of zero-point subtraction and scaling.
## INT4 => FP16/BF16 or INT8 => FP16
The floating-point computation is `__hsub2`
If has zero points:
flop(bit_op(weight)) - flop(bit_op(zp))
= sub(bit_op(weight), bias) - sub(bit_op(zp), bias)
= bit_op(weight) - bit_op(zp)
so we don't need additional modification.
If has float zero points:
flop(bit_op(weight)) - fzp
= sub(bit_op(weight), bias) - fzp
= bit_op(weight) - (fzp + bias)
where the `fzp + bias` can be computed at weight loading. But this
may have accuracy issue, so we should not use this in most cases.
If has not zero points:
scale(flop(bit_op(weight)))
= scale(sub(bit_op(weight), bias))
= scale(bit_op(weight)) - scale(bias)
= fma(bit_op(weight), scale_factor, scale(bias))
where the `scale(bias)` can be cached. But this may have accuracy issue,
so we should not use this in most cases.
## INT8 => BF16
INT8 => BF16 is a special case, it use byte_perm instead of flop.
We cannot fused byte_perm with scaling.
## FP4/FP8 => FP16/BF16
scale(flop(bit_op(weight)))
= scale(mul(bit_op(weight), multiplier))
= mul(bit_op(weight), scale_factor * multiplier)
where `scale_factor * multiplier` can be computed at weight loading.
*/
#include "marlin_dtypes.cuh"
namespace MARLIN_NAMESPACE_NAME {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t2, sglang::ScalarTypeId w_type_id, bool skip_flop = false>
__device__ inline void dequant(int q, scalar_t2* frag_b);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline void dequant<half2, sglang::kU4B8.id(), true>(int q, half2* frag_b) {
const int MASK = 0x000f000f;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
frag_b[0] = *reinterpret_cast<half2*>(&lo);
frag_b[1] = *reinterpret_cast<half2*>(&hi);
}
template <>
__device__ inline void dequant<half2, sglang::kU4B8.id(), false>(int q, half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(
*reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD));
}
template <>
__device__ inline void dequant<half2, sglang::kU4.id(), true>(int q, half2* frag_b) {
dequant<half2, sglang::kU4B8.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<half2, sglang::kU4.id(), false>(int q, half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(
*reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD));
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kU4B8.id(), true>(int q, nv_bfloat162* frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
// clang-format on
frag_b[0] = *reinterpret_cast<nv_bfloat162*>(&lo);
frag_b[1] = *reinterpret_cast<nv_bfloat162*>(&hi);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kU4B8.id(), false>(int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, sglang::kU4B8.id(), true>(q, frag_b);
static constexpr uint32_t SUB = 0x43084308;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kU4.id(), true>(int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, sglang::kU4B8.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kU4.id(), false>(int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, sglang::kU4.id(), true>(q, frag_b);
static constexpr uint32_t SUB = 0x43004300;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const nv_bfloat162*>(&SUB));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const nv_bfloat162*>(&SUB));
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline void dequant<half2, sglang::kU8B128.id(), true>(int q, half2* frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
frag_b[0] = *reinterpret_cast<half2*>(&lo);
frag_b[1] = *reinterpret_cast<half2*>(&hi);
}
template <>
__device__ inline void dequant<half2, sglang::kU8B128.id(), false>(int q, half2* frag_b) {
dequant<half2, sglang::kU8B128.id(), true>(q, frag_b);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<half2, sglang::kU8.id(), true>(int q, half2* frag_b) {
dequant<half2, sglang::kU8B128.id(), true>(q, frag_b);
}
template <>
__device__ inline void dequant<half2, sglang::kU8.id(), false>(int q, half2* frag_b) {
dequant<half2, sglang::kU8.id(), true>(q, frag_b);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kU8B128.id(), false>(int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kU8.id(), false>(int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388608.f;
fp32_intermediates[1] -= 8388608.f;
fp32_intermediates[2] -= 8388608.f;
fp32_intermediates[3] -= 8388608.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<half2, sglang::kFE4M3fn.id(), true>(int q, half2* frag_b) {
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
}
template <>
__device__ inline void dequant<half2, sglang::kFE4M3fn.id(), false>(int q, half2* frag_b) {
dequant<half2, sglang::kFE4M3fn.id(), true>(q, frag_b);
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kFE4M3fn.id(), true>(int q, nv_bfloat162* frag_b) {
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to BF16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kFE4M3fn.id(), false>(int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, sglang::kFE4M3fn.id(), true>(q, frag_b);
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to bfloat162 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<half2, sglang::kFE2M1f.id(), true>(int q, half2* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70007000;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
}
template <>
__device__ inline void dequant<half2, sglang::kFE2M1f.id(), false>(int q, half2* frag_b) {
dequant<half2, sglang::kFE2M1f.id(), true>(q, frag_b);
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kFE2M1f.id(), true>(int q, nv_bfloat162* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70007000;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
template <>
__device__ inline void dequant<nv_bfloat162, sglang::kFE2M1f.id(), false>(int q, nv_bfloat162* frag_b) {
dequant<nv_bfloat162, sglang::kFE2M1f.id(), true>(q, frag_b);
// Constants for FP4 (E2M1) and BF16 formats
constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8;
// Construct and apply exponent bias
constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to half2 and apply bias
frag_b[1] = __hmul2(frag_b[1], bias_reg);
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <typename scalar_t2>
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
template <>
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
int Out1 = (q & 0xFF00FF00) >> 1;
;
q <<= 8;
int Out2 = (q & 0xFF00FF00) >> 1;
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const half2*>(&Out1);
frag_b[0] = *reinterpret_cast<const half2*>(&Out2);
};
template <>
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q, nv_bfloat162* frag_b) {
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
constexpr int MASK = 0x7F007F00;
// Extract and shift FP8 values to BF16 format
int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 8;
int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT);
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
}
#endif
} // namespace MARLIN_NAMESPACE_NAME
This diff is collapsed.
#ifndef MARLIN_NAMESPACE_NAME #include "marlin.cuh"
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "gptq_marlin/marlin.cuh"
namespace MARLIN_NAMESPACE_NAME {
namespace marlin {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async in gptq_marlin_repack_kernel template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr,
uint32_t* __restrict__ out_ptr,
int size_k,
int size_n) {
return;
}
#else #else
template <int const num_threads, int const num_bits, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void gptq_marlin_repack_kernel( __global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ b_q_weight_ptr,
...@@ -23,7 +25,7 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -23,7 +25,7 @@ __global__ void gptq_marlin_repack_kernel(
int n_tiles = size_n / tile_n_size; int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x); int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles; auto start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) { if (start_k_tile >= k_tiles) {
return; return;
} }
...@@ -79,8 +81,8 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -79,8 +81,8 @@ __global__ void gptq_marlin_repack_kernel(
if constexpr (has_perm) { if constexpr (has_perm) {
if (threadIdx.x < stage_size) { if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads; auto k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads; auto n_id = threadIdx.x % stage_n_threads;
uint32_t const* sh_perm_int_ptr = reinterpret_cast<uint32_t const*>(sh_perm_ptr); uint32_t const* sh_perm_int_ptr = reinterpret_cast<uint32_t const*>(sh_perm_ptr);
...@@ -94,8 +96,8 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -94,8 +96,8 @@ __global__ void gptq_marlin_repack_kernel(
} else { } else {
if (threadIdx.x < stage_size) { if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads; auto k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads; auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size; int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor; int first_k_packed = first_k / pack_factor;
...@@ -114,8 +116,8 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -114,8 +116,8 @@ __global__ void gptq_marlin_repack_kernel(
return; return;
} }
int warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32; auto th_id = threadIdx.x % 32;
if (warp_id >= 4) { if (warp_id >= 4) {
return; return;
...@@ -237,22 +239,35 @@ __global__ void gptq_marlin_repack_kernel( ...@@ -237,22 +239,35 @@ __global__ void gptq_marlin_repack_kernel(
} }
} }
} }
#endif
#define CALL_IF(NUM_BITS, HAS_PERM) \ } // namespace marlin
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \ #define CALL_IF(NUM_BITS, HAS_PERM) \
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM>, \ else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncAttributeMaxDynamicSharedMemorySize, \ cudaFuncSetAttribute( \
max_shared_mem); \ marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, HAS_PERM>, \
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM> \ cudaFuncAttributeMaxDynamicSharedMemorySize, \
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
} }
torch::Tensor torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); TORCH_CHECK(
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); size_k % marlin::tile_k_size == 0,
"size_k = ",
size_k,
" is not divisible by tile_k_size = ",
marlin::tile_k_size);
TORCH_CHECK(
size_n % marlin::tile_n_size == 0,
"size_n = ",
size_n,
" is not divisible by tile_n_size = ",
marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits; int const pack_factor = 32 / num_bits;
...@@ -280,7 +295,7 @@ gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_ ...@@ -280,7 +295,7 @@ gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_
// Alloc buffers // Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device()); auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options); torch::Tensor out = torch::empty({size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options);
// Detect if there is act_order // Detect if there is act_order
bool has_perm = perm.size(0) != 0; bool has_perm = perm.size(0) != 0;
...@@ -312,22 +327,3 @@ gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_ ...@@ -312,22 +327,3 @@ gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_
return out; return out;
} }
torch::Tensor gptq_marlin_repack_meta(
torch::Tensor& b_q_weight, torch::Tensor& perm, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
}
#endif
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
// m.impl("gptq_marlin_repack", &gptq_marlin_repack);
// }
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
// m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
// }
} // namespace MARLIN_NAMESPACE_NAME
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const uint16_t *__restrict__ scale2_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {
template <
typename scalar_t, // compute dtype, half or nv_float16
const sglang::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
}
...@@ -10,11 +10,10 @@ ...@@ -10,11 +10,10 @@
#include <iostream> #include <iostream>
#ifndef MARLIN_NAMESPACE_NAME #ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #define MARLIN_NAMESPACE_NAME marlin
#endif #endif
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
// Marlin params // Marlin params
// 8 warps are a good choice since every SM has 4 schedulers and having more // 8 warps are a good choice since every SM has 4 schedulers and having more
...@@ -91,6 +90,7 @@ template <int n> ...@@ -91,6 +90,7 @@ template <int n>
__device__ inline void cp_async_wait() { __device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
} }
#endif #endif
} // namespace MARLIN_NAMESPACE_NAME } // namespace MARLIN_NAMESPACE_NAME
#ifndef _data_types_cuh #ifndef _data_types_cuh
#define _data_types_cuh #define _data_types_cuh
#include <cuda_bf16.h> #include <cuda_bf16.h>
...@@ -7,7 +6,7 @@ ...@@ -7,7 +6,7 @@
#include "marlin.cuh" #include "marlin.cuh"
#ifndef MARLIN_NAMESPACE_NAME #ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #define MARLIN_NAMESPACE_NAME marlin
#endif #endif
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
......
This diff is collapsed.
#pragma once
#include <Python.h>
#define SGLANG_IMPLIES(p, q) (!(p) || (q))
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment