Unverified Commit 1656ad37 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)


Signed-off-by: default avatarJinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Signed-off-by: default avatarJinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin@redhat.com>
parent fa59fe41
#include "marlin.cuh"
#include "core/registration.h"
// for only non-zp format (like gptq)
__global__ void marlin_int4_fp8_preprocess_kernel_without_zp(
// qweight: (size_k * size_n // 8,)
const int32_t* __restrict__ qweight,
// output: same shape with qweight
int32_t* __restrict__ output) {
int32_t val = qweight[blockIdx.x * 32 + threadIdx.x];
int32_t new_val = 0;
#pragma unroll
for (int32_t i = 0; i < 8; i++) {
int32_t single_val = val & 0xF;
single_val = single_val >= 8 ? single_val - 8 : 15 - single_val;
new_val |= single_val << (i * 4);
val >>= 4;
}
output[blockIdx.x * 32 + threadIdx.x] = new_val;
}
// for awq format only (with zp and with awq weight layout)
__global__ void marlin_int4_fp8_preprocess_kernel_awq(
// AWQ qweight: (size_k, size_n // 8)
const int32_t* __restrict__ qweight,
// output: same shape with qweight
int32_t* __restrict__ output,
// AWQ zeros: (size_k // group_size, size_n // 8)
const int32_t* __restrict__ qzeros, int32_t size_n, int32_t size_k,
int32_t group_size) {
int32_t val =
qweight[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y];
int32_t zero =
qzeros[(blockIdx.x * 32 + threadIdx.x) / group_size * size_n / 8 +
blockIdx.y];
int32_t new_val = 0;
#pragma unroll
for (int32_t i = 0; i < 8; i++) {
int32_t single_val = val & 0xF;
int32_t single_zero = zero & 0xF;
single_val =
single_val >= single_zero ? single_val - single_zero : 15 - single_val;
new_val |= single_val << (i * 4);
val >>= 4;
zero >>= 4;
}
output[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y] = new_val;
}
torch::Tensor marlin_int4_fp8_preprocess(
torch::Tensor& qweight, std::optional<torch::Tensor> qzeros_or_none,
bool inplace) {
TORCH_CHECK(qweight.device().is_cuda(), "qweight is not on GPU");
TORCH_CHECK(qweight.scalar_type() == at::ScalarType::Int,
"qweight.dtype != torch.int32");
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
torch::Tensor output = inplace ? qweight : torch::empty_like(qweight);
if (!qzeros_or_none.has_value()) {
TORCH_CHECK(qweight.numel() * 8 % 256 == 0,
"qweight.numel() * 8 % 256 != 0");
int blocks = qweight.numel() * 8 / 256;
marlin_int4_fp8_preprocess_kernel_without_zp<<<blocks, 32>>>(
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr());
} else {
int32_t size_k = qweight.size(0);
int32_t size_n = qweight.size(1) * 8;
torch::Tensor qzeros = qzeros_or_none.value();
TORCH_CHECK(size_k % 32 == 0, "size_k % 32 != 0");
TORCH_CHECK(qzeros.device().is_cuda(), "qzeros is not on GPU");
TORCH_CHECK(qzeros.scalar_type() == at::ScalarType::Int,
"qweight.dtype != torch.int32");
TORCH_CHECK(device_of(qweight) == device_of(qzeros),
"qzeros is not on the same device with qweight");
int32_t group_size = qweight.size(0) / qzeros.size(0);
TORCH_CHECK(qweight.size(1) == qzeros.size(1),
"qweight.size(1) != qzeros.size(1)");
TORCH_CHECK(qweight.size(0) % qzeros.size(0) == 0,
"qweight.size(0) % qzeros.size(0) != 0");
TORCH_CHECK(group_size % 8 == 0, "group_size % 8 != 0");
dim3 blocks(size_k / 32, size_n / 8);
marlin_int4_fp8_preprocess_kernel_awq<<<blocks, 32>>>(
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr(),
(const int32_t*)qzeros.data_ptr(), size_n, size_k, group_size);
}
return output;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_int4_fp8_preprocess", &marlin_int4_fp8_preprocess);
}
......@@ -38,7 +38,7 @@ namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
......@@ -77,65 +77,139 @@ __global__ void Marlin(
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <typename scalar_t>
__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,
const typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::FragC& frag_c) {
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
float* c = reinterpret_cast<float*>(&frag_c);
if constexpr (std::is_same<scalar_t, half>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
template <typename scalar_t>
template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma_trans(
const typename ScalarType<scalar_t>::FragA& a_frag,
const typename ScalarType<scalar_t>::FragB& frag_b,
const typename ScalarType<scalar_t>::FragB& frag_b2,
typename ScalarType<scalar_t>::FragC& frag_c) {
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
if constexpr (std::is_same<scalar_t, half>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (k_size == 16) {
if constexpr (std::is_same<scalar_t, half>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template <int count, typename scalar_t>
__device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
template <int count, vllm::ScalarTypeId type_id>
__device__ inline void ldsm(typename MarlinScalarType<type_id>::FragA& frag_a,
const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
......@@ -159,47 +233,54 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template <typename scalar_t>
__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::FragS& frag_s,
template <vllm::ScalarTypeId type_id>
__device__ inline void scale(typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragS& frag_s,
int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 s =
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
using scalar_t2 = typename MarlinScalarType<type_id>::scalar_t2;
scalar_t2 s = MarlinScalarType<type_id>::num2num2(
reinterpret_cast<scalar_t*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
}
template <typename scalar_t>
template <vllm::ScalarTypeId type_id>
__device__ inline void scale_and_sub(
typename ScalarType<scalar_t>::FragB& frag_b, scalar_t s, scalar_t zp) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 s2 = ScalarType<scalar_t>::num2num2(s);
scalar_t2 zp2 = ScalarType<scalar_t>::num2num2(zp);
typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::scalar_t s,
typename MarlinScalarType<type_id>::scalar_t zp) {
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
using scalar_t2 = typename MarlinScalarType<type_id>::scalar_t2;
scalar_t2 s2 = MarlinScalarType<type_id>::num2num2(s);
scalar_t2 zp2 = MarlinScalarType<type_id>::num2num2(zp);
frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2));
frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2));
}
template <typename scalar_t>
__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::scalar_t2& frag_zp,
int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 zp =
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
template <vllm::ScalarTypeId type_id>
__device__ inline void sub_zp(
typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::scalar_t2& frag_zp, int i) {
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
using scalar_t2 = typename MarlinScalarType<type_id>::scalar_t2;
scalar_t2 zp = MarlinScalarType<type_id>::num2num2(
reinterpret_cast<scalar_t*>(&frag_zp)[i]);
frag_b[0] = __hsub2(frag_b[0], zp);
frag_b[1] = __hsub2(frag_b[1], zp);
}
// Same as above, but for act_order (each K is multiplied individually)
template <typename scalar_t>
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::FragS& frag_s_1,
typename ScalarType<scalar_t>::FragS& frag_s_2,
typename ScalarType<scalar_t>::FragS& frag_s_3,
typename ScalarType<scalar_t>::FragS& frag_s_4,
int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
template <vllm::ScalarTypeId type_id>
__device__ inline void scale4(
typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragS& frag_s_1,
typename MarlinScalarType<type_id>::FragS& frag_s_2,
typename MarlinScalarType<type_id>::FragS& frag_s_3,
typename MarlinScalarType<type_id>::FragS& frag_s_4, int i) {
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
using scalar_t2 = typename MarlinScalarType<type_id>::scalar_t2;
scalar_t2 s_val_1_2;
s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];
s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];
......@@ -213,12 +294,13 @@ __device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
}
// Given 2 floats multiply by 2 scales (halves)
template <typename scalar_t>
__device__ inline void scale_float(float* c,
typename ScalarType<scalar_t>::FragS& s) {
template <vllm::ScalarTypeId type_id>
__device__ inline void scale_float(
float* c, typename MarlinScalarType<type_id>::FragS& s) {
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
c[0] = __fmul_rn(c[0], MarlinScalarType<type_id>::num2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], MarlinScalarType<type_id>::num2float(s_ptr[1]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
......@@ -270,9 +352,10 @@ __device__ inline void wait_negative_and_add(int* lock) {
__syncthreads();
}
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
......@@ -288,18 +371,23 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ A0, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C0, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ b_bias_ptr,
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
// only)
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
// float scales of input matrix, only used when is_a_8bit == true.
// shape (m,)
const float* __restrict__ a_scales_ptr,
// fp16 quantization scales. shape (k/groupsize, n)
const int4* __restrict__ scales_ptr,
// fp16 global scale (for nvfp4// only)
const uint16_t* __restrict__ global_scale_ptr,
// 4bit packed zero-points of shape
// (k/groupsize, n/pack_factor)
const int4* __restrict__ zp_ptr,
// int32 group indices of shape k
const int* __restrict__ g_idx,
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
......@@ -321,17 +409,35 @@ __global__ void Marlin(
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
using Dtype = ScalarType<scalar_t>;
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
using FragA = typename ScalarType<scalar_t>::FragA;
using FragB = typename ScalarType<scalar_t>::FragB;
using FragC = typename ScalarType<scalar_t>::FragC;
using FragS = typename ScalarType<scalar_t>::FragS;
using FragZP = typename ScalarType<scalar_t>::FragZP;
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 890
// FP8 computation is only supported for Ada Lovelace or newer architectures.
if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
#endif
using Adtype = MarlinScalarType<a_type_id>;
using Cdtype = MarlinScalarType<c_type_id>;
const int4* A = A0;
int4* C = C0;
using scalar_t = typename MarlinScalarType<a_type_id>::scalar_t;
using scalar_t2 = typename MarlinScalarType<a_type_id>::scalar_t2;
using scalar_32bit_t = typename MarlinScalarType<a_type_id>::scalar_32bit_t;
using c_scalar_t = typename MarlinScalarType<c_type_id>::scalar_t;
using c_scalar_t2 = typename MarlinScalarType<c_type_id>::scalar_t2;
using FragA = typename MarlinScalarType<a_type_id>::FragA;
using FragB = typename MarlinScalarType<a_type_id>::FragB;
using FragC = typename MarlinScalarType<a_type_id>::FragC;
using FragS = typename MarlinScalarType<c_type_id>::FragS;
using FragZP = typename MarlinScalarType<c_type_id>::FragZP;
static constexpr auto a_type = vllm::ScalarType::from_id(a_type_id);
static constexpr auto b_type = vllm::ScalarType::from_id(b_type_id);
static constexpr auto c_type = vllm::ScalarType::from_id(c_type_id);
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
if constexpr (w_type == vllm::kFE2M1f) {
if constexpr (b_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
......@@ -340,27 +446,35 @@ __global__ void Marlin(
static_assert(s_type == vllm::kFloat16);
}
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
constexpr bool is_a_8bit = a_type.size_bits() == 8;
if constexpr (!is_a_8bit) {
static_assert(std::is_same<scalar_t, c_scalar_t>::value);
}
constexpr bool has_zp = b_type == vllm::kU4 || b_type == vllm::kU8;
constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 ||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
w_type == vllm::kFE4M3fn ||
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
is_a_8bit || b_type == vllm::kFE4M3fn ||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
has_zp && !is_zp_float && !(b_type == vllm::kU8);
c_scalar_t2 global_scale;
scalar_t2 global_scale;
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
// NVFP4 format requires global scale
uint16_t val = scale2_ptr[0];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
uint16_t val = global_scale_ptr[0];
global_scale = Cdtype::num2num2(*reinterpret_cast<c_scalar_t*>(&val));
}
constexpr bool has_act_order = group_blocks == 0;
constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
constexpr int pack_factor = 32 / w_type.size_bits();
extern __shared__ int4 sh[];
float* sh_a_s = reinterpret_cast<float*>(sh);
int4* sh_new = sh + (is_a_8bit ? (4 * thread_m_blocks) : 0);
constexpr int pack_factor = 32 / b_type.size_bits();
static_assert(thread_m_blocks == 1 || !m_block_size_8);
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
......@@ -373,7 +487,19 @@ __global__ void Marlin(
int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks;
int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
int global_mn_tiles = parallel * n_tiles;
int part2_mn_tiles = global_mn_tiles;
int part1_mn_iters = 0;
bool in_part2 = false;
if (global_mn_tiles > gridDim.x) {
part2_mn_tiles = global_mn_tiles % gridDim.x;
if (part2_mn_tiles * 3 <= gridDim.x) part2_mn_tiles += gridDim.x;
part1_mn_iters = (global_mn_tiles - part2_mn_tiles) / gridDim.x;
}
int iters = div_ceil(k_tiles * part2_mn_tiles, gridDim.x);
if constexpr (!has_act_order && group_blocks != -1) {
if (group_blocks >= thread_k_blocks) {
......@@ -385,28 +511,21 @@ __global__ void Marlin(
}
}
int slice_row = (iters * blockIdx.x) % k_tiles;
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par;
int slice_iters; // number of threadblock tiles in the current slice
int slice_count =
0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to
// top
int slice_row = 0;
int slice_col_par = blockIdx.x;
int slice_col;
int slice_iters =
k_tiles; // number of threadblock tiles in the current slice
// total number of active threadblocks in the current slice
int slice_count = 1;
// index of threadblock in current slice; numbered bottom to top
int slice_idx = 0;
int par_id = 0;
int locks_off = 0;
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if (slice_col_par >= n_tiles) {
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8;
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
slice_col = slice_col_par % n_tiles;
par_id = slice_col_par / n_tiles;
}
if (parallel * n_tiles >= gridDim.x) {
// when parallel * n_tiles >= sms
if (part2_mn_tiles >= gridDim.x) {
// when part2_mn_tiles >= sms
// then there are at most $sms$ conflict tile blocks
locks_off = blockIdx.x;
} else {
......@@ -415,10 +534,11 @@ __global__ void Marlin(
// Compute all information about the current slice which is required for
// synchronization.
auto init_slice = [&](bool first_init = false) {
bool first_init = true;
auto init_part2_slice = [&]() {
slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
if (slice_iters < 0 || slice_col_par >= part2_mn_tiles) slice_iters = 0;
if (slice_iters == 0) return;
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
slice_count = 1;
......@@ -436,7 +556,7 @@ __global__ void Marlin(
if (col_off > 0) slice_idx--;
}
}
if (parallel * n_tiles >= gridDim.x) {
if (part2_mn_tiles >= gridDim.x) {
if (slice_count > 1 && slice_idx == slice_count - 1) {
locks_off++;
}
......@@ -466,28 +586,68 @@ __global__ void Marlin(
}
if (slice_col == n_tiles) {
A += 16 * thread_m_blocks * lda / 8;
A += 16 * thread_m_blocks * lda / (is_a_8bit ? 16 : 8);
C += 16 * thread_m_blocks * prob_n / 8;
slice_col = 0;
par_id++;
}
if (is_a_8bit && (first_init || slice_col == 0)) {
__syncthreads();
int a_s_gl_rd = par_id * 16 * thread_m_blocks + threadIdx.x;
cp_async1_ca_pred(&sh_a_s[threadIdx.x], &a_scales_ptr[a_s_gl_rd],
threadIdx.x < prob_m);
}
};
init_slice(true);
auto init_part1_slice = [&]() {
if (part1_mn_iters) {
part1_mn_iters--;
par_id = slice_col_par / n_tiles;
slice_col = slice_col_par % n_tiles;
slice_iters = k_tiles;
A = A0 + 16 * thread_m_blocks / (is_a_8bit ? 16 : 8) * par_id * lda;
C = C0 + 16 * thread_m_blocks / 8 * par_id * prob_n;
if (is_a_8bit) {
__syncthreads();
int a_s_gl_rd = par_id * 16 * thread_m_blocks + threadIdx.x;
cp_async1_ca_pred(&sh_a_s[threadIdx.x], &a_scales_ptr[a_s_gl_rd],
threadIdx.x < prob_m);
}
}
};
auto init_slice = [&]() {
if (!in_part2 && !part1_mn_iters) {
in_part2 = true;
slice_col_par = (iters * blockIdx.x) / k_tiles;
slice_row = (iters * blockIdx.x) % k_tiles;
slice_col = (slice_col_par + global_mn_tiles - part2_mn_tiles) % n_tiles;
par_id = (slice_col_par + global_mn_tiles - part2_mn_tiles) / n_tiles;
A = A0 + 16 * thread_m_blocks / (is_a_8bit ? 16 : 8) * par_id * lda;
C = C0 + 16 * thread_m_blocks / 8 * par_id * prob_n;
}
if (!in_part2) {
init_part1_slice();
} else {
init_part2_slice();
first_init = false;
}
};
init_slice();
// A sizes/strides
// stride of the A matrix in global memory
int a_gl_stride = lda / 8;
int a_gl_stride = lda / (is_a_8bit ? 16 : 8);
// stride of an A matrix tile in shared memory
constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
constexpr int a_sh_stride = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8);
// delta between subsequent A tiles in global memory
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8);
// between subsequent accesses within a tile
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
// between shared memory writes
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
// between shared memory tile reads
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
// within a shared memory tile
constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
// overall size of a tile
......@@ -496,24 +656,25 @@ __global__ void Marlin(
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
// B sizes/strides
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
int b_gl_stride = 16 * prob_n / (pack_factor * (is_a_8bit ? 2 : 4));
constexpr int b_sh_stride =
((thread_n_blocks * 16) * 16 / pack_factor) / (is_a_8bit ? 2 : 4);
constexpr int b_thread_vecs = b_type.size_bits() == 4 ? 1 : 2;
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks / (is_a_8bit ? 2 : 1);
constexpr int b_sh_wr_delta = threads * b_thread_vecs;
constexpr int b_sh_rd_delta = threads * b_thread_vecs;
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
constexpr int b_sh_stage =
b_sh_stride * thread_k_blocks / (is_a_8bit ? 2 : 1);
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride =
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
? thread_k_blocks / group_blocks
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
......@@ -527,7 +688,7 @@ __global__ void Marlin(
int act_s_col_stride = 1;
int act_s_col_warp_stride = act_s_col_stride * 8;
int tb_n_warps = thread_n_blocks / 4;
constexpr int tb_n_warps = thread_n_blocks / (is_a_8bit ? 2 : 4);
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
// Zero-points sizes/strides
......@@ -550,17 +711,22 @@ __global__ void Marlin(
int a_sh_rd =
a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) +
(threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1));
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
a_sh_rd += 2 * ((threadIdx.x / 32) / tb_n_warps) * b_sh_wr_iters;
int b_gl_rd;
if (threads <= b_sh_stride) {
b_gl_rd = threadIdx.x;
} else {
b_gl_rd =
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
}
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
auto b_sh_wr = threadIdx.x * b_thread_vecs;
auto b_sh_rd = threadIdx.x * b_thread_vecs;
b_sh_rd += b_sh_rd / b_sh_stride * (b_sh_stride * (b_sh_wr_iters - 1));
// For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
int slice_k_start = tb_k * slice_row;
int slice_k_finish = slice_k_start + tb_k * slice_iters;
int slice_k_start_shared_fetch = slice_k_start;
......@@ -571,58 +737,54 @@ __global__ void Marlin(
if constexpr (!has_act_order) {
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
(w_type == vllm::kFE2M1f ? 2 : 1) +
} else if constexpr (group_blocks >= thread_k_blocks) {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks +
threadIdx.x / s_sh_stride) +
s_sh_stride * slice_col + threadIdx.x % s_sh_stride;
}
}
auto s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
bool s_sh_wr_pred = threadIdx.x < s_sh_stage;
// Zero-points
int zp_gl_rd;
if constexpr (has_zp) {
if constexpr (group_blocks == -1) {
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else {
} else if constexpr (group_blocks >= thread_k_blocks) {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
zp_sh_stride * slice_col + threadIdx.x;
} else {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks +
threadIdx.x / zp_sh_stride) +
zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride;
}
}
auto zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
bool zp_sh_wr_pred = zp_sh_stage > 0 && threadIdx.x < zp_sh_stage;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
if constexpr (is_a_8bit) {
s_sh_rd = 4 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 4);
} else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 &&
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8;
else
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4;
int bias_sh_rd;
if constexpr (m_block_size_8) {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 8;
bias_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8;
} else {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
bias_sh_rd = (is_a_8bit ? 4 : 8) * ((threadIdx.x / 32) % tb_n_warps) +
(threadIdx.x % 32) % 4;
}
......@@ -638,12 +800,16 @@ __global__ void Marlin(
if constexpr (has_zp) {
if constexpr (is_zp_float) {
if constexpr (group_blocks != -1) {
zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
zp_sh_rd =
8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4;
}
} else if (is_a_8bit) {
zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % tb_n_warps / 2) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
} else {
zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
((threadIdx.x / 32) % tb_n_warps) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
}
}
......@@ -678,26 +844,19 @@ __global__ void Marlin(
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < thread_m_blocks; j++)
a_sh_rd_trans[i][j] =
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
a_sh_rd_trans[i][j] = transform_a(2 * i + a_sh_rd_delta_i * j + a_sh_rd);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const int4* B_ptr[b_sh_wr_iters];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines.
constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks;
constexpr int sh_b_size = stages * b_sh_stage;
int4* sh_b = sh;
int4* sh_red = sh;
int4* sh_b = sh_new;
int4* sh_red = sh_new;
constexpr int sh_size_b_red_min =
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_max =
......@@ -708,8 +867,8 @@ __global__ void Marlin(
? sh_size_b_red_max
: (sh_size_b_red_min + sh_bias_size);
int4* sh_bias = sh + sh_size_b_red_min;
int4* sh_g_idx = sh + sh_b_red_bias_size;
int4* sh_bias = sh_new + sh_size_b_red_min;
int4* sh_g_idx = sh_new + sh_b_red_bias_size;
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
......@@ -723,7 +882,8 @@ __global__ void Marlin(
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2];
FragC frag_c[thread_m_blocks][is_a_8bit ? 2 : 4][2];
FragC frag_c_tmp[thread_m_blocks][is_a_8bit ? 2 : 4][2];
FragS frag_s[2][4]; // No act-order
FragS frag_bias[2][4];
FragS act_frag_s[2][4][4]; // For act-order
......@@ -731,6 +891,24 @@ __global__ void Marlin(
FragZP frag_zp; // Zero-points in fp16
FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ
if constexpr (is_a_8bit) {
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
frag_c_tmp[i][j][0][g] = 0.0f;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
frag_c_tmp[i][j][1][g] = 0.0f;
}
}
}
}
// Zero accumulators.
auto zero_accums = [&]() {
#pragma unroll
......@@ -788,15 +966,17 @@ __global__ void Marlin(
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
}
for (int i = 0; i < (b_sh_wr_iters * b_thread_vecs); i++) {
constexpr int count = div_ceil(b_sh_stride, threads);
int b_gl_idx =
b_gl_rd + (i % count) * threads +
b_gl_stride * (i / count) * div_ceil(threads, b_sh_stride);
B_ptr[i] += b_gl_rd_delta_o;
cp_async4(&sh_b_stage[threads * i + threadIdx.x], &B[b_gl_idx]);
}
b_gl_rd += b_gl_rd_delta_o;
if constexpr (has_act_order) {
// Fetch g_idx thread-block portion
int full_pipe = a_off;
......@@ -816,44 +996,24 @@ __global__ void Marlin(
if constexpr (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
} else {
for (int i = 0; i < s_tb_groups; i++) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
&scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
// Only fetch scales if this tile starts a new group
if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta * s_tb_groups;
}
}
if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch zero-points if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
&zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
// Only fetch zero points if this tile starts a new group
if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta * zp_tb_groups;
}
}
}
......@@ -891,14 +1051,14 @@ __global__ void Marlin(
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++)
ldsm<m_block_size_8 ? 2 : 4, scalar_t>(
ldsm<m_block_size_8 ? 2 : 4, a_type_id>(
frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_thread_vecs; i++) {
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
&sh_b_stage[b_sh_stride * (k % b_sh_wr_iters) + b_sh_rd + i]);
}
};
......@@ -922,53 +1082,54 @@ __global__ void Marlin(
auto fetch_scales_to_registers = [&](int k, int full_pipe) {
int pipe = full_pipe % stages;
using IT1 = typename std::conditional_t<is_a_8bit, int2, int4>;
using IT0 = typename std::conditional_t<is_a_8bit, int, int2>;
constexpr int group_blocks2 = div_ceil(group_blocks, is_a_8bit ? 2 : 1);
if constexpr (!has_act_order) {
// No act-order case
if constexpr (group_blocks == -1) {
// load only when starting a new slice
if (k == 0 && full_pipe == 0) {
if (k == 0 && full_pipe == 0 && dequant_skip_flop) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
} else if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
if (k % b_sh_wr_iters == 0) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
constexpr int g = group_blocks / thread_k_blocks;
if (pipe % g == 0) {
if (k % b_sh_wr_iters == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
}
}
} else {
} else if (group_blocks2 < b_sh_wr_iters || k % b_sh_wr_iters == 0) {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / tb_n_warps;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id =
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters;
int cur_group_id = k_blocks / group_blocks2;
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
} else if (group_blocks >= b_sh_wr_iters) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
k % 2];
reinterpret_cast<int2*>(&frag_s[1])[0] =
reinterpret_cast<int2*>(&frag_s[0])[0];
}
}
}
......@@ -989,18 +1150,15 @@ __global__ void Marlin(
cur_k = 0;
// Progress to current iteration
cur_k += k_iter_size * (k % b_sh_wr_iters);
cur_k += k % b_sh_wr_iters;
// Determine "position" inside the thread-block (based on warp and
// thread-id)
auto warp_id = threadIdx.x / 32;
int n_warps =
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
int warp_row = warp_id / n_warps;
int warp_col = warp_id % n_warps;
int warp_row = warp_id / tb_n_warps;
int warp_col = warp_id % tb_n_warps;
cur_k += warp_row * 16;
cur_k += warp_row * 16 * b_sh_wr_iters;
auto th_id = threadIdx.x % 32;
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
......@@ -1055,18 +1213,16 @@ __global__ void Marlin(
if constexpr (group_blocks == -1) {
// load only when starting a new slice
if (k == 0 && full_pipe == 0) {
if (k == 0 && full_pipe == 0 || is_a_8bit) {
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
}
}
} else if constexpr (group_blocks >= thread_k_blocks) {
if (k % b_sh_wr_iters == 0) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
constexpr int g = group_blocks / thread_k_blocks;
if (pipe % g == 0 && k % b_sh_wr_iters == 0 || is_a_8bit) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g));
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
......@@ -1075,21 +1231,11 @@ __global__ void Marlin(
}
} else {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int warp_row = warp_id / tb_n_warps;
int k_blocks = cur_k / 16;
int cur_group_id = 0;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters;
int cur_group_id = k_blocks / div_ceil(group_blocks, is_a_8bit ? 2 : 1);
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
......@@ -1108,29 +1254,18 @@ __global__ void Marlin(
if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
if (k % b_sh_wr_iters == 0) {
int4* sh_zp_stage =
sh_zp +
zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
constexpr int g = group_blocks / thread_k_blocks;
if (pipe % g == 0 && k % b_sh_wr_iters == 0) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
sh_zp_stage[zp_sh_rd];
}
} else {
} else if (group_blocks < b_sh_wr_iters || k % b_sh_wr_iters == 0) {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
int warp_row = warp_id / tb_n_warps;
int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters;
int cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
......@@ -1141,33 +1276,46 @@ __global__ void Marlin(
}
};
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
auto dequant_data = [&](int q, scalar_32bit_t* frag_b_ptr, int zp = 0) {
if constexpr (a_type.size_bits() != b_type.size_bits()) {
if constexpr (is_a_8bit && has_zp) {
sub_zp_and_dequant<scalar_32bit_t, b_type_id, dequant_skip_flop>(
q, frag_b_ptr, zp);
} else {
dequant<scalar_32bit_t, b_type_id, dequant_skip_flop>(q, frag_b_ptr);
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
bool is_first_matmul_in_slice = true;
auto matmul = [&](int k) {
auto matmul = [&](int k, int pipe) {
if (is_a_8bit) return;
int k2 = k % 2;
constexpr int g =
group_blocks > 0 ? div_ceil(group_blocks, thread_k_blocks) : 1;
const bool is_new_zp =
((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) ||
(group_blocks == 0) ||
((group_blocks > 0) && (group_blocks < b_sh_wr_iters || k == 0)) &&
(pipe % g == 0) ||
(group_blocks == -1 && is_first_matmul_in_slice);
if constexpr (has_zp && !is_zp_float) {
if (is_new_zp) {
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
int zp_quant_0, zp_quant_1;
if constexpr (w_type.size_bits() == 4) {
if constexpr (b_type.size_bits() == 4) {
zp_quant_0 = frag_qzp[k2][0];
zp_quant_1 = zp_quant_0 >> 8;
} else {
static_assert(w_type.size_bits() == 8);
static_assert(b_type.size_bits() == 8);
zp_quant_0 = frag_qzp[k2][0];
zp_quant_1 = frag_qzp[k2][1];
}
dequant_data(zp_quant_0, reinterpret_cast<scalar_t2*>(&frag_zp));
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
dequant_data(zp_quant_0, reinterpret_cast<scalar_32bit_t*>(&frag_zp));
dequant_data(zp_quant_1,
reinterpret_cast<scalar_32bit_t*>(&frag_zp) + 2);
}
}
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
......@@ -1177,14 +1325,14 @@ __global__ void Marlin(
}
}
if constexpr (w_type == vllm::kFE2M1f) {
if constexpr (b_type == vllm::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2, s_type_id>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
dequant_fp8_scales<c_scalar_t2, s_type_id>(
s_quant_0, reinterpret_cast<c_scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<c_scalar_t2, s_type_id>(
s_quant_1, reinterpret_cast<c_scalar_t2*>(&frag_s[k2]) + 2);
}
// We have the m dimension as the inner loop in order to encourage overlapping
......@@ -1195,61 +1343,168 @@ __global__ void Marlin(
FragB frag_b1;
int b_quant_0, b_quant_1;
if constexpr (w_type_id == vllm::kFE2M1f.id()) {
if constexpr (b_type_id == vllm::kFE2M1f.id()) {
b_quant_1 = frag_b_quant[k2][0][j];
b_quant_0 = b_quant_1 << 8;
} else if constexpr (w_type.size_bits() == 4) {
} else if constexpr (b_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k2][0][j];
b_quant_1 = b_quant_0 >> 8;
} else {
static_assert(w_type.size_bits() == 8);
static_assert(b_type.size_bits() == 8);
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k2]);
b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
}
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
dequant_data(b_quant_0, reinterpret_cast<scalar_32bit_t*>(&frag_b0));
dequant_data(b_quant_1, reinterpret_cast<scalar_32bit_t*>(&frag_b1));
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
if constexpr (dequant_skip_flop && has_zp && !is_zp_float && !is_a_8bit) {
sub_zp<a_type_id>(frag_b0, frag_zp[j], 0);
sub_zp<a_type_id>(frag_b1, frag_zp[j], 1);
}
// Apply scale to frag_b0
if constexpr (has_act_order) {
if constexpr (has_act_order && !is_a_8bit) {
static_assert(group_blocks != -1);
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
scale4<a_type_id>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<a_type_id>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
group_blocks == -1) {
group_blocks == -1 && !is_a_8bit) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2(
scalar_t2 s2 = Adtype::nums2num2(
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 1])[idx]);
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
scale_and_sub<a_type_id>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<a_type_id>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1 &&
!is_a_8bit) {
if (is_new_zp)
frag_zp[j] = __hmul2(frag_zp[j],
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y);
} else if constexpr (group_blocks != -1) {
scale<scalar_t>(frag_b0, frag_s[k2][j], 0);
scale<scalar_t>(frag_b1, frag_s[k2][j], 1);
scale_and_sub<a_type_id>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x);
scale_and_sub<a_type_id>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y);
} else if constexpr (group_blocks != -1 && !is_a_8bit) {
scale<a_type_id>(frag_b0, frag_s[k2][j], 0);
scale<a_type_id>(frag_b1, frag_s[k2][j], 1);
}
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) {
mma_trans<scalar_t>(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]);
mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
frag_c[i][j][0]);
} else {
mma<scalar_t>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
mma<scalar_t>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
}
}
}
};
auto matmul_a8 = [&](int k) {
int k2 = k % 2;
#pragma unroll
for (int j = 0; j < 2; j++) {
FragB frag_b[2];
if (is_a_8bit && b_type.size_bits() == 4 && !has_zp) {
dequant_data(frag_b_quant[k2][0][j * 2],
reinterpret_cast<scalar_32bit_t*>(&frag_b));
dequant_data(frag_b_quant[k2][0][j * 2 + 1],
reinterpret_cast<scalar_32bit_t*>(&frag_b) + 2);
} else if (is_a_8bit && b_type.size_bits() == 4 && has_zp) {
int off = (threadIdx.x / 32) % 2 * 2 + j;
int zp = (frag_qzp[k2][0] >> (off * 8)) & 0xF;
dequant_data(frag_b_quant[k2][0][j * 2],
reinterpret_cast<scalar_32bit_t*>(&frag_b), zp);
zp = (frag_qzp[k2][0] >> (off * 8 + 4)) & 0xF;
dequant_data(frag_b_quant[k2][0][j * 2 + 1],
reinterpret_cast<scalar_32bit_t*>(&frag_b) + 2, zp);
} else {
reinterpret_cast<int2*>(&frag_b)[0] =
reinterpret_cast<int2*>(&frag_b_quant[k2][j])[0];
reinterpret_cast<int2*>(&frag_b)[1] =
reinterpret_cast<int2*>(&frag_b_quant[k2][j])[1];
}
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
}
if constexpr (group_blocks != -1) {
if (group_blocks == 2 || k == 1) {
if constexpr (a_type == vllm::kS8) {
int2 s_vals[2];
s_vals[0] = {
(int)reinterpret_cast<uint16_t*>(&frag_s[k2][j * 2][0])[0],
(int)reinterpret_cast<uint16_t*>(&frag_s[k2][j * 2][0])[1]};
s_vals[1] = {
(int)reinterpret_cast<uint16_t*>(&frag_s[k2][j * 2 + 1][0])[0],
(int)reinterpret_cast<uint16_t*>(&frag_s[k2][j * 2 + 1][0])[1]};
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
int scale = reinterpret_cast<int*>(&s_vals[0])[g % 2];
*reinterpret_cast<int32_t*>(&frag_c[i][j][0][g]) +=
*reinterpret_cast<int32_t*>(&frag_c_tmp[i][j][0][g]) *
scale;
frag_c_tmp[i][j][0][g] = 0.0f;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
int scale = reinterpret_cast<int*>(&s_vals[1])[g % 2];
*reinterpret_cast<int32_t*>(&frag_c[i][j][1][g]) +=
*reinterpret_cast<int32_t*>(&frag_c_tmp[i][j][1][g]) *
scale;
frag_c_tmp[i][j][1][g] = 0.0f;
}
}
} else {
float2 s_vals[2];
if constexpr (s_type_id != vllm::kFE8M0fnu.id()) {
static_assert(a_type.size_bits() == 16 ||
s_type.size_bits() == 16);
s_vals[0] = Cdtype::num22float2(frag_s[k2][j * 2][0]);
s_vals[1] = Cdtype::num22float2(frag_s[k2][j * 2 + 1][0]);
} else {
int32_t* s_vals_int = reinterpret_cast<int32_t*>(&s_vals[0]);
int32_t s_vals_e8m0 =
*reinterpret_cast<int32_t*>(&frag_s[k2][j][0]);
s_vals_int[0] = (s_vals_e8m0 & 0xFF) << 23;
s_vals_int[1] = (s_vals_e8m0 & 0xFF00) << 15;
s_vals_int[2] = (s_vals_e8m0 & 0xFF0000) << 7;
s_vals_int[3] = (s_vals_e8m0 & 0xFF000000) >> 1;
}
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
float scale = reinterpret_cast<float*>(&s_vals[0])[g % 2];
frag_c[i][j][0][g] += frag_c_tmp[i][j][0][g] * scale;
frag_c_tmp[i][j][0][g] = 0.0f;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
float scale = reinterpret_cast<float*>(&s_vals[1])[g % 2];
frag_c[i][j][1][g] += frag_c_tmp[i][j][1][g] * scale;
frag_c_tmp[i][j][1][g] = 0.0f;
}
}
}
}
}
}
......@@ -1263,7 +1518,8 @@ __global__ void Marlin(
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
auto red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_stride =
b_sh_stride_threads * (is_a_8bit ? 2 : 4) * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads);
......@@ -1278,7 +1534,8 @@ __global__ void Marlin(
for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll
for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) {
for (int j = 0; j < (is_a_8bit ? 2 : 4) * 2;
j += (m_block_size_8 ? 2 : 1)) {
int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) {
......@@ -1287,24 +1544,26 @@ __global__ void Marlin(
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
reinterpret_cast<FragC*>(
frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k];
}
sh_red[red_sh_wr] =
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
sh_red[red_sh_wr] = reinterpret_cast<int4*>(
&frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j];
}
}
__syncthreads();
}
if (red_idx == 0) {
#pragma unroll
for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) {
for (int i = 0; i < (is_a_8bit ? 2 : 4) * 2;
i += (m_block_size_8 ? 2 : 1)) {
float* c_rd =
reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
c_rd[j];
reinterpret_cast<FragC*>(
frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + i][j] += c_rd[j];
}
}
__syncthreads();
......@@ -1320,10 +1579,10 @@ __global__ void Marlin(
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr int active_threads = 32 * thread_n_blocks / 4;
constexpr int active_threads = 32 * tb_n_warps;
if (threadIdx.x < active_threads) {
int c_gl_stride = prob_n / 8;
int c_gl_wr_delta_o = 8 * c_gl_stride;
int c_gl_wr_delta_o = 8 * c_gl_stride * (is_a_8bit ? 2 : 1);
int c_gl_wr_delta_i = 4 * (active_threads / 32);
int c_gl_wr;
if constexpr (m_block_size_8) {
......@@ -1331,9 +1590,9 @@ __global__ void Marlin(
4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
} else {
c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) * (is_a_8bit ? 2 : 1) +
4 * (threadIdx.x / 32) + threadIdx.x % 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
c_gl_wr += (2 * thread_n_blocks) * slice_col * (is_a_8bit ? 2 : 1);
}
constexpr int c_sh_wr_delta = active_threads;
auto c_sh_wr = threadIdx.x;
......@@ -1351,6 +1610,14 @@ __global__ void Marlin(
&C[c_gl_wr + i * c_gl_stride +
(threadIdx.x % 8) / 4 * c_gl_wr_delta_i],
(threadIdx.x % 4) * 2 + i < prob_m);
} else if constexpr (is_a_8bit) {
int2* sh_red_int2 = reinterpret_cast<int2*>(sh_red);
int2* c_int2 = reinterpret_cast<int2*>(C);
cp_async2_ca_pred(
&sh_red_int2[c_sh_wr + c_sh_wr_delta * i],
&c_int2[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
} else {
cp_async4_pred(
&sh_red[c_sh_wr + c_sh_wr_delta * i],
......@@ -1370,36 +1637,51 @@ __global__ void Marlin(
(m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m);
if (mask) {
if (!first) {
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
c_scalar_t* c_red_f16;
if constexpr (is_a_8bit) {
int2 tmp =
reinterpret_cast<int2*>(sh_red)[c_sh_wr + i * c_sh_wr_delta];
c_red_f16 = reinterpret_cast<c_scalar_t*>(&tmp);
} else {
int4 tmp = sh_red[c_sh_wr + i * c_sh_wr_delta];
c_red_f16 = reinterpret_cast<c_scalar_t*>(&tmp);
}
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) {
int delta = 0;
if constexpr (m_block_size_8) {
delta = j % 2 == 1 ? -2 : 0;
}
reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] +=
Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);
&frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j +
(i % 4) + delta] += Cdtype::num2float(c_red_f16[j]);
}
}
if (!last) {
int4 c;
c_scalar_t c_f16[is_a_8bit ? 4 : 8];
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) {
int delta = 0;
if constexpr (m_block_size_8) {
delta = j % 2 == 1 ? -2 : 0;
}
reinterpret_cast<scalar_t*>(&c)[j] =
Dtype::float2num(reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]);
c_f16[j] = Cdtype::float2num(reinterpret_cast<float*>(
&frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j +
(i % 4) + delta]);
}
if constexpr (m_block_size_8)
if constexpr (m_block_size_8) {
C[c_gl_wr + i * c_gl_stride +
(threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c;
else
(threadIdx.x % 8) / 4 * c_gl_wr_delta_i] =
*reinterpret_cast<int4*>(c_f16);
} else if constexpr (is_a_8bit) {
int2* c_int2 = reinterpret_cast<int2*>(C);
c_int2[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)] =
*reinterpret_cast<int2*>(c_f16);
} else {
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)] = c;
c_gl_wr_delta_i * (i % 2)] = *reinterpret_cast<int4*>(c_f16);
}
}
}
}
......@@ -1414,10 +1696,10 @@ __global__ void Marlin(
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
constexpr int active_threads = 32 * thread_n_blocks / 4;
constexpr int active_threads = 32 * tb_n_warps;
bool is_th_active = threadIdx.x < active_threads;
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
constexpr int num_floats = thread_m_blocks * (is_a_8bit ? 2 : 4) * 2 * 4;
constexpr int th_size = num_floats * sizeof(float) / 16;
int c_cur_offset = locks_off * c_size;
......@@ -1471,7 +1753,7 @@ __global__ void Marlin(
} else {
c_sh_wr =
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
c_sh_wr += 32 * (threadIdx.x / 32);
c_sh_wr += (is_a_8bit ? 16 : 32) * (threadIdx.x / 32);
}
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
......@@ -1481,47 +1763,47 @@ __global__ void Marlin(
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
scalar_t2 res =
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
c_scalar_t2 res =
Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1));
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4 &&
if constexpr (!has_act_order && group_blocks == -1 && !is_a_8bit &&
b_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) {
scalar_t2 tmp_scale = s[0];
c_scalar_t2 tmp_scale = s[0];
if constexpr (m_block_size_8) {
tmp_scale = Dtype::num2num2(
tmp_scale = Cdtype::num2num2(
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
}
res = __hmul2(res, tmp_scale);
}
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
res = __hmul2(res, global_scale);
}
if (has_bias && last) {
scalar_t2 tmp_bias = b_bias[0];
c_scalar_t2 tmp_bias = b_bias[0];
if constexpr (m_block_size_8) {
tmp_bias = Dtype::num2num2(
tmp_bias = Cdtype::num2num2(
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
}
res = __hadd2(res, tmp_bias);
}
if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x;
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
((c_scalar_t*)sh_red)[idx] = res.x;
((c_scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
} else {
((scalar_t2*)sh_red)[idx] = res;
((c_scalar_t2*)sh_red)[idx] = res;
}
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
if (threadIdx.x / 32 < tb_n_warps) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
for (int j = 0; j < (is_a_8bit ? 2 : 4); j++) {
if constexpr (m_block_size_8) {
int wr = c_sh_wr + 16 * j;
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
......@@ -1557,9 +1839,9 @@ __global__ void Marlin(
i++) {
if (c_gl_wr < c_gl_wr_end) {
if (use_atomic_add && slice_count > 1) {
scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[c_gl_wr]);
scalar_t2* sh_red_half2 =
reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]);
c_scalar_t2* C_half2 = reinterpret_cast<c_scalar_t2*>(&C[c_gl_wr]);
c_scalar_t2* sh_red_half2 =
reinterpret_cast<c_scalar_t2*>(&sh_red[c_sh_rd]);
#pragma unroll
for (int a = 0; a < 4; a++) {
atomicAdd(&C_half2[a], sh_red_half2[a]);
......@@ -1635,7 +1917,13 @@ __global__ void Marlin(
wait_for_stage();
init_same_group(pipe % stages);
}
matmul(k);
if constexpr (!is_a_8bit) {
matmul(k, pipe - (k >= b_sh_wr_iters - 2 ? 1 : 0));
} else {
static_assert(group_blocks != 0 && group_blocks != 1);
matmul_a8(k);
}
}
slice_iters--;
if (slice_iters == 0) {
......@@ -1668,13 +1956,47 @@ __global__ void Marlin(
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) {
if constexpr (is_a_8bit) {
float frag_a_s[2 * thread_m_blocks];
for (int i = 0; i < 2 * thread_m_blocks; i++)
frag_a_s[i] = sh_a_s[i * 8 + (threadIdx.x % 32) / 4];
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
float c_val = frag_c[i][j][0][g];
if constexpr (a_type == vllm::kS8) {
c_val = __int2float_rn(*reinterpret_cast<int32_t*>(&c_val));
}
float s_val = frag_a_s[i * 2 + g / 2];
frag_c[i][j][0][g] = c_val * s_val;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
float c_val = frag_c[i][j][1][g];
if constexpr (a_type == vllm::kS8) {
c_val = __int2float_rn(*reinterpret_cast<int32_t*>(&c_val));
}
float s_val = frag_a_s[i * 2 + g / 2];
frag_c[i][j][1][g] = c_val * s_val;
}
}
}
}
cp_async_wait<0>();
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before
// write-out
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
if (b_type.size_bits() == 8 || (last || use_atomic_add) || is_a_8bit) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
......@@ -1692,20 +2014,27 @@ __global__ void Marlin(
}
if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
(has_zp && dequant_skip_flop || !has_zp || is_a_8bit)) {
if constexpr (is_a_8bit) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
if (threadIdx.x / 32 < tb_n_warps) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
}
} else if (b_type.size_bits() == 8 || (last || use_atomic_add)) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < tb_n_warps) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
if constexpr (m_block_size_8) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2* frag_s_half2 = reinterpret_cast<scalar_t2*>(frag_s);
c_scalar_t2* frag_s_half2 =
reinterpret_cast<c_scalar_t2*>(frag_s);
#pragma unroll
for (int i = 0; i < 8; i++) {
frag_s_half2[i] = Dtype::num2num2(
reinterpret_cast<scalar_t*>(&frag_s_half2[i])[idx]);
frag_s_half2[i] = Cdtype::num2num2(
reinterpret_cast<c_scalar_t*>(&frag_s_half2[i])[idx]);
}
}
}
......@@ -1715,26 +2044,48 @@ __global__ void Marlin(
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 8 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
if constexpr (!has_act_order && group_blocks == -1 && is_a_8bit) {
#pragma unroll
for (int j = 0; j < 2; j++) {
float2 aa[2];
aa[0] = Cdtype::num22float2(frag_s[0][j * 2][0]);
aa[1] = Cdtype::num22float2(frag_s[0][j * 2 + 1][0]);
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
float scale = reinterpret_cast<float*>(&aa[0])[g % 2];
frag_c[i][j][0][g] *= scale;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
float scale = reinterpret_cast<float*>(&aa[1])[g % 2];
frag_c[i][j][1][g] *= scale;
}
}
}
} else if (!has_act_order && group_blocks == -1 &&
b_type.size_bits() == 8 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (threadIdx.x / 32 < tb_n_warps) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
scale_float<scalar_t>(
scale_float<c_type_id>(
reinterpret_cast<float*>(&frag_c[i][j][0][0]),
frag_s[j / 2][2 * (j % 2) + 0]);
scale_float<scalar_t>(
scale_float<c_type_id>(
reinterpret_cast<float*>(&frag_c[i][j][0][2]),
frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]);
if constexpr (!m_block_size_8) {
scale_float<scalar_t>(
scale_float<c_type_id>(
reinterpret_cast<float*>(&frag_c[i][j][1][0]),
frag_s[j / 2][2 * (j % 2) + 1]);
scale_float<scalar_t>(
scale_float<c_type_id>(
reinterpret_cast<float*>(&frag_c[i][j][1][2]),
frag_s[j / 2][2 * (j % 2) + 1]);
}
......@@ -1758,7 +2109,8 @@ __global__ void Marlin(
cp_async_wait<0>();
__syncthreads();
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
if constexpr (!is_a_8bit)
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
__syncthreads();
}
......@@ -1768,21 +2120,22 @@ __global__ void Marlin(
// only the last block in a slice actually writes the result
write_result(last);
slice_row = 0;
slice_col_par++;
slice_col++;
if (!in_part2) {
slice_col_par += gridDim.x;
} else {
slice_col_par++;
slice_col++;
}
is_first_matmul_in_slice = true;
init_slice();
if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) {
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
}
a_gl_rd += a_gl_rd_delta_o * slice_row;
b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) +
(threadIdx.x % b_sh_stride);
b_gl_rd += b_sh_stride * slice_col + b_gl_rd_delta_o * slice_row;
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Update slice k/n for scales loading
......@@ -1791,12 +2144,28 @@ __global__ void Marlin(
slice_k_finish = slice_k_start + tb_k * slice_iters;
slice_k_start_shared_fetch = slice_k_start;
slice_n_offset = act_s_col_tb_stride * slice_col;
} else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else if constexpr (group_blocks >= thread_k_blocks) {
s_gl_rd =
s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd =
zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
zp_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd =
s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks +
threadIdx.x / s_sh_stride) +
s_sh_stride * slice_col + threadIdx.x % s_sh_stride;
zp_gl_rd =
zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks +
threadIdx.x / zp_sh_stride) +
zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride;
}
}
start_pipes();
}
}
......
......@@ -298,9 +298,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor? b_bias_or_none,"
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
"Tensor? b_bias_or_none,Tensor b_scales, "
"Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, "
"Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_type_id, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file
......@@ -308,13 +309,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin repack from GPTQ.
ops.def(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
"SymInt size_k, SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
// conditionally compiled so impl registrations are in source file
// awq_marlin repack from AWQ.
ops.def(
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor");
"SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
// conditionally compiled so impl registrations are in source file
// preprocess W-int4A-fp8 weight for marlin kernel
ops.def(
"marlin_int4_fp8_preprocess(Tensor qweight, "
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
// conditionally compiled so impl registrations are in source file
// CUTLASS w4a8 GEMM
......
......@@ -60,7 +60,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
- [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod]
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
- [`CompressedTensorsW4A4Nvfp4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoeMethod]
- [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoEMethod]
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod]
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]
......
......@@ -21,7 +21,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
......@@ -65,6 +65,64 @@ NUM_EXPERTS = [8, 64, 192]
EP_SIZE = [1, 4]
TOP_KS = [2, 6]
MOE_MARLIN_QUANT_TEST_CONFIGS = [
# AWQ-INT4
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
# GPTQ-INT4
{
"b_type": scalar_types.uint4b8,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT8
{
"b_type": scalar_types.uint8b128,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# FP8
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
# NVFP4
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
# MXFP4
{
"a_type": [scalar_types.bfloat16],
"b_type": scalar_types.float4_e2m1f,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.float4_e2m1f,
"c_type": [scalar_types.bfloat16],
"group_blocks": [2],
},
]
FUSED_MOE_MNK_FACTORS = [
(1, 128, 128),
(1, 2048, 128),
......@@ -505,63 +563,74 @@ def marlin_moe_generate_valid_test_cases():
m_list = [1, 123, 666]
n_list = [128, 1024]
k_list = [256, 2048]
e_list = [4, 12]
e_list = [5, 12]
topk_list = [2, 3]
ep_size_list = [1, 4]
dtype_list = [torch.bfloat16]
group_size_list = [-1, 32, 128]
act_order_list = [True, False]
quant_type_list = [
scalar_types.float4_e2m1f,
scalar_types.float8_e4m3fn,
scalar_types.uint4,
scalar_types.uint4b8,
scalar_types.uint8b128,
]
is_k_full_list = [True, False]
all_combinations = itertools.product(
MOE_MARLIN_QUANT_TEST_CONFIGS,
m_list,
n_list,
k_list,
e_list,
topk_list,
ep_size_list,
dtype_list,
group_size_list,
act_order_list,
quant_type_list,
is_k_full_list,
)
def is_invalid(
m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full
a_type,
b_type,
c_type,
group_blocks,
m,
n,
k,
e,
topk,
ep_size,
act_order,
is_k_full,
):
if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]:
return False
if quant_type == scalar_types.float4_e2m1f:
if group_size not in [16, 32]:
return False
if dtype == torch.float16 and group_size == 32:
return False
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
if group_size > 0 and k % group_size != 0:
return False
# Filter act_order
if act_order:
if group_size in (-1, k, n):
return False
if quant_type not in [scalar_types.uint4b8]:
return False
elif not is_k_full:
if act_order and group_size in [-1, k, n]:
return False
if group_size in [k, n]:
return False
if not act_order and is_k_full:
return False
return True
return a_type.size_bits < 16 or a_type is c_type
cases = []
for case in all_combinations:
if is_invalid(*case):
cases.append(case)
quant_test_config, m, n, k, _, _, _, act_order, *_ = case
if act_order and not quant_test_config.get("support_act_order", False):
continue
f16_types = [scalar_types.float16]
inner_combinations = itertools.product(
quant_test_config.get("a_type", f16_types),
[quant_test_config["b_type"]],
quant_test_config.get("c_type", f16_types),
quant_test_config["group_blocks"],
)
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
):
continue
args = sub_case + (m, n, k) + case[4:]
if is_invalid(*args):
cases.append(args)
return cases
......@@ -571,6 +640,7 @@ class MarlinMoEWeightData:
qweight: torch.Tensor
scales: torch.Tensor
global_scale: torch.Tensor | None
a_scales_factor: torch.Tensor | None
g_idx: torch.Tensor | None
zeros: torch.Tensor | None
sort_indices: torch.Tensor | None
......@@ -583,11 +653,20 @@ class MarlinMoEWeightData:
group_size: int,
act_order: bool | None = None,
bias: torch.Tensor | None = None,
input_type: ScalarType = None,
) -> "MarlinMoEWeightData":
assert w.ndim == 3
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
k = w.shape[-1]
if input_type == scalar_types.int8:
input_dtype = torch.int8
elif input_type == scalar_types.float8_e4m3fn:
input_dtype = torch.float8_e4m3fn
else:
input_dtype = w.dtype
w_ref_l: list[torch.Tensor] = []
qweight_l: list[torch.Tensor] = []
scales_l: list[torch.Tensor] = []
......@@ -601,11 +680,13 @@ class MarlinMoEWeightData:
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref, qweight, scales, global_scale = (
rand_marlin_weight_nvfp4_like(w[i], group_size)
rand_marlin_weight_nvfp4_like(
w[i], group_size, input_dtype=input_dtype
)
)
else:
w_ref, qweight, scales = rand_marlin_weight_mxfp4_like(
w[i], group_size
w[i], group_size, input_dtype=input_dtype
)
global_scale = None
......@@ -615,13 +696,18 @@ class MarlinMoEWeightData:
if global_scale is not None:
global_scale_l.append(global_scale)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size)
w_ref, qweight, scales = marlin_quant_fp8_torch(
w[i], group_size, input_dtype=input_dtype
)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
elif has_zp:
w_ref, qweight, scales, zeros = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size
w[i].transpose(1, 0),
quant_type,
group_size,
input_dtype=input_dtype,
)
w_ref_l.append(w_ref.T)
......@@ -631,7 +717,12 @@ class MarlinMoEWeightData:
else:
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
w[i].transpose(1, 0),
quant_type,
group_size,
act_order,
test_perm,
input_dtype=input_dtype,
)
w_ref_l.append(w_ref.T)
......@@ -652,11 +743,18 @@ class MarlinMoEWeightData:
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
marlin_bias = stack_and_dev(bias_l) if bias_l else None
a_scales_factor = None
if input_type == scalar_types.int8 and group_size != -1:
a_scales_factor = 1 / 4096 * scales.max().float()
scales = scales / scales.max() * 4096
scales = scales.round().to(torch.int16).view(w.dtype)
return MarlinMoEWeightData(
w_ref=w_ref,
qweight=qweight,
scales=scales,
global_scale=global_scale,
a_scales_factor=a_scales_factor,
g_idx=g_idx,
zeros=zeros,
sort_indices=sort_indices,
......@@ -666,28 +764,47 @@ class MarlinMoEWeightData:
@pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize(
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
(
"a_type, b_type, c_type, group_blocks,"
"m, n, k, e, topk, ep_size, act_order, is_k_full"
),
marlin_moe_generate_valid_test_cases(),
)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
group_size: int,
act_order: bool,
quant_type: ScalarType,
is_k_full: bool,
a_type,
b_type,
c_type,
group_blocks,
m,
n,
k,
e,
topk,
ep_size,
act_order,
is_k_full,
):
torch.cuda.manual_seed(0)
torch.cuda.manual_seed(1)
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
if c_type == scalar_types.float16:
dtype = torch.float16
elif c_type == scalar_types.bfloat16:
dtype = torch.bfloat16
else:
raise RuntimeError("unsupported c_type")
if a_type == scalar_types.int8:
a_dtype = torch.int8
elif a_type == scalar_types.float8_e4m3fn:
a_dtype = torch.float8_e4m3fn
else:
a_dtype = dtype
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
if ep_size > 1:
local_e = e // ep_size
......@@ -700,11 +817,19 @@ def test_fused_marlin_moe(
e_map = None
w1_data = MarlinMoEWeightData.make(
w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order
w=w1,
quant_type=b_type,
group_size=group_size,
act_order=act_order,
input_type=a_type,
)
w2_data = MarlinMoEWeightData.make(
w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order
w=w2,
quant_type=b_type,
group_size=group_size,
act_order=act_order,
input_type=a_type,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
......@@ -712,8 +837,18 @@ def test_fused_marlin_moe(
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(
a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
torch_output = torch_experts(
a,
w1_data.w_ref,
w2_data.w_ref,
topk_weight=topk_weight,
topk_ids=topk_ids,
global_num_experts=e,
expert_map=e_map,
quant_dtype=a_dtype,
per_act_token_quant=True,
)
marlin_output = fused_marlin_moe(
......@@ -733,15 +868,18 @@ def test_fused_marlin_moe(
global_scale2=w2_data.global_scale,
g_idx1=w1_data.g_idx,
g_idx2=w2_data.g_idx,
input_global_scale1=w1_data.a_scales_factor,
input_global_scale2=w2_data.a_scales_factor,
sort_indices1=w1_data.sort_indices,
sort_indices2=w2_data.sort_indices,
w1_zeros=w1_data.zeros,
w2_zeros=w2_data.zeros,
quant_type_id=quant_type.id,
input_dtype=a_dtype,
quant_type_id=b_type.id,
is_k_full=is_k_full,
)
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0)
@pytest.mark.flaky(reruns=2)
......
......@@ -5,6 +5,8 @@
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
"""
import itertools
import pytest
import torch
......@@ -17,8 +19,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES,
marlin_make_empty_g_idx,
marlin_make_workspace_new,
marlin_permute_bias,
......@@ -26,7 +30,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
query_marlin_supported_quant_types,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like,
)
......@@ -50,6 +53,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights,
sort_weights,
)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
ACT_ORDER_OPTS = [False, True]
......@@ -65,6 +69,12 @@ MARLIN_24_N_CHUNKS = [512]
HQQ_SUPPORTED_GROUP_SIZES = [64]
MARLIN_REPACK_NK_FACTORS = [
(4, 8),
(7, 5),
(13, 11),
]
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
......@@ -74,6 +84,64 @@ MNK_FACTORS = [
DTYPES = [torch.float16, torch.bfloat16]
DENSE_MARLIN_QUANT_TEST_CONFIGS = [
# AWQ-INT4
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
# GPTQ-INT4
{
"b_type": scalar_types.uint4b8,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT8
{
"b_type": scalar_types.uint8b128,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# FP8
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
# NVFP4
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
# MXFP4
{
"a_type": [scalar_types.bfloat16],
"b_type": scalar_types.float4_e2m1f,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.float4_e2m1f,
"c_type": [scalar_types.bfloat16],
"group_blocks": [2],
},
]
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
......@@ -85,6 +153,58 @@ def rand_data(shape, dtype=torch.float16):
return torch.randn(shape, dtype=dtype, device="cuda")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_without_zp():
qweight_unpacked = torch.randint(
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
)
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed)
torch_res = torch.where(
qweight_unpacked >= 8, qweight_unpacked - 8, 15 - qweight_unpacked
)
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
torch_res = torch_res.to(torch.int8).view(torch.int32)
assert (cuda_res == torch_res).all()
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_awq():
group_size = 128
qweight_unpacked = torch.randint(
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
)
qzeros_unpacked = torch.randint(
0, 16, size=(2048 // group_size, 2048), dtype=torch.int32, device="cuda"
)
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
qzeros_packed = qzeros_unpacked[:, ::2] * 16 + qzeros_unpacked[:, 1::2]
qzeros_packed = qzeros_packed.to(torch.int8).view(torch.int32)
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed, qzeros_packed)
repeated_zp = qzeros_unpacked.repeat_interleave(group_size, 0)
torch_res = qweight_unpacked - repeated_zp
torch_res[torch_res < 0] = 15 - qweight_unpacked[torch_res < 0]
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
torch_res = torch_res.to(torch.int8).view(torch.int32)
assert (cuda_res == torch_res).all()
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
......@@ -92,16 +212,17 @@ def rand_data(shape, dtype=torch.float16):
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
def test_gptq_marlin_repack(
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
k_chunk, n_chunk, quant_type, act_order, is_a_8bit, nk_factors
):
m_factor, n_factor, k_factor = mnk_factors
n_factor, k_factor = nk_factors
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
group_size = 128
# Filter act_order
if act_order:
......@@ -109,6 +230,8 @@ def test_gptq_marlin_repack(
return
if group_size == size_k:
return
if is_a_8bit:
return
# Normalize group_size
if group_size == -1:
......@@ -133,23 +256,19 @@ def test_gptq_marlin_repack(
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Pack to Marlin format
weight_perm = get_weight_perm(quant_type.size_bits)
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
marlin_q_w_1 = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
)
opcheck(
torch.ops._C.gptq_marlin_repack,
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits),
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit),
)
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.gptq_marlin_repack(
q_w_gptq,
sort_indices,
size_k,
size_n,
quant_type.size_bits,
q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit
)
torch.cuda.synchronize()
......@@ -163,18 +282,15 @@ def test_gptq_marlin_repack(
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
@pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors):
n_factor, k_factor = nk_factors
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
group_size = 128
# Create input
b_weight = rand_data((size_k, size_n))
......@@ -188,162 +304,221 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
# Pack to Marlin format
weight_perm = get_weight_perm(quant_type.size_bits)
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
marlin_q_w_1 = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
)
opcheck(
torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits)
torch.ops._C.awq_marlin_repack,
(q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit),
)
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.awq_marlin_repack(
q_w_awq,
size_k,
size_n,
quant_type.size_bits,
q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit
)
torch.cuda.synchronize()
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
def marlin_generate_valid_test_cases():
all_combinations = itertools.product(
DENSE_MARLIN_QUANT_TEST_CONFIGS,
MNK_FACTORS,
MARLIN_N_CHUNKS,
MARLIN_K_CHUNKS,
ACT_ORDER_OPTS,
K_FULL_OPTS,
USE_ATOMIC_ADD_OPTS,
USE_FP32_REDUCE_OPTS,
)
def is_invalid(
a_type,
b_type,
c_type,
group_blocks,
size_m,
size_n,
size_k,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
):
if use_atomic_add:
if use_fp32_reduce:
return False
if (
c_type == scalar_types.bfloat16
and torch.cuda.get_device_capability()[0] < 9
):
return False
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
if group_size > 0 and size_k % group_size != 0:
return False
if act_order and group_size in [-1, size_k]:
return False
if group_size == size_k:
return False
if not act_order and is_k_full:
return False
return a_type.size_bits < 16 or a_type is c_type
cases = []
for case in all_combinations:
quant_test_config, mnk_factors, n_chunk, k_chunk, act_order, *_ = case
size_m = mnk_factors[0]
size_n = mnk_factors[1] * n_chunk
size_k = mnk_factors[2] * k_chunk
if act_order and not quant_test_config.get("support_act_order", False):
continue
f16_types = [scalar_types.float16, scalar_types.bfloat16]
inner_combinations = itertools.product(
quant_test_config.get("a_type", f16_types),
[quant_test_config["b_type"]],
quant_test_config.get("c_type", f16_types),
quant_test_config["group_blocks"],
)
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
):
continue
args = sub_case + (size_m, size_n, size_k) + case[4:]
if is_invalid(*args):
cases.append(args)
return cases
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
@pytest.mark.parametrize(
"group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)
(
"a_type, b_type, c_type, group_blocks,"
"size_m, size_n, size_k, act_order, is_k_full,"
"use_atomic_add, use_fp32_reduce"
),
marlin_generate_valid_test_cases(),
)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
a_type,
b_type,
c_type,
group_blocks,
size_m,
size_n,
size_k,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
dtype,
):
m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
has_zp = b_type in [scalar_types.uint4, scalar_types.uint8]
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
if act_order:
if group_size == -1:
return
if group_size == size_k:
return
if has_zp:
return
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
if size_k % group_size != 0:
return
if c_type == scalar_types.float16:
dtype = torch.float16
elif c_type == scalar_types.bfloat16:
dtype = torch.bfloat16
else:
raise RuntimeError("unsupported c_type")
a_input = rand_data((size_m, size_k), dtype)
b_weight = rand_data((size_k, size_n), dtype)
if a_type == scalar_types.int8:
a_dtype = torch.int8
elif a_type == scalar_types.float8_e4m3fn:
a_dtype = torch.float8_e4m3fn
else:
a_dtype = dtype
if quant_type == scalar_types.float4_e2m1f:
if group_size not in [16, 32] or act_order:
return
if group_size == 32 and dtype == torch.float16:
return
a_input = rand_data((size_m, size_k), dtype=dtype)
b_weight = rand_data((size_k, size_n), dtype=dtype)
if b_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
b_weight.T, group_size
b_weight.T, group_size, input_dtype=a_dtype
)
else:
w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
b_weight.T, group_size
b_weight.T, group_size, input_dtype=a_dtype
)
marlin_s2 = None
g_idx = None
sort_indices = None
marlin_zp = None
elif quant_type == scalar_types.float8_e4m3fn:
if group_size not in [-1, 128]:
return
if act_order:
return
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size)
elif b_type == scalar_types.float8_e4m3fn:
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
b_weight.T, group_size, input_dtype=a_dtype
)
g_idx = None
sort_indices = None
marlin_zp = None
marlin_s2 = None
elif has_zp:
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size
b_weight, b_type, group_size, input_dtype=a_dtype
)
g_idx = None
sort_indices = None
marlin_s2 = None
else:
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order
b_weight, b_type, group_size, act_order, input_dtype=a_dtype
)
marlin_zp = None
marlin_s2 = None
workspace = marlin_make_workspace_new(w_ref.device)
opcheck(
torch.ops._C.gptq_marlin_gemm,
(
a_input,
None,
marlin_q_w,
None,
marlin_s,
marlin_s2,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type.id,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full,
use_atomic_add,
use_fp32_reduce,
False,
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
if a_type == scalar_types.int8:
a_input, a_scales = per_token_quant_int8(a_input)
a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
a_input_ref = a_input_ref.to(dtype)
if group_size != -1:
a_scales = a_scales / 4096 * marlin_s.max()
a_scales = a_scales.float()
marlin_s = marlin_s / marlin_s.max() * 4096
marlin_s = marlin_s.round().to(torch.int16).view(dtype)
elif a_type == scalar_types.float8_e4m3fn:
a_input, a_scales = ops.scaled_fp8_quant(a_input, use_per_token_if_dynamic=True)
a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
a_input_ref = a_input_ref.to(dtype)
else:
assert a_type.size_bits == 16
a_input_ref = a_input
a_scales = None
output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
output = ops.gptq_marlin_gemm(
a_input,
None,
output,
marlin_q_w,
None,
marlin_s,
a_scales,
marlin_s2,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type,
b_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
......@@ -352,12 +527,9 @@ def test_gptq_marlin_gemm(
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize()
output_ref = torch.matmul(a_input_ref, w_ref)
max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
......@@ -507,6 +679,7 @@ def test_hqq_marlin_gemm(
None,
marlin_s,
None,
None,
marlin_zp,
g_idx,
g_idx_sort_indices,
......@@ -559,6 +732,7 @@ def test_marlin_gemm_subset_input():
None,
marlin_s,
None,
None,
marlin_zp,
g_idx,
sort_indices,
......@@ -607,6 +781,7 @@ def test_marlin_gemm_with_bias(size_m):
marlin_bias,
marlin_s,
None,
None,
marlin_zp,
g_idx,
sort_indices,
......
......@@ -846,6 +846,13 @@ def torch_experts(
or (expert_map is not None and global_num_experts == expert_map.shape[0])
)
if quant_dtype in [torch.float16, torch.bfloat16]:
quant_dtype = None
quant_input_only = quant_dtype is not None and w1_scale is None and w2_scale is None
if quant_input_only:
assert a1_scale is None and a2_scale is None
assert per_act_token_quant
M, K = a.shape
topk = topk_ids.shape[1]
......@@ -863,6 +870,9 @@ def torch_experts(
a, a1_scale, quant_dtype, per_act_token_quant, block_shape
)
if quant_input_only:
a = (a.float() * a_scale.view(-1, 1)).to(w1.dtype)
num_experts = w1.shape[0]
topk_ids = topk_ids.view(-1)
......@@ -882,6 +892,14 @@ def torch_experts(
out[mask] = tmp2 @ w2[i].transpose(0, 1)
if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
elif quant_input_only:
tmp1 = a[mask] @ w1[i].transpose(0, 1)
tmp2 = SiluAndMul()(tmp1)
tmp2, tmp2_scale = moe_kernel_quantize_input(
tmp2, None, quant_dtype, per_act_token_quant
)
tmp2 = (tmp2.float() * tmp2_scale.view(-1, 1)).to(w2.dtype)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
elif block_shape is not None:
# block quantized
assert (
......
......@@ -554,6 +554,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
b_q_weight: torch.Tensor,
b_bias: torch.Tensor | None,
b_scales: torch.Tensor,
a_scales: torch.Tensor | None,
global_scale: torch.Tensor | None,
b_zeros: torch.Tensor | None,
g_idx: torch.Tensor | None,
......@@ -568,7 +569,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
use_fp32_reduce: bool = False,
is_zp_float: bool = False,
) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
dtype = a.dtype
if dtype not in [torch.half, torch.bfloat16]:
dtype = b_scales.dtype
return torch.empty((size_m, size_n), device=a.device, dtype=dtype)
@register_fake("_C::awq_dequantize")
def _awq_dequantize_fake(
......@@ -1167,8 +1171,11 @@ def gptq_marlin_repack(
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor:
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
return torch.ops._C.gptq_marlin_repack(
b_q_weight, perm, size_k, size_n, num_bits, is_a_8bit
)
if hasattr(torch.ops._C, "gptq_marlin_repack"):
......@@ -1180,6 +1187,7 @@ if hasattr(torch.ops._C, "gptq_marlin_repack"):
size_k: torch.SymInt,
size_n: torch.SymInt,
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor:
pack_factor = 32 // num_bits
marlin_tile_size = 16
......@@ -1192,9 +1200,15 @@ if hasattr(torch.ops._C, "gptq_marlin_repack"):
# awq_marlin
def awq_marlin_repack(
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
b_q_weight: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor:
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
return torch.ops._C.awq_marlin_repack(
b_q_weight, size_k, size_n, num_bits, is_a_8bit
)
if hasattr(torch.ops._C, "awq_marlin_repack"):
......@@ -1205,6 +1219,7 @@ if hasattr(torch.ops._C, "awq_marlin_repack"):
size_k: torch.SymInt,
size_n: torch.SymInt,
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor:
pack_factor = 32 // num_bits
marlin_tile_size = 16
......@@ -1221,6 +1236,7 @@ def gptq_marlin_moe_repack(
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
......@@ -1231,7 +1247,7 @@ def gptq_marlin_moe_repack(
)
for e in range(num_experts):
output[e] = torch.ops._C.gptq_marlin_repack(
b_q_weight[e], perm[e], size_k, size_n, num_bits
b_q_weight[e], perm[e], size_k, size_n, num_bits, is_a_8bit
)
return output
......@@ -1242,6 +1258,7 @@ def awq_marlin_moe_repack(
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
......@@ -1252,17 +1269,26 @@ def awq_marlin_moe_repack(
)
for e in range(num_experts):
output[e] = torch.ops._C.awq_marlin_repack(
b_q_weight[e], size_k, size_n, num_bits
b_q_weight[e], size_k, size_n, num_bits, is_a_8bit
)
return output
def marlin_int4_fp8_preprocess(
qweight: torch.Tensor,
qzeros_or_none: torch.Tensor | None = None,
inplace: bool = False,
):
return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace)
def gptq_marlin_gemm(
a: torch.Tensor,
c: torch.Tensor | None,
b_q_weight: torch.Tensor,
b_bias: torch.Tensor | None,
b_scales: torch.Tensor,
a_scales: torch.Tensor | None,
global_scale: torch.Tensor | None,
b_zeros: torch.Tensor | None,
g_idx: torch.Tensor | None,
......@@ -1283,6 +1309,7 @@ def gptq_marlin_gemm(
b_q_weight,
b_bias,
b_scales,
a_scales,
global_scale,
b_zeros,
g_idx,
......@@ -1600,7 +1627,7 @@ def allspark_repack_weight(
if use asymmetric quantization, has_zp = True.
Returns:
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] :
rearranged weight, scale, and optionally zero_point.
"""
K = qweight.shape[0]
......@@ -1683,7 +1710,7 @@ def scaled_int8_quant(
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
Returns:
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp.
"""
output = torch.empty_like(input, dtype=torch.int8)
if scale is not None:
......@@ -2004,6 +2031,7 @@ def moe_wna16_marlin_gemm(
b_qweight: torch.Tensor,
b_bias: torch.Tensor | None,
b_scales: torch.Tensor,
a_scales: torch.Tensor | None,
global_scale: torch.Tensor | None,
b_qzeros: torch.Tensor | None,
g_idx: torch.Tensor | None,
......@@ -2025,6 +2053,9 @@ def moe_wna16_marlin_gemm(
use_atomic_add: bool,
use_fp32_reduce: bool,
is_zp_float: bool,
thread_k: int = -1,
thread_n: int = -1,
blocks_per_sm: int = -1,
) -> torch.Tensor:
return torch.ops._moe_C.moe_wna16_marlin_gemm(
input,
......@@ -2032,6 +2063,7 @@ def moe_wna16_marlin_gemm(
b_qweight,
b_bias,
b_scales,
a_scales,
global_scale,
b_qzeros,
g_idx,
......@@ -2053,6 +2085,9 @@ def moe_wna16_marlin_gemm(
use_atomic_add,
use_fp32_reduce,
is_zp_float,
thread_k,
thread_n,
blocks_per_sm,
)
......@@ -2088,7 +2123,10 @@ if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe")
input: torch.Tensor,
output: torch.Tensor | None,
b_qweight: torch.Tensor,
b_bias: torch.Tensor | None,
b_scales: torch.Tensor,
a_scales: torch.Tensor | None,
global_scale: torch.Tensor | None,
b_qzeros: torch.Tensor | None,
g_idx: torch.Tensor | None,
perm: torch.Tensor | None,
......@@ -2109,7 +2147,7 @@ if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe")
use_atomic_add: bool,
use_fp32_reduce: bool,
is_zp_float: bool,
) -> torch.Tensor:
):
return torch.empty(
(size_m * top_k, size_n), dtype=input.dtype, device=input.device
)
......@@ -2583,7 +2621,7 @@ def onednn_scaled_int8_quant(
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
Returns:
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp.
"""
output = torch.empty_like(input, dtype=torch.int8)
token_num = input.numel() // input.shape[-1]
......
......@@ -145,6 +145,7 @@ if TYPE_CHECKING:
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict"
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_MARLIN_INPUT_DTYPE: Literal["int8", "fp8"] | None = None
VLLM_MXFP4_USE_MARLIN: bool | None = None
VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
......@@ -1122,6 +1123,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool(
os.environ.get("VLLM_MXFP4_USE_MARLIN", None)
),
# The activation dtype for marlin kernel
"VLLM_MARLIN_INPUT_DTYPE": env_with_choices(
"VLLM_MARLIN_INPUT_DTYPE", None, ["int8", "fp8"]
),
# Whether to turn on the outlines cache for V1
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.
......
......@@ -24,7 +24,7 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_in
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new,
marlin_moe_intermediate_size,
maybe_warn_marlin_atomic_add,
marlin_quant_input,
)
from vllm.scalar_type import ScalarType, scalar_types
......@@ -65,6 +65,8 @@ def _fused_marlin_moe(
activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None
] = default_activation_func,
input_global_scale1: torch.Tensor | None = None,
input_global_scale2: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None,
......@@ -77,6 +79,7 @@ def _fused_marlin_moe(
intermediate_cache13: torch.Tensor | None = None,
intermediate_cache2: torch.Tensor | None = None,
output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
is_k_full: bool = True,
) -> torch.Tensor:
assert hidden_states.ndim == 2
......@@ -106,18 +109,22 @@ def _fused_marlin_moe(
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * num_topk, N))
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
use_atomic_add = (
hidden_states.dtype == torch.half
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
)
a_scales1 = None
gate_up_input = hidden_states
if input_dtype == torch.int8:
gate_up_input, a_scales1 = marlin_quant_input(hidden_states, input_dtype)
if input_global_scale1 is not None:
a_scales1 = a_scales1 * input_global_scale1
elif input_dtype == torch.float8_e4m3fn:
gate_up_input, a_scales1 = marlin_quant_input(hidden_states, input_dtype)
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
hidden_states,
gate_up_input,
intermediate_cache1,
w1,
bias1,
w1_scale,
a_scales1,
global_scale1,
w1_zeros,
g_idx1,
......@@ -136,7 +143,7 @@ def _fused_marlin_moe(
size_n=2 * N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False,
)
......@@ -151,12 +158,25 @@ def _fused_marlin_moe(
if expert_map is not None:
output.zero_()
a_scales2 = None
if input_dtype == torch.int8:
intermediate_cache2, a_scales2 = marlin_quant_input(
intermediate_cache2, input_dtype
)
if input_global_scale2 is not None:
a_scales2 = a_scales2 * input_global_scale2
elif input_dtype == torch.float8_e4m3fn:
intermediate_cache2, a_scales2 = marlin_quant_input(
intermediate_cache2, input_dtype
)
output = ops.moe_wna16_marlin_gemm(
intermediate_cache2,
output,
w2,
bias2,
w2_scale,
a_scales2,
global_scale2,
w2_zeros,
g_idx2,
......@@ -175,7 +195,7 @@ def _fused_marlin_moe(
size_n=K,
size_k=N,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_atomic_add=False,
use_fp32_reduce=True,
is_zp_float=False,
)
......@@ -203,6 +223,8 @@ def fused_marlin_moe(
] = default_activation_func,
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
expert_map: torch.Tensor | None = None,
input_global_scale1: torch.Tensor | None = None,
input_global_scale2: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None,
......@@ -216,6 +238,7 @@ def fused_marlin_moe(
intermediate_cache2: torch.Tensor | None = None,
is_k_full: bool = True,
output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
inplace: bool = False,
) -> torch.Tensor:
"""
......@@ -287,6 +310,9 @@ def fused_marlin_moe(
if M * topk / E / block_size_m < 0.9:
break
if input_dtype is not None and input_dtype.itemsize == 1:
block_size_m = max(block_size_m, 16)
if global_num_experts == -1:
global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
......@@ -313,6 +339,8 @@ def fused_marlin_moe(
num_tokens_post_padded=num_tokens_post_padded,
activation=activation,
activation_func=activation_func,
input_global_scale1=input_global_scale1,
input_global_scale2=input_global_scale2,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
......@@ -325,6 +353,7 @@ def fused_marlin_moe(
intermediate_cache13=intermediate_cache13,
intermediate_cache2=intermediate_cache2,
output=None,
input_dtype=input_dtype,
is_k_full=is_k_full,
).view(-1, topk, K)
......
......@@ -266,7 +266,7 @@ class AutoRoundConfig(QuantizationConfig):
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMarlinLinearMethod,
AWQMoEMethod,
AWQMarlinMoEMethod,
)
quant_args_marlin = AWQMarlinConfig(
......@@ -291,7 +291,7 @@ class AutoRoundConfig(QuantizationConfig):
if isinstance(layer, FusedMoE):
if use_marlin:
return AWQMoEMethod(quant_args_marlin, layer.moe_config)
return AWQMarlinMoEMethod(quant_args_marlin, layer.moe)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
config = {
......
......@@ -106,7 +106,7 @@ class AWQConfig(QuantizationConfig):
return AWQLinearMethod(self)
elif isinstance(layer, FusedMoE):
# Lazy import to avoid circular import.
from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
from .awq_marlin import AWQMarlinConfig, AWQMarlinMoEMethod
from .moe_wna16 import MoeWNA16Config
from .utils.marlin_utils import check_moe_marlin_supports_layer
......@@ -136,7 +136,7 @@ class AWQConfig(QuantizationConfig):
awq_marlin_config = AWQMarlinConfig.from_config(
marlin_compatible_config_dict
)
return AWQMoEMethod(awq_marlin_config, layer.moe_config)
return AWQMarlinMoEMethod(awq_marlin_config, layer.moe_config)
return None
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
......
......@@ -40,6 +40,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
check_marlin_supports_layer,
check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_empty_g_idx,
marlin_make_workspace_new,
marlin_moe_permute_scales,
......@@ -69,7 +71,6 @@ class AWQMarlinConfig(QuantizationConfig):
# num_bits -> type
TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
def __init__(
......@@ -193,7 +194,9 @@ class AWQMarlinConfig(QuantizationConfig):
return AWQConfig.from_config(self.full_config).get_quant_method(
layer, prefix
)
return AWQMarlinLinearMethod(self)
quant_method = AWQMarlinLinearMethod(self)
quant_method.input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
......@@ -211,7 +214,9 @@ class AWQMarlinConfig(QuantizationConfig):
return MoeWNA16Config.from_config(self.full_config).get_quant_method(
layer, prefix
)
return AWQMoEMethod(self, layer.moe_config)
moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
return None
@classmethod
......@@ -270,6 +275,8 @@ class AWQMarlinLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQMarlinConfig) -> None:
self.quant_config = quant_config
self.quant_type = scalar_types.uint4
self.input_dtype = None
def create_weights(
self,
......@@ -312,6 +319,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
)
num_groups = input_size_per_partition // group_size
layer.num_groups = num_groups
qzeros = PackedvLLMParameter(
data=torch.empty(
......@@ -358,12 +366,19 @@ class AWQMarlinLinearMethod(LinearMethodBase):
# Allocate marlin workspace
layer.workspace = marlin_make_workspace_new(device)
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(layer.qweight, layer.qzeros, inplace=True)
layer.scales.data = layer.scales.data * 512
# Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack(
layer.qweight,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "qweight", marlin_qweight)
......@@ -373,7 +388,16 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups > 1:
marlin_scales, input_global_scale = marlin_act_int8_process_scales(
marlin_scales
)
layer.register_parameter(
"input_global_scale", Parameter(input_global_scale, requires_grad=False)
)
replace_parameter(layer, "scales", marlin_scales)
# Permute zero-points from AWQ format to marlin format.
......@@ -382,6 +406,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "qzeros", marlin_zp)
......@@ -409,11 +434,13 @@ class AWQMarlinLinearMethod(LinearMethodBase):
quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
input_global_scale=getattr(layer, "input_global_scale", None),
bias=bias,
input_dtype=self.input_dtype,
)
class AWQMoEMethod(FusedMoEMethodBase):
class AWQMarlinMoEMethod(FusedMoEMethodBase):
def __init__(
self,
quant_config: AWQMarlinConfig,
......@@ -422,8 +449,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
super().__init__(moe)
self.quant_config = quant_config
if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.")
raise ValueError("AWQMarlinMoEMethod only supports 4bit now.")
self.quant_type = scalar_types.uint4
self.input_dtype = None
self.use_marlin = True
def create_weights(
......@@ -435,6 +463,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.input_dtype = self.input_dtype
extra_weight_attrs.update(
{
"is_transposed": True,
......@@ -468,6 +497,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size
layer.num_groups_w13 = num_groups_w13
layer.num_groups_w2 = num_groups_w2
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
......@@ -522,6 +553,21 @@ class AWQMoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(
layer.w13_qweight.view(-1, layer.w13_qweight.size(2)),
layer.w13_qzeros.view(-1, layer.w13_qzeros.size(2)),
inplace=True,
)
ops.marlin_int4_fp8_preprocess(
layer.w2_qweight.view(-1, layer.w2_qweight.size(2)),
layer.w2_qzeros.view(-1, layer.w2_qzeros.size(2)),
inplace=True,
)
layer.w13_scales.data = layer.w13_scales.data * 512
layer.w2_scales.data = layer.w2_scales.data * 512
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
......@@ -538,6 +584,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w13_qweight.shape[1],
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
......@@ -547,6 +594,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w2_qweight.shape[1],
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
......@@ -556,7 +604,16 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
......@@ -565,7 +622,17 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.intermediate_size_per_partition,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
marlin_w13_zp = moe_awq_to_marlin_zero_points(
......@@ -573,6 +640,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w13_qzeros.shape[1],
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
......@@ -581,6 +649,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w2_qzeros.shape[1],
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
......@@ -636,6 +705,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
......@@ -643,4 +714,5 @@ class AWQMoEMethod(FusedMoEMethodBase):
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace,
input_dtype=self.input_dtype,
)
......@@ -157,7 +157,9 @@ class CompressedTensorsConfig(QuantizationConfig):
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix)
return CompressedTensorsMoEMethod.get_moe_method(
self, layer, layer_name=prefix
)
return None
def _add_fused_moe_to_target_scheme_map(self):
......@@ -547,6 +549,7 @@ class CompressedTensorsConfig(QuantizationConfig):
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
format: str | None = None,
layer_name: str | None = None,
) -> "CompressedTensorsScheme":
# use the per-layer format if defined, otherwise, use global format
format = format if format is not None else self.quant_format
......@@ -585,6 +588,7 @@ class CompressedTensorsConfig(QuantizationConfig):
symmetric=weight_quant.symmetric,
group_size=weight_quant.group_size,
actorder=weight_quant.actorder,
layer_name=layer_name,
)
act_quant_format = is_activation_quantization_format(format)
......@@ -724,7 +728,10 @@ class CompressedTensorsConfig(QuantizationConfig):
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore
weight_quant=weight_quant, input_quant=input_quant, format=format
weight_quant=weight_quant,
input_quant=input_quant,
format=format,
layer_name=layer_name,
)
# Raise error if device does not support the scheme
......
......@@ -64,6 +64,8 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_workspace_new,
marlin_moe_permute_scales,
)
......@@ -101,7 +103,7 @@ __all__ = [
"CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
"CompressedTensorsW4A4Nvfp4MoeMethod",
"CompressedTensorsW4A4Nvfp4MoEMethod",
"CompressedTensorsW4A8Int8MoEMethod",
]
......@@ -111,13 +113,13 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
prefix: str,
layer_name: str,
) -> "CompressedTensorsMoEMethod":
# FusedMoE was made by combining multiple Linears so need to
# make sure quantization config for Linear can target it
quant_config._add_fused_moe_to_target_scheme_map()
unfused_names = [
prefix + proj_name
layer_name + proj_name
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
]
# TODO: refactor this to use expert_mapping and check all layer numbers
......@@ -158,32 +160,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"WNA16MoE is not supported with actorder=group/dynamic."
)
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config)
return CompressedTensorsWNA16MoEMethod(
quant_config, layer.moe_config, layer_name
)
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config
quant_config, layer.moe_config, layer_name
)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4Nvfp4MoeMethod(layer.moe_config)
return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name)
elif (
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
or quant_config._is_fp8_w8a8(weight_quant, input_quant)
):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config, layer.moe_config)
return CompressedTensorsW8A8Fp8MoEMethod(
quant_config, layer.moe_config, layer_name
)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config, layer.moe_config)
return CompressedTensorsW8A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name
)
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
return CompressedTensorsW4A8Int8MoEMethod(quant_config, layer.moe_config)
return CompressedTensorsW4A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name
)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}"
)
class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig):
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support,
)
......@@ -194,17 +204,21 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.group_size = 16
self.layer_name = layer_name
self.marlin_input_dtype = (
get_marlin_input_dtype(layer_name) if self.use_marlin else None
)
self.flashinfer_moe_backend = None
if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
" for CompressedTensorsW4A4Nvfp4MoeMethod."
" for CompressedTensorsW4A4Nvfp4MoEMethod."
)
elif self.use_marlin:
logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoeMethod.")
logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoEMethod.")
else:
logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoeMethod.")
logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoEMethod.")
def create_weights(
self,
......@@ -354,7 +368,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
)
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
return
# w13
if (
......@@ -538,7 +552,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
):
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet."
"EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet."
)
return flashinfer_trtllm_fp4_moe(
......@@ -576,6 +590,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
......@@ -610,7 +625,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
assert expert_map is None, (
"Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4Nvfp4MoeMethod."
"CompressedTensorsW4A4Nvfp4MoEMethod."
)
assert self.moe_quant_config is not None
......@@ -637,6 +652,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
......@@ -690,6 +706,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
or self.is_fp8_w8a8_sm100
)
self.disable_expert_map = False
self.layer_name = layer_name
self.marlin_input_dtype = (
get_marlin_input_dtype(layer_name) if self.use_marlin else None
)
def create_weights(
self,
......@@ -931,7 +951,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
......@@ -1144,6 +1166,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
......@@ -1240,6 +1263,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
......@@ -1392,6 +1416,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
......@@ -1403,6 +1428,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
self.strategy = config.strategy
self.group_size = config.group_size
self.actorder = config.actorder
self.layer_name = layer_name
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
assert config.symmetric, "Only symmetric quantization is supported for MoE"
if not (
......@@ -1477,6 +1504,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size
layer.num_groups_w13 = num_groups_w13
layer.num_groups_w2 = num_groups_w2
w13_scale = torch.nn.Parameter(
torch.ones(
num_experts,
......@@ -1560,6 +1590,17 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_weight_g_idx.shape[0]
device = layer.w13_weight_g_idx.device
is_a_8bit = (
self.marlin_input_dtype is not None
and self.marlin_input_dtype.itemsize == 1
)
if self.marlin_input_dtype == torch.float8_e4m3fn:
# NOTE: for non-zp quantization format only
ops.marlin_int4_fp8_preprocess(layer.w13_weight_packed, inplace=True)
ops.marlin_int4_fp8_preprocess(layer.w2_weight_packed, inplace=True)
layer.w13_weight_scale.data = layer.w13_weight_scale.data * 512
layer.w2_weight_scale.data = layer.w2_weight_scale.data * 512
# when running models with grouped act order,
# resort to g_idx values provided in checkpoint
......@@ -1610,31 +1651,54 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_packed.shape[1] * self.packed_factor,
layer.w13_weight_packed.shape[2],
self.num_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_weight_packed,
layer.w2_g_idx_sort_indices,
layer.w2_weight_packed.shape[1] * self.packed_factor,
layer.w2_weight_packed.shape[2],
self.num_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_weight_scale,
size_k=layer.w13_weight_packed.shape[2],
size_n=layer.w13_weight_scale.shape[2],
group_size=self.group_size,
is_a_8bit=is_a_8bit,
)
if self.marlin_input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_weight_scale,
size_k=layer.w2_weight_scale.shape[1]
* (self.group_size if self.group_size != -1 else self.packed_factor),
size_n=layer.w2_weight_scale.shape[2],
group_size=self.group_size,
is_a_8bit=is_a_8bit,
)
if self.marlin_input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
layer.workspace = marlin_make_workspace_new(device, 4)
......@@ -1729,6 +1793,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
router_logits,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
......@@ -1738,6 +1804,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace,
input_dtype=self.marlin_input_dtype,
is_k_full=self.is_k_full,
)
......@@ -1747,6 +1814,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.quant_config = quant_config
......@@ -1999,6 +2067,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
layer_name: str | None = None,
):
super().__init__(moe)
self.has_bias = self.moe.has_bias
......
......@@ -14,7 +14,11 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig,
choose_mp_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import (
MarlinLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
marlin_repeat_scales_on_all_ranks,
)
from vllm.model_executor.parameter import (
......@@ -45,12 +49,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
group_size: int | None = None,
symmetric: bool | None = True,
actorder: ActivationOrdering | None = None,
layer_name: str | None = None,
):
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
self.layer_name = layer_name
if self.group_size == -1 and self.strategy != "channel":
raise ValueError(
......@@ -108,6 +114,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
logger.info("Using %s for CompressedTensorsWNA16", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
if isinstance(kernel_type, MarlinLinearKernel):
input_dtype = get_marlin_input_dtype(self.layer_name)
if input_dtype is not None:
mp_linear_kernel_config.act_type = input_dtype
# If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = input_size != input_size_per_partition
......
......@@ -69,6 +69,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_tensor_strategy,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
......@@ -316,7 +319,9 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
quant_method = Fp8LinearMethod(self)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix=prefix,
......@@ -324,7 +329,9 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
return Fp8MoEMethod(self, layer)
moe_quant_method = Fp8MoEMethod(self, layer)
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
......@@ -375,6 +382,7 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.marlin_input_dtype = None
self.use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
......@@ -552,7 +560,9 @@ class Fp8LinearMethod(LinearMethodBase):
)
if self.use_marlin:
prepare_fp8_layer_for_marlin(layer, size_k_first)
prepare_fp8_layer_for_marlin(
layer, size_k_first, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
del layer.input_scale
return
......@@ -610,6 +620,7 @@ class Fp8LinearMethod(LinearMethodBase):
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
input_dtype=self.marlin_input_dtype,
bias=bias,
)
......@@ -657,6 +668,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.block_quant, layer.moe_parallel_config
)
self.marlin_input_dtype = None
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
......@@ -1031,7 +1043,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight.data = w13_weight.data
if self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
......@@ -1270,6 +1284,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
......
......@@ -41,6 +41,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_workspace_new,
marlin_moe_permute_scales,
marlin_permute_bias,
......@@ -251,8 +253,21 @@ class GPTQMarlinConfig(QuantizationConfig):
return MoeWNA16Config.from_config(self.full_config).get_quant_method(
layer, prefix
)
return get_moe_quant_method(self, layer, prefix, GPTQMarlinMoEMethod)
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
moe_quant_method = get_moe_quant_method(
self, layer, prefix, GPTQMarlinMoEMethod
)
if moe_quant_method is None:
return None
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
quant_method = get_linear_quant_method(
self, layer, prefix, GPTQMarlinLinearMethod
)
if quant_method is None:
return None
quant_method.input_dtype = get_marlin_input_dtype(prefix)
return quant_method
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
......@@ -319,6 +334,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
self.input_dtype = None
self.quant_type = self.quant_config.quant_type
# Verify supported on platform.
verify_marlin_supported(
......@@ -339,6 +356,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
input_dtype = self.input_dtype
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
......@@ -347,7 +365,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition,
),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
act_type=params_dtype if input_dtype is None else input_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act,
......@@ -482,6 +500,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
self.quant_type = scalar_types.uint8b128
else:
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
self.input_dtype = None
self.use_marlin = True
def create_weights(
......@@ -493,6 +512,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.input_dtype = self.input_dtype
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
self.is_k_full = (not self.quant_config.desc_act) or (
......@@ -513,6 +540,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
scales_size2 = 1
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
layer.num_groups_w13 = scales_size13
layer.num_groups_w2 = scales_size2
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
# Fused gate_up_proj (column parallel)
w13_qweight = torch.nn.Parameter(
......@@ -630,6 +660,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(layer.w13_qweight, inplace=True)
ops.marlin_int4_fp8_preprocess(layer.w2_qweight, inplace=True)
layer.w13_scales.data = layer.w13_scales.data * 512
layer.w2_scales.data = layer.w2_scales.data * 512
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
......@@ -678,6 +721,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
......@@ -686,6 +730,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales
......@@ -694,7 +739,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
......@@ -706,7 +761,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
),
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
......@@ -761,6 +826,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
......@@ -771,4 +838,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
)
......@@ -351,6 +351,7 @@ class HQQMarlinMethod(LinearMethodBase):
bias,
scales,
None,
None,
zeros,
layer.g_idx,
layer.g_idx_sort_indices,
......
......@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES,
apply_gptq_marlin_linear,
check_marlin_supports_shape,
marlin_act_int8_process_scales,
marlin_is_k_full,
marlin_make_empty_g_idx,
marlin_make_workspace_new,
......@@ -21,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
......@@ -65,6 +67,18 @@ class MarlinLinearKernel(MPLinearKernel):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
is_a_8bit = c.act_type is not None and c.act_type.itemsize == 1
if is_a_8bit:
assert c.weight_type == scalar_types.uint4b8, (
"W8A8 is not supported by marlin kernel."
)
if c.act_type == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(getattr(layer, self.w_q_name), inplace=True)
getattr(layer, self.w_s_name).data = (
getattr(layer, self.w_s_name).data * 512
)
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
......@@ -88,6 +102,7 @@ class MarlinLinearKernel(MPLinearKernel):
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
)
return x
......@@ -99,7 +114,22 @@ class MarlinLinearKernel(MPLinearKernel):
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size,
is_a_8bit=is_a_8bit,
)
if c.group_size == -1:
num_groups = 1
else:
num_groups = c.partition_weight_shape[0] // c.group_size
if c.act_type == torch.int8 and num_groups > 1:
x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
layer.register_parameter(
"input_global_scale",
torch.nn.Parameter(input_global_scale, requires_grad=False),
)
else:
layer.input_global_scale = None
return x
if c.has_g_idx:
......@@ -129,6 +159,7 @@ class MarlinLinearKernel(MPLinearKernel):
size_k=grouped_k,
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
),
)
else:
......@@ -150,6 +181,7 @@ class MarlinLinearKernel(MPLinearKernel):
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return apply_gptq_marlin_linear(
input=x,
weight=w_q,
......@@ -162,5 +194,7 @@ class MarlinLinearKernel(MPLinearKernel):
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full,
input_global_scale=getattr(layer, "input_global_scale", None),
bias=bias,
input_dtype=c.act_type,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment