"vllm/vscode:/vscode.git/clone" did not exist on "560ffc4c7c7c63f7a01877e50e59ece406c39ef5"
Unverified Commit 1656ad37 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)


Signed-off-by: default avatarJinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Signed-off-by: default avatarJinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin@redhat.com>
parent fa59fe41
#include "marlin.cuh"
#include "core/registration.h"
// for only non-zp format (like gptq)
__global__ void marlin_int4_fp8_preprocess_kernel_without_zp(
// qweight: (size_k * size_n // 8,)
const int32_t* __restrict__ qweight,
// output: same shape with qweight
int32_t* __restrict__ output) {
int32_t val = qweight[blockIdx.x * 32 + threadIdx.x];
int32_t new_val = 0;
#pragma unroll
for (int32_t i = 0; i < 8; i++) {
int32_t single_val = val & 0xF;
single_val = single_val >= 8 ? single_val - 8 : 15 - single_val;
new_val |= single_val << (i * 4);
val >>= 4;
}
output[blockIdx.x * 32 + threadIdx.x] = new_val;
}
// for awq format only (with zp and with awq weight layout)
__global__ void marlin_int4_fp8_preprocess_kernel_awq(
// AWQ qweight: (size_k, size_n // 8)
const int32_t* __restrict__ qweight,
// output: same shape with qweight
int32_t* __restrict__ output,
// AWQ zeros: (size_k // group_size, size_n // 8)
const int32_t* __restrict__ qzeros, int32_t size_n, int32_t size_k,
int32_t group_size) {
int32_t val =
qweight[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y];
int32_t zero =
qzeros[(blockIdx.x * 32 + threadIdx.x) / group_size * size_n / 8 +
blockIdx.y];
int32_t new_val = 0;
#pragma unroll
for (int32_t i = 0; i < 8; i++) {
int32_t single_val = val & 0xF;
int32_t single_zero = zero & 0xF;
single_val =
single_val >= single_zero ? single_val - single_zero : 15 - single_val;
new_val |= single_val << (i * 4);
val >>= 4;
zero >>= 4;
}
output[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y] = new_val;
}
torch::Tensor marlin_int4_fp8_preprocess(
torch::Tensor& qweight, std::optional<torch::Tensor> qzeros_or_none,
bool inplace) {
TORCH_CHECK(qweight.device().is_cuda(), "qweight is not on GPU");
TORCH_CHECK(qweight.scalar_type() == at::ScalarType::Int,
"qweight.dtype != torch.int32");
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
torch::Tensor output = inplace ? qweight : torch::empty_like(qweight);
if (!qzeros_or_none.has_value()) {
TORCH_CHECK(qweight.numel() * 8 % 256 == 0,
"qweight.numel() * 8 % 256 != 0");
int blocks = qweight.numel() * 8 / 256;
marlin_int4_fp8_preprocess_kernel_without_zp<<<blocks, 32>>>(
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr());
} else {
int32_t size_k = qweight.size(0);
int32_t size_n = qweight.size(1) * 8;
torch::Tensor qzeros = qzeros_or_none.value();
TORCH_CHECK(size_k % 32 == 0, "size_k % 32 != 0");
TORCH_CHECK(qzeros.device().is_cuda(), "qzeros is not on GPU");
TORCH_CHECK(qzeros.scalar_type() == at::ScalarType::Int,
"qweight.dtype != torch.int32");
TORCH_CHECK(device_of(qweight) == device_of(qzeros),
"qzeros is not on the same device with qweight");
int32_t group_size = qweight.size(0) / qzeros.size(0);
TORCH_CHECK(qweight.size(1) == qzeros.size(1),
"qweight.size(1) != qzeros.size(1)");
TORCH_CHECK(qweight.size(0) % qzeros.size(0) == 0,
"qweight.size(0) % qzeros.size(0) != 0");
TORCH_CHECK(group_size % 8 == 0, "group_size % 8 != 0");
dim3 blocks(size_k / 32, size_n / 8);
marlin_int4_fp8_preprocess_kernel_awq<<<blocks, 32>>>(
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr(),
(const int32_t*)qzeros.data_ptr(), size_n, size_k, group_size);
}
return output;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_int4_fp8_preprocess", &marlin_int4_fp8_preprocess);
}
...@@ -38,7 +38,7 @@ namespace MARLIN_NAMESPACE_NAME { ...@@ -38,7 +38,7 @@ namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <typename scalar_t, // compute dtype, half or nv_float16 template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId b_type_id, // weight MarlinScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
...@@ -77,65 +77,139 @@ __global__ void Marlin( ...@@ -77,65 +77,139 @@ __global__ void Marlin(
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation. // output/accumulation.
template <typename scalar_t> template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag, __device__ inline void mma(
const typename ScalarType<scalar_t>::FragB& frag_b, const typename MarlinScalarType<type_id>::FragA& a_frag,
typename ScalarType<scalar_t>::FragC& frag_c) { const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag); const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b); const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
float* c = reinterpret_cast<float*>(&frag_c); using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
if constexpr (std::is_same<scalar_t, half>::value) { if constexpr (k_size == 16) {
asm volatile( if constexpr (std::is_same<scalar_t, half>::value) {
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " float* c = reinterpret_cast<float*>(&frag_c);
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" asm volatile(
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
asm volatile( "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" float* c = reinterpret_cast<float*>(&frag_c);
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) asm volatile(
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
} else { : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]),
"f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]),
"r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} else if (k_size == 32) {
if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} }
} }
template <typename scalar_t> template <vllm::ScalarTypeId type_id, int k_size = 16>
__device__ inline void mma_trans( __device__ inline void mma_trans(
const typename ScalarType<scalar_t>::FragA& a_frag, const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename ScalarType<scalar_t>::FragB& frag_b, const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename ScalarType<scalar_t>::FragB& frag_b2, const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename ScalarType<scalar_t>::FragC& frag_c) { typename MarlinScalarType<type_id>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag); const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b); const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2); const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c); float* c = reinterpret_cast<float*>(&frag_c);
if constexpr (std::is_same<scalar_t, half>::value) { using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
asm volatile( if constexpr (k_size == 16) {
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " if constexpr (std::is_same<scalar_t, half>::value) {
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" float* c = reinterpret_cast<float*>(&frag_c);
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) asm volatile(
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
asm volatile( : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) float* c = reinterpret_cast<float*>(&frag_c);
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), asm volatile(
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
"r"(c[3]));
}
} else { } else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); if constexpr (std::is_same<scalar_t, __nv_fp8_e4m3>::value) {
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, int8_t>::value) {
int32_t* c = reinterpret_cast<int32_t*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
: "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]),
"r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
}
} }
} }
// Instruction for loading a full 16x16 matrix fragment of operand A from shared // Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout. // memory, directly in tensor core layout.
template <int count, typename scalar_t> template <int count, vllm::ScalarTypeId type_id>
__device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a, __device__ inline void ldsm(typename MarlinScalarType<type_id>::FragA& frag_a,
const void* smem_ptr) { const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a); uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
...@@ -159,47 +233,54 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a, ...@@ -159,47 +233,54 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
// Multiply dequantized values by the corresponding quantization scale; used // Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization. // only for grouped quantization.
template <typename scalar_t> template <vllm::ScalarTypeId type_id>
__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b, __device__ inline void scale(typename MarlinScalarType<type_id>::FragB& frag_b,
typename ScalarType<scalar_t>::FragS& frag_s, typename MarlinScalarType<type_id>::FragS& frag_s,
int i) { int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
scalar_t2 s = using scalar_t2 = typename MarlinScalarType<type_id>::scalar_t2;
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]); scalar_t2 s = MarlinScalarType<type_id>::num2num2(
reinterpret_cast<scalar_t*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s); frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s); frag_b[1] = __hmul2(frag_b[1], s);
} }
template <typename scalar_t> template <vllm::ScalarTypeId type_id>
__device__ inline void scale_and_sub( __device__ inline void scale_and_sub(
typename ScalarType<scalar_t>::FragB& frag_b, scalar_t s, scalar_t zp) { typename MarlinScalarType<type_id>::FragB& frag_b,
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; typename MarlinScalarType<type_id>::scalar_t s,
scalar_t2 s2 = ScalarType<scalar_t>::num2num2(s); typename MarlinScalarType<type_id>::scalar_t zp) {
scalar_t2 zp2 = ScalarType<scalar_t>::num2num2(zp); using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
using scalar_t2 = typename MarlinScalarType<type_id>::scalar_t2;
scalar_t2 s2 = MarlinScalarType<type_id>::num2num2(s);
scalar_t2 zp2 = MarlinScalarType<type_id>::num2num2(zp);
frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2));
frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2));
} }
template <typename scalar_t> template <vllm::ScalarTypeId type_id>
__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b, __device__ inline void sub_zp(
typename ScalarType<scalar_t>::scalar_t2& frag_zp, typename MarlinScalarType<type_id>::FragB& frag_b,
int i) { typename MarlinScalarType<type_id>::scalar_t2& frag_zp, int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
scalar_t2 zp = using scalar_t2 = typename MarlinScalarType<type_id>::scalar_t2;
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]); scalar_t2 zp = MarlinScalarType<type_id>::num2num2(
reinterpret_cast<scalar_t*>(&frag_zp)[i]);
frag_b[0] = __hsub2(frag_b[0], zp); frag_b[0] = __hsub2(frag_b[0], zp);
frag_b[1] = __hsub2(frag_b[1], zp); frag_b[1] = __hsub2(frag_b[1], zp);
} }
// Same as above, but for act_order (each K is multiplied individually) // Same as above, but for act_order (each K is multiplied individually)
template <typename scalar_t> template <vllm::ScalarTypeId type_id>
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b, __device__ inline void scale4(
typename ScalarType<scalar_t>::FragS& frag_s_1, typename MarlinScalarType<type_id>::FragB& frag_b,
typename ScalarType<scalar_t>::FragS& frag_s_2, typename MarlinScalarType<type_id>::FragS& frag_s_1,
typename ScalarType<scalar_t>::FragS& frag_s_3, typename MarlinScalarType<type_id>::FragS& frag_s_2,
typename ScalarType<scalar_t>::FragS& frag_s_4, typename MarlinScalarType<type_id>::FragS& frag_s_3,
int i) { typename MarlinScalarType<type_id>::FragS& frag_s_4, int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
using scalar_t2 = typename MarlinScalarType<type_id>::scalar_t2;
scalar_t2 s_val_1_2; scalar_t2 s_val_1_2;
s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i]; s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];
s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i]; s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];
...@@ -213,12 +294,13 @@ __device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b, ...@@ -213,12 +294,13 @@ __device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
} }
// Given 2 floats multiply by 2 scales (halves) // Given 2 floats multiply by 2 scales (halves)
template <typename scalar_t> template <vllm::ScalarTypeId type_id>
__device__ inline void scale_float(float* c, __device__ inline void scale_float(
typename ScalarType<scalar_t>::FragS& s) { float* c, typename MarlinScalarType<type_id>::FragS& s) {
using scalar_t = typename MarlinScalarType<type_id>::scalar_t;
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s); scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0])); c[0] = __fmul_rn(c[0], MarlinScalarType<type_id>::num2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1])); c[1] = __fmul_rn(c[1], MarlinScalarType<type_id>::num2float(s_ptr[1]));
} }
// Wait until barrier reaches `count`, then lock for current threadblock. // Wait until barrier reaches `count`, then lock for current threadblock.
...@@ -270,9 +352,10 @@ __device__ inline void wait_negative_and_add(int* lock) { ...@@ -270,9 +352,10 @@ __device__ inline void wait_negative_and_add(int* lock) {
__syncthreads(); __syncthreads();
} }
template <typename scalar_t, // compute dtype, half or nv_float16 template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId w_type_id, // weight ScalarType id const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the // dimension (batchsize) of the
...@@ -288,18 +371,23 @@ template <typename scalar_t, // compute dtype, half or nv_float16 ...@@ -288,18 +371,23 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const bool is_zp_float // is zero point of float16 type? const bool is_zp_float // is zero point of float16 type?
> >
__global__ void Marlin( __global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ A0, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn int4* __restrict__ C0, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ b_bias_ptr, const int4* __restrict__ b_bias_ptr,
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // float scales of input matrix, only used when is_a_8bit == true.
// (k/groupsize)xn // shape (m,)
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 const float* __restrict__ a_scales_ptr,
// only) // fp16 quantization scales. shape (k/groupsize, n)
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape const int4* __restrict__ scales_ptr,
// (k/groupsize)x(n/pack_factor) // fp16 global scale (for nvfp4// only)
const int* __restrict__ g_idx, // int32 group indices of shape k const uint16_t* __restrict__ global_scale_ptr,
// 4bit packed zero-points of shape
// (k/groupsize, n/pack_factor)
const int4* __restrict__ zp_ptr,
// int32 group indices of shape k
const int* __restrict__ g_idx,
int num_groups, // number of scale groups per output channel int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m int prob_m, // batch dimension m
int prob_n, // output dimension n int prob_n, // output dimension n
...@@ -321,17 +409,35 @@ __global__ void Marlin( ...@@ -321,17 +409,35 @@ __global__ void Marlin(
// ensures good utilization of all SMs for many kinds of shape and GPU // ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock // configurations, while requiring as few slow global cross-threadblock
// reductions as possible. // reductions as possible.
using Dtype = ScalarType<scalar_t>;
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 890
using FragA = typename ScalarType<scalar_t>::FragA; // FP8 computation is only supported for Ada Lovelace or newer architectures.
using FragB = typename ScalarType<scalar_t>::FragB; if constexpr (a_type_id == vllm::kFE4M3fn.id()) return;
using FragC = typename ScalarType<scalar_t>::FragC; #endif
using FragS = typename ScalarType<scalar_t>::FragS;
using FragZP = typename ScalarType<scalar_t>::FragZP; using Adtype = MarlinScalarType<a_type_id>;
using Cdtype = MarlinScalarType<c_type_id>;
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); const int4* A = A0;
int4* C = C0;
using scalar_t = typename MarlinScalarType<a_type_id>::scalar_t;
using scalar_t2 = typename MarlinScalarType<a_type_id>::scalar_t2;
using scalar_32bit_t = typename MarlinScalarType<a_type_id>::scalar_32bit_t;
using c_scalar_t = typename MarlinScalarType<c_type_id>::scalar_t;
using c_scalar_t2 = typename MarlinScalarType<c_type_id>::scalar_t2;
using FragA = typename MarlinScalarType<a_type_id>::FragA;
using FragB = typename MarlinScalarType<a_type_id>::FragB;
using FragC = typename MarlinScalarType<a_type_id>::FragC;
using FragS = typename MarlinScalarType<c_type_id>::FragS;
using FragZP = typename MarlinScalarType<c_type_id>::FragZP;
static constexpr auto a_type = vllm::ScalarType::from_id(a_type_id);
static constexpr auto b_type = vllm::ScalarType::from_id(b_type_id);
static constexpr auto c_type = vllm::ScalarType::from_id(c_type_id);
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id); static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
if constexpr (w_type == vllm::kFE2M1f) { if constexpr (b_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2); s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
...@@ -340,27 +446,35 @@ __global__ void Marlin( ...@@ -340,27 +446,35 @@ __global__ void Marlin(
static_assert(s_type == vllm::kFloat16); static_assert(s_type == vllm::kFloat16);
} }
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; constexpr bool is_a_8bit = a_type.size_bits() == 8;
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || if constexpr (!is_a_8bit) {
w_type == vllm::kU4B8 || w_type == vllm::kU8B128; static_assert(std::is_same<scalar_t, c_scalar_t>::value);
}
constexpr bool has_zp = b_type == vllm::kU4 || b_type == vllm::kU8;
constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 ||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
// see comments of dequant.h for more details // see comments of dequant.h for more details
constexpr bool dequant_skip_flop = constexpr bool dequant_skip_flop =
w_type == vllm::kFE4M3fn || is_a_8bit || b_type == vllm::kFE4M3fn ||
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value || has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(w_type == vllm::kU8); has_zp && !is_zp_float && !(b_type == vllm::kU8);
c_scalar_t2 global_scale;
scalar_t2 global_scale; if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { uint16_t val = global_scale_ptr[0];
// NVFP4 format requires global scale global_scale = Cdtype::num2num2(*reinterpret_cast<c_scalar_t*>(&val));
uint16_t val = scale2_ptr[0];
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
} }
constexpr bool has_act_order = group_blocks == 0; constexpr bool has_act_order = group_blocks == 0;
constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
constexpr int pack_factor = 32 / w_type.size_bits(); extern __shared__ int4 sh[];
float* sh_a_s = reinterpret_cast<float*>(sh);
int4* sh_new = sh + (is_a_8bit ? (4 * thread_m_blocks) : 0);
constexpr int pack_factor = 32 / b_type.size_bits();
static_assert(thread_m_blocks == 1 || !m_block_size_8); static_assert(thread_m_blocks == 1 || !m_block_size_8);
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
...@@ -373,7 +487,19 @@ __global__ void Marlin( ...@@ -373,7 +487,19 @@ __global__ void Marlin(
int k_tiles = prob_k / 16 / thread_k_blocks; int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks; int n_tiles = prob_n / 16 / thread_n_blocks;
int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
int global_mn_tiles = parallel * n_tiles;
int part2_mn_tiles = global_mn_tiles;
int part1_mn_iters = 0;
bool in_part2 = false;
if (global_mn_tiles > gridDim.x) {
part2_mn_tiles = global_mn_tiles % gridDim.x;
if (part2_mn_tiles * 3 <= gridDim.x) part2_mn_tiles += gridDim.x;
part1_mn_iters = (global_mn_tiles - part2_mn_tiles) / gridDim.x;
}
int iters = div_ceil(k_tiles * part2_mn_tiles, gridDim.x);
if constexpr (!has_act_order && group_blocks != -1) { if constexpr (!has_act_order && group_blocks != -1) {
if (group_blocks >= thread_k_blocks) { if (group_blocks >= thread_k_blocks) {
...@@ -385,28 +511,21 @@ __global__ void Marlin( ...@@ -385,28 +511,21 @@ __global__ void Marlin(
} }
} }
int slice_row = (iters * blockIdx.x) % k_tiles; int slice_row = 0;
int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col_par = blockIdx.x;
int slice_col = slice_col_par; int slice_col;
int slice_iters; // number of threadblock tiles in the current slice int slice_iters =
int slice_count = k_tiles; // number of threadblock tiles in the current slice
0; // total number of active threadblocks in the current slice // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to int slice_count = 1;
// top // index of threadblock in current slice; numbered bottom to top
int slice_idx = 0;
int par_id = 0; int par_id = 0;
int locks_off = 0; int locks_off = 0;
// We can easily implement parallel problem execution by just remapping if (part2_mn_tiles >= gridDim.x) {
// indices and advancing global pointers // when part2_mn_tiles >= sms
if (slice_col_par >= n_tiles) {
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8;
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
slice_col = slice_col_par % n_tiles;
par_id = slice_col_par / n_tiles;
}
if (parallel * n_tiles >= gridDim.x) {
// when parallel * n_tiles >= sms
// then there are at most $sms$ conflict tile blocks // then there are at most $sms$ conflict tile blocks
locks_off = blockIdx.x; locks_off = blockIdx.x;
} else { } else {
...@@ -415,10 +534,11 @@ __global__ void Marlin( ...@@ -415,10 +534,11 @@ __global__ void Marlin(
// Compute all information about the current slice which is required for // Compute all information about the current slice which is required for
// synchronization. // synchronization.
auto init_slice = [&](bool first_init = false) { bool first_init = true;
auto init_part2_slice = [&]() {
slice_iters = slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; if (slice_iters < 0 || slice_col_par >= part2_mn_tiles) slice_iters = 0;
if (slice_iters == 0) return; if (slice_iters == 0) return;
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
slice_count = 1; slice_count = 1;
...@@ -436,7 +556,7 @@ __global__ void Marlin( ...@@ -436,7 +556,7 @@ __global__ void Marlin(
if (col_off > 0) slice_idx--; if (col_off > 0) slice_idx--;
} }
} }
if (parallel * n_tiles >= gridDim.x) { if (part2_mn_tiles >= gridDim.x) {
if (slice_count > 1 && slice_idx == slice_count - 1) { if (slice_count > 1 && slice_idx == slice_count - 1) {
locks_off++; locks_off++;
} }
...@@ -466,28 +586,68 @@ __global__ void Marlin( ...@@ -466,28 +586,68 @@ __global__ void Marlin(
} }
if (slice_col == n_tiles) { if (slice_col == n_tiles) {
A += 16 * thread_m_blocks * lda / 8; A += 16 * thread_m_blocks * lda / (is_a_8bit ? 16 : 8);
C += 16 * thread_m_blocks * prob_n / 8; C += 16 * thread_m_blocks * prob_n / 8;
slice_col = 0; slice_col = 0;
par_id++; par_id++;
} }
if (is_a_8bit && (first_init || slice_col == 0)) {
__syncthreads();
int a_s_gl_rd = par_id * 16 * thread_m_blocks + threadIdx.x;
cp_async1_ca_pred(&sh_a_s[threadIdx.x], &a_scales_ptr[a_s_gl_rd],
threadIdx.x < prob_m);
}
}; };
init_slice(true);
auto init_part1_slice = [&]() {
if (part1_mn_iters) {
part1_mn_iters--;
par_id = slice_col_par / n_tiles;
slice_col = slice_col_par % n_tiles;
slice_iters = k_tiles;
A = A0 + 16 * thread_m_blocks / (is_a_8bit ? 16 : 8) * par_id * lda;
C = C0 + 16 * thread_m_blocks / 8 * par_id * prob_n;
if (is_a_8bit) {
__syncthreads();
int a_s_gl_rd = par_id * 16 * thread_m_blocks + threadIdx.x;
cp_async1_ca_pred(&sh_a_s[threadIdx.x], &a_scales_ptr[a_s_gl_rd],
threadIdx.x < prob_m);
}
}
};
auto init_slice = [&]() {
if (!in_part2 && !part1_mn_iters) {
in_part2 = true;
slice_col_par = (iters * blockIdx.x) / k_tiles;
slice_row = (iters * blockIdx.x) % k_tiles;
slice_col = (slice_col_par + global_mn_tiles - part2_mn_tiles) % n_tiles;
par_id = (slice_col_par + global_mn_tiles - part2_mn_tiles) / n_tiles;
A = A0 + 16 * thread_m_blocks / (is_a_8bit ? 16 : 8) * par_id * lda;
C = C0 + 16 * thread_m_blocks / 8 * par_id * prob_n;
}
if (!in_part2) {
init_part1_slice();
} else {
init_part2_slice();
first_init = false;
}
};
init_slice();
// A sizes/strides // A sizes/strides
// stride of the A matrix in global memory // stride of the A matrix in global memory
int a_gl_stride = lda / 8; int a_gl_stride = lda / (is_a_8bit ? 16 : 8);
// stride of an A matrix tile in shared memory // stride of an A matrix tile in shared memory
constexpr int a_sh_stride = 16 * thread_k_blocks / 8; constexpr int a_sh_stride = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8);
// delta between subsequent A tiles in global memory // delta between subsequent A tiles in global memory
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8);
// between subsequent accesses within a tile // between subsequent accesses within a tile
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
// between shared memory writes // between shared memory writes
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
// between shared memory tile reads
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
// within a shared memory tile // within a shared memory tile
constexpr int a_sh_rd_delta_i = a_sh_stride * 16; constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
// overall size of a tile // overall size of a tile
...@@ -496,24 +656,25 @@ __global__ void Marlin( ...@@ -496,24 +656,25 @@ __global__ void Marlin(
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
// B sizes/strides // B sizes/strides
int b_gl_stride = 16 * prob_n / (pack_factor * 4); int b_gl_stride = 16 * prob_n / (pack_factor * (is_a_8bit ? 2 : 4));
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; constexpr int b_sh_stride =
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; ((thread_n_blocks * 16) * 16 / pack_factor) / (is_a_8bit ? 2 : 4);
constexpr int b_thread_vecs = b_type.size_bits() == 4 ? 1 : 2;
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks / (is_a_8bit ? 2 : 1);
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
constexpr int b_sh_wr_delta = threads * b_thread_vecs; constexpr int b_sh_wr_delta = threads * b_thread_vecs;
constexpr int b_sh_rd_delta = threads * b_thread_vecs; constexpr int b_sh_stage =
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; b_sh_stride * thread_k_blocks / (is_a_8bit ? 2 : 1);
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order // Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8; int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_sh_stride =
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_tb_groups = constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) ? thread_k_blocks / group_blocks
: 1; : 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride; constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride; int s_gl_rd_delta = s_gl_stride;
...@@ -527,7 +688,7 @@ __global__ void Marlin( ...@@ -527,7 +688,7 @@ __global__ void Marlin(
int act_s_col_stride = 1; int act_s_col_stride = 1;
int act_s_col_warp_stride = act_s_col_stride * 8; int act_s_col_warp_stride = act_s_col_stride * 8;
int tb_n_warps = thread_n_blocks / 4; constexpr int tb_n_warps = thread_n_blocks / (is_a_8bit ? 2 : 4);
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
// Zero-points sizes/strides // Zero-points sizes/strides
...@@ -550,17 +711,22 @@ __global__ void Marlin( ...@@ -550,17 +711,22 @@ __global__ void Marlin(
int a_sh_rd = int a_sh_rd =
a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) +
(threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1));
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); a_sh_rd += 2 * ((threadIdx.x / 32) / tb_n_warps) * b_sh_wr_iters;
int b_gl_rd;
if (threads <= b_sh_stride) {
b_gl_rd = threadIdx.x;
} else {
b_gl_rd =
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
}
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row; b_gl_rd += b_gl_rd_delta_o * slice_row;
auto b_sh_wr = threadIdx.x * b_thread_vecs;
auto b_sh_rd = threadIdx.x * b_thread_vecs; auto b_sh_rd = threadIdx.x * b_thread_vecs;
b_sh_rd += b_sh_rd / b_sh_stride * (b_sh_stride * (b_sh_wr_iters - 1));
// For act_order // For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
int slice_k_start = tb_k * slice_row; int slice_k_start = tb_k * slice_row;
int slice_k_finish = slice_k_start + tb_k * slice_iters; int slice_k_finish = slice_k_start + tb_k * slice_iters;
int slice_k_start_shared_fetch = slice_k_start; int slice_k_start_shared_fetch = slice_k_start;
...@@ -571,58 +737,54 @@ __global__ void Marlin( ...@@ -571,58 +737,54 @@ __global__ void Marlin(
if constexpr (!has_act_order) { if constexpr (!has_act_order) {
if constexpr (group_blocks == -1) { if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else { } else if constexpr (group_blocks >= thread_k_blocks) {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
(w_type == vllm::kFE2M1f ? 2 : 1) +
s_sh_stride * slice_col + threadIdx.x; s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks +
threadIdx.x / s_sh_stride) +
s_sh_stride * slice_col + threadIdx.x % s_sh_stride;
} }
} }
auto s_sh_wr = threadIdx.x; auto s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride; bool s_sh_wr_pred = threadIdx.x < s_sh_stage;
// Zero-points // Zero-points
int zp_gl_rd; int zp_gl_rd;
if constexpr (has_zp) { if constexpr (has_zp) {
if constexpr (group_blocks == -1) { if constexpr (group_blocks == -1) {
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else { } else if constexpr (group_blocks >= thread_k_blocks) {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
zp_sh_stride * slice_col + threadIdx.x; zp_sh_stride * slice_col + threadIdx.x;
} else {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks +
threadIdx.x / zp_sh_stride) +
zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride;
} }
} }
auto zp_sh_wr = threadIdx.x; auto zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; bool zp_sh_wr_pred = zp_sh_stage > 0 && threadIdx.x < zp_sh_stage;
// We use a different scale layout for grouped and column-wise quantization as // We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in // we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case. // row-major in the latter case.
int s_sh_rd; int s_sh_rd;
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { if constexpr (is_a_8bit) {
auto warp_id = threadIdx.x / 32; s_sh_rd = 4 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 4);
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
} else if constexpr (group_blocks != -1) } else if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4;
(threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 && else if constexpr (group_blocks == -1 &&
(m_block_size_8 || (has_zp && !dequant_skip_flop))) (m_block_size_8 || (has_zp && !dequant_skip_flop)))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8;
(threadIdx.x % 32) / 8;
else else
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4;
(threadIdx.x % 32) % 4;
int bias_sh_rd; int bias_sh_rd;
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + bias_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8;
(threadIdx.x % 32) / 8;
} else { } else {
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + bias_sh_rd = (is_a_8bit ? 4 : 8) * ((threadIdx.x / 32) % tb_n_warps) +
(threadIdx.x % 32) % 4; (threadIdx.x % 32) % 4;
} }
...@@ -638,12 +800,16 @@ __global__ void Marlin( ...@@ -638,12 +800,16 @@ __global__ void Marlin(
if constexpr (has_zp) { if constexpr (has_zp) {
if constexpr (is_zp_float) { if constexpr (is_zp_float) {
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + zp_sh_rd =
(threadIdx.x % 32) / 4; 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4;
} }
} else if (is_a_8bit) {
zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % tb_n_warps / 2) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
} else { } else {
zp_sh_rd = num_ints_per_thread * num_col_threads * zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % (thread_n_blocks / 4)) + ((threadIdx.x / 32) % tb_n_warps) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
} }
} }
...@@ -678,26 +844,19 @@ __global__ void Marlin( ...@@ -678,26 +844,19 @@ __global__ void Marlin(
for (int i = 0; i < b_sh_wr_iters; i++) { for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < thread_m_blocks; j++) for (int j = 0; j < thread_m_blocks; j++)
a_sh_rd_trans[i][j] = a_sh_rd_trans[i][j] = transform_a(2 * i + a_sh_rd_delta_i * j + a_sh_rd);
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
} }
// Since B-accesses have non-constant stride they have to be computed at // Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by // runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny // maintining multiple pointers (we have enough registers), a tiny
// optimization. // optimization.
const int4* B_ptr[b_sh_wr_iters];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines. // Shared memory storage for global fetch pipelines.
constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks;
constexpr int sh_b_size = stages * b_sh_stage; constexpr int sh_b_size = stages * b_sh_stage;
int4* sh_b = sh; int4* sh_b = sh_new;
int4* sh_red = sh; int4* sh_red = sh_new;
constexpr int sh_size_b_red_min = constexpr int sh_size_b_red_min =
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size); (sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
constexpr int sh_size_b_red_max = constexpr int sh_size_b_red_max =
...@@ -708,8 +867,8 @@ __global__ void Marlin( ...@@ -708,8 +867,8 @@ __global__ void Marlin(
? sh_size_b_red_max ? sh_size_b_red_max
: (sh_size_b_red_min + sh_bias_size); : (sh_size_b_red_min + sh_bias_size);
int4* sh_bias = sh + sh_size_b_red_min; int4* sh_bias = sh_new + sh_size_b_red_min;
int4* sh_g_idx = sh + sh_b_red_bias_size; int4* sh_g_idx = sh_new + sh_b_red_bias_size;
int4* sh_zp = sh_g_idx + (stages * g_idx_stage); int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage); : (stages * s_sh_stage);
...@@ -723,7 +882,8 @@ __global__ void Marlin( ...@@ -723,7 +882,8 @@ __global__ void Marlin(
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks]; FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs]; I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2]; FragC frag_c[thread_m_blocks][is_a_8bit ? 2 : 4][2];
FragC frag_c_tmp[thread_m_blocks][is_a_8bit ? 2 : 4][2];
FragS frag_s[2][4]; // No act-order FragS frag_s[2][4]; // No act-order
FragS frag_bias[2][4]; FragS frag_bias[2][4];
FragS act_frag_s[2][4][4]; // For act-order FragS act_frag_s[2][4][4]; // For act-order
...@@ -731,6 +891,24 @@ __global__ void Marlin( ...@@ -731,6 +891,24 @@ __global__ void Marlin(
FragZP frag_zp; // Zero-points in fp16 FragZP frag_zp; // Zero-points in fp16
FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ
if constexpr (is_a_8bit) {
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
frag_c_tmp[i][j][0][g] = 0.0f;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
frag_c_tmp[i][j][1][g] = 0.0f;
}
}
}
}
// Zero accumulators. // Zero accumulators.
auto zero_accums = [&]() { auto zero_accums = [&]() {
#pragma unroll #pragma unroll
...@@ -788,15 +966,17 @@ __global__ void Marlin( ...@@ -788,15 +966,17 @@ __global__ void Marlin(
} }
int4* sh_b_stage = sh_b + b_sh_stage * pipe; int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) { for (int i = 0; i < (b_sh_wr_iters * b_thread_vecs); i++) {
#pragma unroll constexpr int count = div_ceil(b_sh_stride, threads);
for (int j = 0; j < b_thread_vecs; j++) { int b_gl_idx =
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); b_gl_rd + (i % count) * threads +
} b_gl_stride * (i / count) * div_ceil(threads, b_sh_stride);
B_ptr[i] += b_gl_rd_delta_o; cp_async4(&sh_b_stage[threads * i + threadIdx.x], &B[b_gl_idx]);
} }
b_gl_rd += b_gl_rd_delta_o;
if constexpr (has_act_order) { if constexpr (has_act_order) {
// Fetch g_idx thread-block portion // Fetch g_idx thread-block portion
int full_pipe = a_off; int full_pipe = a_off;
...@@ -816,44 +996,24 @@ __global__ void Marlin( ...@@ -816,44 +996,24 @@ __global__ void Marlin(
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) { // Only fetch scales if this tile starts a new group
// Only fetch scales if this tile starts a new group if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) {
if (pipe % (group_blocks / thread_k_blocks) == 0) { if (s_sh_wr_pred) {
if (s_sh_wr_pred) { cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
} else {
for (int i = 0; i < s_tb_groups; i++) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
&scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
} }
s_gl_rd += s_gl_rd_delta * s_tb_groups;
} }
} }
if constexpr (has_zp && group_blocks != -1) { if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) { // Only fetch zero points if this tile starts a new group
// Only fetch zero-points if this tile starts a new group if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) {
if (pipe % (group_blocks / thread_k_blocks) == 0) { if (zp_sh_wr_pred) {
if (zp_sh_wr_pred) { cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
&zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
} }
zp_gl_rd += zp_gl_rd_delta * zp_tb_groups;
} }
} }
} }
...@@ -891,14 +1051,14 @@ __global__ void Marlin( ...@@ -891,14 +1051,14 @@ __global__ void Marlin(
int4* sh_a_stage = sh_a + a_sh_stage * pipe; int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) for (int i = 0; i < thread_m_blocks; i++)
ldsm<m_block_size_8 ? 2 : 4, scalar_t>( ldsm<m_block_size_8 ? 2 : 4, a_type_id>(
frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4* sh_b_stage = sh_b + b_sh_stage * pipe; int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < b_thread_vecs; i++) { for (int i = 0; i < b_thread_vecs; i++) {
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>( frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); &sh_b_stage[b_sh_stride * (k % b_sh_wr_iters) + b_sh_rd + i]);
} }
}; };
...@@ -922,53 +1082,54 @@ __global__ void Marlin( ...@@ -922,53 +1082,54 @@ __global__ void Marlin(
auto fetch_scales_to_registers = [&](int k, int full_pipe) { auto fetch_scales_to_registers = [&](int k, int full_pipe) {
int pipe = full_pipe % stages; int pipe = full_pipe % stages;
using IT1 = typename std::conditional_t<is_a_8bit, int2, int4>;
using IT0 = typename std::conditional_t<is_a_8bit, int, int2>;
constexpr int group_blocks2 = div_ceil(group_blocks, is_a_8bit ? 2 : 1);
if constexpr (!has_act_order) { if constexpr (!has_act_order) {
// No act-order case // No act-order case
if constexpr (group_blocks == -1) { if constexpr (group_blocks == -1) {
// load only when starting a new slice // load only when starting a new slice
if (k == 0 && full_pipe == 0) { if (k == 0 && full_pipe == 0 && dequant_skip_flop) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd]; reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4]; reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
} }
} else if constexpr (group_blocks != -1) { } else if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) { if constexpr (group_blocks >= thread_k_blocks) {
if (k % b_sh_wr_iters == 0) { constexpr int g = group_blocks / thread_k_blocks;
int4* sh_s_stage = if (pipe % g == 0) {
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * if (k % b_sh_wr_iters == 0) {
(pipe / (group_blocks / thread_k_blocks))); int4* sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else { } else {
reinterpret_cast<int4*>(&frag_s[1])[0] = reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0]; reinterpret_cast<int4*>(&frag_s[0])[0];
}
} }
} else { } else if (group_blocks2 < b_sh_wr_iters || k % b_sh_wr_iters == 0) {
auto warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4; int warp_row = warp_id / tb_n_warps;
int warp_row = warp_id / n_warps; int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters;
int cur_group_id = k_blocks / group_blocks2;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id =
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (w_type_id != vllm::kFE2M1f.id()) { if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { } else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] = reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>( reinterpret_cast<int2*>(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
} else if (group_blocks >= b_sh_wr_iters) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
} else { } else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] = reinterpret_cast<int2*>(&frag_s[1])[0] =
reinterpret_cast<int2*>( reinterpret_cast<int2*>(&frag_s[0])[0];
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
k % 2];
} }
} }
} }
...@@ -989,18 +1150,15 @@ __global__ void Marlin( ...@@ -989,18 +1150,15 @@ __global__ void Marlin(
cur_k = 0; cur_k = 0;
// Progress to current iteration // Progress to current iteration
cur_k += k_iter_size * (k % b_sh_wr_iters); cur_k += k % b_sh_wr_iters;
// Determine "position" inside the thread-block (based on warp and // Determine "position" inside the thread-block (based on warp and
// thread-id) // thread-id)
auto warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = int warp_row = warp_id / tb_n_warps;
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N int warp_col = warp_id % tb_n_warps;
int warp_row = warp_id / n_warps;
int warp_col = warp_id % n_warps;
cur_k += warp_row * 16; cur_k += warp_row * 16 * b_sh_wr_iters;
auto th_id = threadIdx.x % 32; auto th_id = threadIdx.x % 32;
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
...@@ -1055,18 +1213,16 @@ __global__ void Marlin( ...@@ -1055,18 +1213,16 @@ __global__ void Marlin(
if constexpr (group_blocks == -1) { if constexpr (group_blocks == -1) {
// load only when starting a new slice // load only when starting a new slice
if (k == 0 && full_pipe == 0) { if (k == 0 && full_pipe == 0 || is_a_8bit) {
#pragma unroll #pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) { for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i]; frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
} }
} }
} else if constexpr (group_blocks >= thread_k_blocks) { } else if constexpr (group_blocks >= thread_k_blocks) {
if (k % b_sh_wr_iters == 0) { constexpr int g = group_blocks / thread_k_blocks;
int4* sh_zp_stage = if (pipe % g == 0 && k % b_sh_wr_iters == 0 || is_a_8bit) {
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g));
(pipe / (group_blocks / thread_k_blocks)));
#pragma unroll #pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) { for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = frag_qzp[k % 2][i] =
...@@ -1075,21 +1231,11 @@ __global__ void Marlin( ...@@ -1075,21 +1231,11 @@ __global__ void Marlin(
} }
} else { } else {
auto warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16; int warp_row = warp_id / tb_n_warps;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16; int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters;
int cur_group_id = 0; int cur_group_id = k_blocks / div_ceil(group_blocks, is_a_8bit ? 2 : 1);
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
...@@ -1108,29 +1254,18 @@ __global__ void Marlin( ...@@ -1108,29 +1254,18 @@ __global__ void Marlin(
if constexpr (group_blocks != -1) { if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) { if constexpr (group_blocks >= thread_k_blocks) {
if (k % b_sh_wr_iters == 0) { constexpr int g = group_blocks / thread_k_blocks;
int4* sh_zp_stage = if (pipe % g == 0 && k % b_sh_wr_iters == 0) {
sh_zp + int4* sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g));
zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
sh_zp_stage[zp_sh_rd]; sh_zp_stage[zp_sh_rd];
} }
} else { } else if (group_blocks < b_sh_wr_iters || k % b_sh_wr_iters == 0) {
auto warp_id = threadIdx.x / 32; auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps; int warp_row = warp_id / tb_n_warps;
int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
int cur_group_id = k_blocks / group_blocks; int cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
...@@ -1141,33 +1276,46 @@ __global__ void Marlin( ...@@ -1141,33 +1276,46 @@ __global__ void Marlin(
} }
}; };
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { auto dequant_data = [&](int q, scalar_32bit_t* frag_b_ptr, int zp = 0) {
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr); if constexpr (a_type.size_bits() != b_type.size_bits()) {
if constexpr (is_a_8bit && has_zp) {
sub_zp_and_dequant<scalar_32bit_t, b_type_id, dequant_skip_flop>(
q, frag_b_ptr, zp);
} else {
dequant<scalar_32bit_t, b_type_id, dequant_skip_flop>(q, frag_b_ptr);
}
}
}; };
// Execute the actual tensor core matmul of a sub-tile. // Execute the actual tensor core matmul of a sub-tile.
bool is_first_matmul_in_slice = true; bool is_first_matmul_in_slice = true;
auto matmul = [&](int k) { auto matmul = [&](int k, int pipe) {
if (is_a_8bit) return;
int k2 = k % 2; int k2 = k % 2;
constexpr int g =
group_blocks > 0 ? div_ceil(group_blocks, thread_k_blocks) : 1;
const bool is_new_zp = const bool is_new_zp =
((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || (group_blocks == 0) ||
((group_blocks > 0) && (group_blocks < b_sh_wr_iters || k == 0)) &&
(pipe % g == 0) ||
(group_blocks == -1 && is_first_matmul_in_slice); (group_blocks == -1 && is_first_matmul_in_slice);
if constexpr (has_zp && !is_zp_float) { if constexpr (has_zp && !is_zp_float) {
if (is_new_zp) { if (is_new_zp) {
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
int zp_quant_0, zp_quant_1; int zp_quant_0, zp_quant_1;
if constexpr (w_type.size_bits() == 4) { if constexpr (b_type.size_bits() == 4) {
zp_quant_0 = frag_qzp[k2][0]; zp_quant_0 = frag_qzp[k2][0];
zp_quant_1 = zp_quant_0 >> 8; zp_quant_1 = zp_quant_0 >> 8;
} else { } else {
static_assert(w_type.size_bits() == 8); static_assert(b_type.size_bits() == 8);
zp_quant_0 = frag_qzp[k2][0]; zp_quant_0 = frag_qzp[k2][0];
zp_quant_1 = frag_qzp[k2][1]; zp_quant_1 = frag_qzp[k2][1];
} }
dequant_data(zp_quant_0, reinterpret_cast<scalar_t2*>(&frag_zp)); dequant_data(zp_quant_0, reinterpret_cast<scalar_32bit_t*>(&frag_zp));
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2); dequant_data(zp_quant_1,
reinterpret_cast<scalar_32bit_t*>(&frag_zp) + 2);
} }
} }
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
...@@ -1177,14 +1325,14 @@ __global__ void Marlin( ...@@ -1177,14 +1325,14 @@ __global__ void Marlin(
} }
} }
if constexpr (w_type == vllm::kFE2M1f) { if constexpr (b_type == vllm::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0]; int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1]; int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
dequant_fp8_scales<scalar_t2, s_type_id>( dequant_fp8_scales<c_scalar_t2, s_type_id>(
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2])); s_quant_0, reinterpret_cast<c_scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2, s_type_id>( dequant_fp8_scales<c_scalar_t2, s_type_id>(
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2); s_quant_1, reinterpret_cast<c_scalar_t2*>(&frag_s[k2]) + 2);
} }
// We have the m dimension as the inner loop in order to encourage overlapping // We have the m dimension as the inner loop in order to encourage overlapping
...@@ -1195,61 +1343,168 @@ __global__ void Marlin( ...@@ -1195,61 +1343,168 @@ __global__ void Marlin(
FragB frag_b1; FragB frag_b1;
int b_quant_0, b_quant_1; int b_quant_0, b_quant_1;
if constexpr (w_type_id == vllm::kFE2M1f.id()) { if constexpr (b_type_id == vllm::kFE2M1f.id()) {
b_quant_1 = frag_b_quant[k2][0][j]; b_quant_1 = frag_b_quant[k2][0][j];
b_quant_0 = b_quant_1 << 8; b_quant_0 = b_quant_1 << 8;
} else if constexpr (w_type.size_bits() == 4) { } else if constexpr (b_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k2][0][j]; b_quant_0 = frag_b_quant[k2][0][j];
b_quant_1 = b_quant_0 >> 8; b_quant_1 = b_quant_0 >> 8;
} else { } else {
static_assert(w_type.size_bits() == 8); static_assert(b_type.size_bits() == 8);
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k2]); int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k2]);
b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
} }
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0)); dequant_data(b_quant_0, reinterpret_cast<scalar_32bit_t*>(&frag_b0));
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1)); dequant_data(b_quant_1, reinterpret_cast<scalar_32bit_t*>(&frag_b1));
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { if constexpr (dequant_skip_flop && has_zp && !is_zp_float && !is_a_8bit) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0); sub_zp<a_type_id>(frag_b0, frag_zp[j], 0);
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1); sub_zp<a_type_id>(frag_b1, frag_zp[j], 1);
} }
// Apply scale to frag_b0 // Apply scale to frag_b0
if constexpr (has_act_order) { if constexpr (has_act_order && !is_a_8bit) {
static_assert(group_blocks != -1); static_assert(group_blocks != -1);
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], scale4<a_type_id>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], scale4<a_type_id>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
group_blocks == -1) { group_blocks == -1 && !is_a_8bit) {
int idx = (threadIdx.x / 4) % 2; int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2( scalar_t2 s2 = Adtype::nums2num2(
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx], reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 1])[idx]);
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x); scale_and_sub<a_type_id>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y); scale_and_sub<a_type_id>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1 &&
!is_a_8bit) {
if (is_new_zp) if (is_new_zp)
frag_zp[j] = __hmul2(frag_zp[j], frag_zp[j] = __hmul2(frag_zp[j],
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j])); *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); scale_and_sub<a_type_id>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); scale_and_sub<a_type_id>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y);
} else if constexpr (group_blocks != -1) { } else if constexpr (group_blocks != -1 && !is_a_8bit) {
scale<scalar_t>(frag_b0, frag_s[k2][j], 0); scale<a_type_id>(frag_b0, frag_s[k2][j], 0);
scale<scalar_t>(frag_b1, frag_s[k2][j], 1); scale<a_type_id>(frag_b1, frag_s[k2][j], 1);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
mma_trans<scalar_t>(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); mma_trans<a_type_id>(frag_a[k2][i], frag_b0, frag_b1,
frag_c[i][j][0]);
} else { } else {
mma<scalar_t>(frag_a[k2][i], frag_b0, frag_c[i][j][0]); mma<a_type_id>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
mma<scalar_t>(frag_a[k2][i], frag_b1, frag_c[i][j][1]); mma<a_type_id>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
}
}
}
};
auto matmul_a8 = [&](int k) {
int k2 = k % 2;
#pragma unroll
for (int j = 0; j < 2; j++) {
FragB frag_b[2];
if (is_a_8bit && b_type.size_bits() == 4 && !has_zp) {
dequant_data(frag_b_quant[k2][0][j * 2],
reinterpret_cast<scalar_32bit_t*>(&frag_b));
dequant_data(frag_b_quant[k2][0][j * 2 + 1],
reinterpret_cast<scalar_32bit_t*>(&frag_b) + 2);
} else if (is_a_8bit && b_type.size_bits() == 4 && has_zp) {
int off = (threadIdx.x / 32) % 2 * 2 + j;
int zp = (frag_qzp[k2][0] >> (off * 8)) & 0xF;
dequant_data(frag_b_quant[k2][0][j * 2],
reinterpret_cast<scalar_32bit_t*>(&frag_b), zp);
zp = (frag_qzp[k2][0] >> (off * 8 + 4)) & 0xF;
dequant_data(frag_b_quant[k2][0][j * 2 + 1],
reinterpret_cast<scalar_32bit_t*>(&frag_b) + 2, zp);
} else {
reinterpret_cast<int2*>(&frag_b)[0] =
reinterpret_cast<int2*>(&frag_b_quant[k2][j])[0];
reinterpret_cast<int2*>(&frag_b)[1] =
reinterpret_cast<int2*>(&frag_b_quant[k2][j])[1];
}
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma<a_type_id, 32>(frag_a[k2][i], frag_b[0],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]);
mma<a_type_id, 32>(frag_a[k2][i], frag_b[1],
(group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]);
}
if constexpr (group_blocks != -1) {
if (group_blocks == 2 || k == 1) {
if constexpr (a_type == vllm::kS8) {
int2 s_vals[2];
s_vals[0] = {
(int)reinterpret_cast<uint16_t*>(&frag_s[k2][j * 2][0])[0],
(int)reinterpret_cast<uint16_t*>(&frag_s[k2][j * 2][0])[1]};
s_vals[1] = {
(int)reinterpret_cast<uint16_t*>(&frag_s[k2][j * 2 + 1][0])[0],
(int)reinterpret_cast<uint16_t*>(&frag_s[k2][j * 2 + 1][0])[1]};
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
int scale = reinterpret_cast<int*>(&s_vals[0])[g % 2];
*reinterpret_cast<int32_t*>(&frag_c[i][j][0][g]) +=
*reinterpret_cast<int32_t*>(&frag_c_tmp[i][j][0][g]) *
scale;
frag_c_tmp[i][j][0][g] = 0.0f;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
int scale = reinterpret_cast<int*>(&s_vals[1])[g % 2];
*reinterpret_cast<int32_t*>(&frag_c[i][j][1][g]) +=
*reinterpret_cast<int32_t*>(&frag_c_tmp[i][j][1][g]) *
scale;
frag_c_tmp[i][j][1][g] = 0.0f;
}
}
} else {
float2 s_vals[2];
if constexpr (s_type_id != vllm::kFE8M0fnu.id()) {
static_assert(a_type.size_bits() == 16 ||
s_type.size_bits() == 16);
s_vals[0] = Cdtype::num22float2(frag_s[k2][j * 2][0]);
s_vals[1] = Cdtype::num22float2(frag_s[k2][j * 2 + 1][0]);
} else {
int32_t* s_vals_int = reinterpret_cast<int32_t*>(&s_vals[0]);
int32_t s_vals_e8m0 =
*reinterpret_cast<int32_t*>(&frag_s[k2][j][0]);
s_vals_int[0] = (s_vals_e8m0 & 0xFF) << 23;
s_vals_int[1] = (s_vals_e8m0 & 0xFF00) << 15;
s_vals_int[2] = (s_vals_e8m0 & 0xFF0000) << 7;
s_vals_int[3] = (s_vals_e8m0 & 0xFF000000) >> 1;
}
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
float scale = reinterpret_cast<float*>(&s_vals[0])[g % 2];
frag_c[i][j][0][g] += frag_c_tmp[i][j][0][g] * scale;
frag_c_tmp[i][j][0][g] = 0.0f;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
float scale = reinterpret_cast<float*>(&s_vals[1])[g % 2];
frag_c[i][j][1][g] += frag_c_tmp[i][j][1][g] * scale;
frag_c_tmp[i][j][1][g] = 0.0f;
}
}
}
} }
} }
} }
...@@ -1263,7 +1518,8 @@ __global__ void Marlin( ...@@ -1263,7 +1518,8 @@ __global__ void Marlin(
constexpr int red_off = threads / b_sh_stride_threads / 2; constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) { if (red_off >= 1) {
auto red_idx = threadIdx.x / b_sh_stride_threads; auto red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; constexpr int red_sh_stride =
b_sh_stride_threads * (is_a_8bit ? 2 : 4) * 2;
constexpr int red_sh_delta = b_sh_stride_threads; constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads); (threadIdx.x % b_sh_stride_threads);
...@@ -1278,7 +1534,8 @@ __global__ void Marlin( ...@@ -1278,7 +1534,8 @@ __global__ void Marlin(
for (int i = red_off; i > 0; i /= 2) { for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) { if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll #pragma unroll
for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { for (int j = 0; j < (is_a_8bit ? 2 : 4) * 2;
j += (m_block_size_8 ? 2 : 1)) {
int red_sh_wr = int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i); red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) { if (i < red_off) {
...@@ -1287,24 +1544,26 @@ __global__ void Marlin( ...@@ -1287,24 +1544,26 @@ __global__ void Marlin(
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]); float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
#pragma unroll #pragma unroll
for (int k = 0; k < 4; k++) for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += reinterpret_cast<FragC*>(
frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k]; c_rd[k] + c_wr[k];
} }
sh_red[red_sh_wr] = sh_red[red_sh_wr] = reinterpret_cast<int4*>(
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j]; &frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j];
} }
} }
__syncthreads(); __syncthreads();
} }
if (red_idx == 0) { if (red_idx == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { for (int i = 0; i < (is_a_8bit ? 2 : 4) * 2;
i += (m_block_size_8 ? 2 : 1)) {
float* c_rd = float* c_rd =
reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]); reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += reinterpret_cast<FragC*>(
c_rd[j]; frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + i][j] += c_rd[j];
} }
} }
__syncthreads(); __syncthreads();
...@@ -1320,10 +1579,10 @@ __global__ void Marlin( ...@@ -1320,10 +1579,10 @@ __global__ void Marlin(
// We are very careful here to reduce directly in the output buffer to // We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out // maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute). // results in FP16 (but still reduce with FP32 compute).
constexpr int active_threads = 32 * thread_n_blocks / 4; constexpr int active_threads = 32 * tb_n_warps;
if (threadIdx.x < active_threads) { if (threadIdx.x < active_threads) {
int c_gl_stride = prob_n / 8; int c_gl_stride = prob_n / 8;
int c_gl_wr_delta_o = 8 * c_gl_stride; int c_gl_wr_delta_o = 8 * c_gl_stride * (is_a_8bit ? 2 : 1);
int c_gl_wr_delta_i = 4 * (active_threads / 32); int c_gl_wr_delta_i = 4 * (active_threads / 32);
int c_gl_wr; int c_gl_wr;
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
...@@ -1331,9 +1590,9 @@ __global__ void Marlin( ...@@ -1331,9 +1590,9 @@ __global__ void Marlin(
4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8;
c_gl_wr += (2 * thread_n_blocks) * slice_col; c_gl_wr += (2 * thread_n_blocks) * slice_col;
} else { } else {
c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) * (is_a_8bit ? 2 : 1) +
4 * (threadIdx.x / 32) + threadIdx.x % 4; 4 * (threadIdx.x / 32) + threadIdx.x % 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col; c_gl_wr += (2 * thread_n_blocks) * slice_col * (is_a_8bit ? 2 : 1);
} }
constexpr int c_sh_wr_delta = active_threads; constexpr int c_sh_wr_delta = active_threads;
auto c_sh_wr = threadIdx.x; auto c_sh_wr = threadIdx.x;
...@@ -1351,6 +1610,14 @@ __global__ void Marlin( ...@@ -1351,6 +1610,14 @@ __global__ void Marlin(
&C[c_gl_wr + i * c_gl_stride + &C[c_gl_wr + i * c_gl_stride +
(threadIdx.x % 8) / 4 * c_gl_wr_delta_i], (threadIdx.x % 8) / 4 * c_gl_wr_delta_i],
(threadIdx.x % 4) * 2 + i < prob_m); (threadIdx.x % 4) * 2 + i < prob_m);
} else if constexpr (is_a_8bit) {
int2* sh_red_int2 = reinterpret_cast<int2*>(sh_red);
int2* c_int2 = reinterpret_cast<int2*>(C);
cp_async2_ca_pred(
&sh_red_int2[c_sh_wr + c_sh_wr_delta * i],
&c_int2[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
} else { } else {
cp_async4_pred( cp_async4_pred(
&sh_red[c_sh_wr + c_sh_wr_delta * i], &sh_red[c_sh_wr + c_sh_wr_delta * i],
...@@ -1370,36 +1637,51 @@ __global__ void Marlin( ...@@ -1370,36 +1637,51 @@ __global__ void Marlin(
(m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m);
if (mask) { if (mask) {
if (!first) { if (!first) {
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; c_scalar_t* c_red_f16;
if constexpr (is_a_8bit) {
int2 tmp =
reinterpret_cast<int2*>(sh_red)[c_sh_wr + i * c_sh_wr_delta];
c_red_f16 = reinterpret_cast<c_scalar_t*>(&tmp);
} else {
int4 tmp = sh_red[c_sh_wr + i * c_sh_wr_delta];
c_red_f16 = reinterpret_cast<c_scalar_t*>(&tmp);
}
#pragma unroll #pragma unroll
for (int j = 0; j < 2 * 4; j++) { for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) {
int delta = 0; int delta = 0;
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
delta = j % 2 == 1 ? -2 : 0; delta = j % 2 == 1 ? -2 : 0;
} }
reinterpret_cast<float*>( reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j +
Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]); (i % 4) + delta] += Cdtype::num2float(c_red_f16[j]);
} }
} }
if (!last) { if (!last) {
int4 c; c_scalar_t c_f16[is_a_8bit ? 4 : 8];
#pragma unroll #pragma unroll
for (int j = 0; j < 2 * 4; j++) { for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) {
int delta = 0; int delta = 0;
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
delta = j % 2 == 1 ? -2 : 0; delta = j % 2 == 1 ? -2 : 0;
} }
reinterpret_cast<scalar_t*>(&c)[j] = c_f16[j] = Cdtype::float2num(reinterpret_cast<float*>(
Dtype::float2num(reinterpret_cast<float*>( &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j +
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); (i % 4) + delta]);
} }
if constexpr (m_block_size_8) if constexpr (m_block_size_8) {
C[c_gl_wr + i * c_gl_stride + C[c_gl_wr + i * c_gl_stride +
(threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] =
else *reinterpret_cast<int4*>(c_f16);
} else if constexpr (is_a_8bit) {
int2* c_int2 = reinterpret_cast<int2*>(C);
c_int2[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)] =
*reinterpret_cast<int2*>(c_f16);
} else {
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)] = c; c_gl_wr_delta_i * (i % 2)] = *reinterpret_cast<int4*>(c_f16);
}
} }
} }
} }
...@@ -1414,10 +1696,10 @@ __global__ void Marlin( ...@@ -1414,10 +1696,10 @@ __global__ void Marlin(
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
constexpr int active_threads = 32 * thread_n_blocks / 4; constexpr int active_threads = 32 * tb_n_warps;
bool is_th_active = threadIdx.x < active_threads; bool is_th_active = threadIdx.x < active_threads;
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; constexpr int num_floats = thread_m_blocks * (is_a_8bit ? 2 : 4) * 2 * 4;
constexpr int th_size = num_floats * sizeof(float) / 16; constexpr int th_size = num_floats * sizeof(float) / 16;
int c_cur_offset = locks_off * c_size; int c_cur_offset = locks_off * c_size;
...@@ -1471,7 +1753,7 @@ __global__ void Marlin( ...@@ -1471,7 +1753,7 @@ __global__ void Marlin(
} else { } else {
c_sh_wr = c_sh_wr =
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
c_sh_wr += 32 * (threadIdx.x / 32); c_sh_wr += (is_a_8bit ? 16 : 32) * (threadIdx.x / 32);
} }
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
...@@ -1481,47 +1763,47 @@ __global__ void Marlin( ...@@ -1481,47 +1763,47 @@ __global__ void Marlin(
// We first reorder in shared memory to guarantee the most efficient final // We first reorder in shared memory to guarantee the most efficient final
// global write patterns // global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) { auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
scalar_t2 res = c_scalar_t2 res =
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1));
// For per-column quantization we finally apply the scale here (only for // For per-column quantization we finally apply the scale here (only for
// 4-bit) // 4-bit)
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 && !is_a_8bit &&
w_type.size_bits() == 4 && b_type.size_bits() == 4 &&
(has_zp && dequant_skip_flop || !has_zp)) { (has_zp && dequant_skip_flop || !has_zp)) {
scalar_t2 tmp_scale = s[0]; c_scalar_t2 tmp_scale = s[0];
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
tmp_scale = Dtype::num2num2( tmp_scale = Cdtype::num2num2(
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]); reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
} }
res = __hmul2(res, tmp_scale); res = __hmul2(res, tmp_scale);
} }
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
res = __hmul2(res, global_scale); res = __hmul2(res, global_scale);
} }
if (has_bias && last) { if (has_bias && last) {
scalar_t2 tmp_bias = b_bias[0]; c_scalar_t2 tmp_bias = b_bias[0];
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
tmp_bias = Dtype::num2num2( tmp_bias = Cdtype::num2num2(
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]); reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
} }
res = __hadd2(res, tmp_bias); res = __hadd2(res, tmp_bias);
} }
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x; ((c_scalar_t*)sh_red)[idx] = res.x;
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; ((c_scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
} else { } else {
((scalar_t2*)sh_red)[idx] = res; ((c_scalar_t2*)sh_red)[idx] = res;
} }
}; };
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < tb_n_warps) {
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) { for (int j = 0; j < (is_a_8bit ? 2 : 4); j++) {
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
int wr = c_sh_wr + 16 * j; int wr = c_sh_wr + 16 * j;
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
...@@ -1557,9 +1839,9 @@ __global__ void Marlin( ...@@ -1557,9 +1839,9 @@ __global__ void Marlin(
i++) { i++) {
if (c_gl_wr < c_gl_wr_end) { if (c_gl_wr < c_gl_wr_end) {
if (use_atomic_add && slice_count > 1) { if (use_atomic_add && slice_count > 1) {
scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[c_gl_wr]); c_scalar_t2* C_half2 = reinterpret_cast<c_scalar_t2*>(&C[c_gl_wr]);
scalar_t2* sh_red_half2 = c_scalar_t2* sh_red_half2 =
reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]); reinterpret_cast<c_scalar_t2*>(&sh_red[c_sh_rd]);
#pragma unroll #pragma unroll
for (int a = 0; a < 4; a++) { for (int a = 0; a < 4; a++) {
atomicAdd(&C_half2[a], sh_red_half2[a]); atomicAdd(&C_half2[a], sh_red_half2[a]);
...@@ -1635,7 +1917,13 @@ __global__ void Marlin( ...@@ -1635,7 +1917,13 @@ __global__ void Marlin(
wait_for_stage(); wait_for_stage();
init_same_group(pipe % stages); init_same_group(pipe % stages);
} }
matmul(k);
if constexpr (!is_a_8bit) {
matmul(k, pipe - (k >= b_sh_wr_iters - 2 ? 1 : 0));
} else {
static_assert(group_blocks != 0 && group_blocks != 1);
matmul_a8(k);
}
} }
slice_iters--; slice_iters--;
if (slice_iters == 0) { if (slice_iters == 0) {
...@@ -1668,13 +1956,47 @@ __global__ void Marlin( ...@@ -1668,13 +1956,47 @@ __global__ void Marlin(
// While this pattern may not be the most readable, other ways of writing // While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation. // the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) { if (slice_iters == 0) {
if constexpr (is_a_8bit) {
float frag_a_s[2 * thread_m_blocks];
for (int i = 0; i < 2 * thread_m_blocks; i++)
frag_a_s[i] = sh_a_s[i * 8 + (threadIdx.x % 32) / 4];
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
float c_val = frag_c[i][j][0][g];
if constexpr (a_type == vllm::kS8) {
c_val = __int2float_rn(*reinterpret_cast<int32_t*>(&c_val));
}
float s_val = frag_a_s[i * 2 + g / 2];
frag_c[i][j][0][g] = c_val * s_val;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
float c_val = frag_c[i][j][1][g];
if constexpr (a_type == vllm::kS8) {
c_val = __int2float_rn(*reinterpret_cast<int32_t*>(&c_val));
}
float s_val = frag_a_s[i * 2 + g / 2];
frag_c[i][j][1][g] = c_val * s_val;
}
}
}
}
cp_async_wait<0>(); cp_async_wait<0>();
bool last = slice_idx == slice_count - 1; bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before // For per-column scales, we only fetch them here in the final step before
// write-out // write-out
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) { (has_zp && dequant_skip_flop || !has_zp)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (b_type.size_bits() == 8 || (last || use_atomic_add) || is_a_8bit) {
if (s_sh_wr_pred) { if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
} }
...@@ -1692,20 +2014,27 @@ __global__ void Marlin( ...@@ -1692,20 +2014,27 @@ __global__ void Marlin(
} }
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 &&
(has_zp && dequant_skip_flop || !has_zp)) { (has_zp && dequant_skip_flop || !has_zp || is_a_8bit)) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if constexpr (is_a_8bit) {
cp_async_wait<0>(); cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < tb_n_warps) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
}
} else if (b_type.size_bits() == 8 || (last || use_atomic_add)) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < tb_n_warps) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4]; reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
if constexpr (m_block_size_8) { if constexpr (m_block_size_8) {
int idx = (threadIdx.x / 4) % 2; int idx = (threadIdx.x / 4) % 2;
scalar_t2* frag_s_half2 = reinterpret_cast<scalar_t2*>(frag_s); c_scalar_t2* frag_s_half2 =
reinterpret_cast<c_scalar_t2*>(frag_s);
#pragma unroll #pragma unroll
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
frag_s_half2[i] = Dtype::num2num2( frag_s_half2[i] = Cdtype::num2num2(
reinterpret_cast<scalar_t*>(&frag_s_half2[i])[idx]); reinterpret_cast<c_scalar_t*>(&frag_s_half2[i])[idx]);
} }
} }
} }
...@@ -1715,26 +2044,48 @@ __global__ void Marlin( ...@@ -1715,26 +2044,48 @@ __global__ void Marlin(
// For 8-bit channelwise, we apply the scale before the global reduction // For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible // that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16) // overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 && if constexpr (!has_act_order && group_blocks == -1 && is_a_8bit) {
w_type.size_bits() == 8 && #pragma unroll
(has_zp && dequant_skip_flop || !has_zp)) { for (int j = 0; j < 2; j++) {
if (threadIdx.x / 32 < thread_n_blocks / 4) { float2 aa[2];
aa[0] = Cdtype::num22float2(frag_s[0][j * 2][0]);
aa[1] = Cdtype::num22float2(frag_s[0][j * 2 + 1][0]);
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int g = 0; g < 4; g++) {
float scale = reinterpret_cast<float*>(&aa[0])[g % 2];
frag_c[i][j][0][g] *= scale;
}
#pragma unroll
for (int g = 0; g < 4; g++) {
float scale = reinterpret_cast<float*>(&aa[1])[g % 2];
frag_c[i][j][1][g] *= scale;
}
}
}
} else if (!has_act_order && group_blocks == -1 &&
b_type.size_bits() == 8 &&
(has_zp && dequant_skip_flop || !has_zp)) {
if (threadIdx.x / 32 < tb_n_warps) {
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
scale_float<scalar_t>( scale_float<c_type_id>(
reinterpret_cast<float*>(&frag_c[i][j][0][0]), reinterpret_cast<float*>(&frag_c[i][j][0][0]),
frag_s[j / 2][2 * (j % 2) + 0]); frag_s[j / 2][2 * (j % 2) + 0]);
scale_float<scalar_t>( scale_float<c_type_id>(
reinterpret_cast<float*>(&frag_c[i][j][0][2]), reinterpret_cast<float*>(&frag_c[i][j][0][2]),
frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]);
if constexpr (!m_block_size_8) { if constexpr (!m_block_size_8) {
scale_float<scalar_t>( scale_float<c_type_id>(
reinterpret_cast<float*>(&frag_c[i][j][1][0]), reinterpret_cast<float*>(&frag_c[i][j][1][0]),
frag_s[j / 2][2 * (j % 2) + 1]); frag_s[j / 2][2 * (j % 2) + 1]);
scale_float<scalar_t>( scale_float<c_type_id>(
reinterpret_cast<float*>(&frag_c[i][j][1][2]), reinterpret_cast<float*>(&frag_c[i][j][1][2]),
frag_s[j / 2][2 * (j % 2) + 1]); frag_s[j / 2][2 * (j % 2) + 1]);
} }
...@@ -1758,7 +2109,8 @@ __global__ void Marlin( ...@@ -1758,7 +2109,8 @@ __global__ void Marlin(
cp_async_wait<0>(); cp_async_wait<0>();
__syncthreads(); __syncthreads();
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd]; reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; if constexpr (!is_a_8bit)
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
__syncthreads(); __syncthreads();
} }
...@@ -1768,21 +2120,22 @@ __global__ void Marlin( ...@@ -1768,21 +2120,22 @@ __global__ void Marlin(
// only the last block in a slice actually writes the result // only the last block in a slice actually writes the result
write_result(last); write_result(last);
slice_row = 0; slice_row = 0;
slice_col_par++; if (!in_part2) {
slice_col++; slice_col_par += gridDim.x;
} else {
slice_col_par++;
slice_col++;
}
is_first_matmul_in_slice = true; is_first_matmul_in_slice = true;
init_slice(); init_slice();
if (slice_iters) { if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o); (threadIdx.x % a_gl_rd_delta_o);
#pragma unroll a_gl_rd += a_gl_rd_delta_o * slice_row;
for (int i = 0; i < b_sh_wr_iters; i++) b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) +
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; (threadIdx.x % b_sh_stride);
if (slice_col == 0) { b_gl_rd += b_sh_stride * slice_col + b_gl_rd_delta_o * slice_row;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
}
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
// Update slice k/n for scales loading // Update slice k/n for scales loading
...@@ -1791,12 +2144,28 @@ __global__ void Marlin( ...@@ -1791,12 +2144,28 @@ __global__ void Marlin(
slice_k_finish = slice_k_start + tb_k * slice_iters; slice_k_finish = slice_k_start + tb_k * slice_iters;
slice_k_start_shared_fetch = slice_k_start; slice_k_start_shared_fetch = slice_k_start;
slice_n_offset = act_s_col_tb_stride * slice_col; slice_n_offset = act_s_col_tb_stride * slice_col;
} else { } else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; if constexpr (group_blocks == -1) {
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else if constexpr (group_blocks >= thread_k_blocks) {
s_gl_rd =
s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd =
zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
zp_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd =
s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks +
threadIdx.x / s_sh_stride) +
s_sh_stride * slice_col + threadIdx.x % s_sh_stride;
zp_gl_rd =
zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks +
threadIdx.x / zp_sh_stride) +
zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride;
}
} }
start_pipes(); start_pipes();
} }
} }
......
...@@ -298,9 +298,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -298,9 +298,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ. // gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def( ops.def(
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor? b_bias_or_none," "Tensor? b_bias_or_none,Tensor b_scales, "
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " "Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " "Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_type_id, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor"); "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file // conditionally compiled so impl registration is in source file
...@@ -308,13 +309,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -308,13 +309,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin repack from GPTQ. // gptq_marlin repack from GPTQ.
ops.def( ops.def(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, " "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor"); "SymInt size_k, SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
// conditionally compiled so impl registrations are in source file // conditionally compiled so impl registrations are in source file
// awq_marlin repack from AWQ. // awq_marlin repack from AWQ.
ops.def( ops.def(
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor"); "SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
// conditionally compiled so impl registrations are in source file
// preprocess W-int4A-fp8 weight for marlin kernel
ops.def(
"marlin_int4_fp8_preprocess(Tensor qweight, "
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
// conditionally compiled so impl registrations are in source file // conditionally compiled so impl registrations are in source file
// CUTLASS w4a8 GEMM // CUTLASS w4a8 GEMM
......
...@@ -60,7 +60,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes. ...@@ -60,7 +60,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
- [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod] - [`ModelOptFp8MoEMethod`][vllm.model_executor.layers.quantization.modelopt.ModelOptFp8MoEMethod]
- [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod] - [`Fp8MoEMethod`][vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod]
- [`CompressedTensorsW4A4Nvfp4MoeMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoeMethod] - [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW4A4Nvfp4MoEMethod]
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod] - [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Fp8MoEMethod]
- [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod] - [`Mxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.Mxfp4MoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod] - [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]
......
...@@ -21,7 +21,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock ...@@ -21,7 +21,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe from tests.kernels.moe.utils import fused_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment from vllm.distributed.parallel_state import init_distributed_environment
...@@ -65,6 +65,64 @@ NUM_EXPERTS = [8, 64, 192] ...@@ -65,6 +65,64 @@ NUM_EXPERTS = [8, 64, 192]
EP_SIZE = [1, 4] EP_SIZE = [1, 4]
TOP_KS = [2, 6] TOP_KS = [2, 6]
MOE_MARLIN_QUANT_TEST_CONFIGS = [
# AWQ-INT4
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
# GPTQ-INT4
{
"b_type": scalar_types.uint4b8,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT8
{
"b_type": scalar_types.uint8b128,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# FP8
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
# NVFP4
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
# MXFP4
{
"a_type": [scalar_types.bfloat16],
"b_type": scalar_types.float4_e2m1f,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.float4_e2m1f,
"c_type": [scalar_types.bfloat16],
"group_blocks": [2],
},
]
FUSED_MOE_MNK_FACTORS = [ FUSED_MOE_MNK_FACTORS = [
(1, 128, 128), (1, 128, 128),
(1, 2048, 128), (1, 2048, 128),
...@@ -505,63 +563,74 @@ def marlin_moe_generate_valid_test_cases(): ...@@ -505,63 +563,74 @@ def marlin_moe_generate_valid_test_cases():
m_list = [1, 123, 666] m_list = [1, 123, 666]
n_list = [128, 1024] n_list = [128, 1024]
k_list = [256, 2048] k_list = [256, 2048]
e_list = [4, 12] e_list = [5, 12]
topk_list = [2, 3] topk_list = [2, 3]
ep_size_list = [1, 4] ep_size_list = [1, 4]
dtype_list = [torch.bfloat16]
group_size_list = [-1, 32, 128]
act_order_list = [True, False] act_order_list = [True, False]
quant_type_list = [
scalar_types.float4_e2m1f,
scalar_types.float8_e4m3fn,
scalar_types.uint4,
scalar_types.uint4b8,
scalar_types.uint8b128,
]
is_k_full_list = [True, False] is_k_full_list = [True, False]
all_combinations = itertools.product( all_combinations = itertools.product(
MOE_MARLIN_QUANT_TEST_CONFIGS,
m_list, m_list,
n_list, n_list,
k_list, k_list,
e_list, e_list,
topk_list, topk_list,
ep_size_list, ep_size_list,
dtype_list,
group_size_list,
act_order_list, act_order_list,
quant_type_list,
is_k_full_list, is_k_full_list,
) )
def is_invalid( def is_invalid(
m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full a_type,
b_type,
c_type,
group_blocks,
m,
n,
k,
e,
topk,
ep_size,
act_order,
is_k_full,
): ):
if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]: group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
return False if group_size > 0 and k % group_size != 0:
if quant_type == scalar_types.float4_e2m1f:
if group_size not in [16, 32]:
return False
if dtype == torch.float16 and group_size == 32:
return False
if quant_type != scalar_types.float4_e2m1f and group_size == 16:
return False return False
# Filter act_order if act_order and group_size in [-1, k, n]:
if act_order: return False
if group_size in (-1, k, n): if group_size in [k, n]:
return False return False
if quant_type not in [scalar_types.uint4b8]: if not act_order and is_k_full:
return False
elif not is_k_full:
return False return False
return True return a_type.size_bits < 16 or a_type is c_type
cases = [] cases = []
for case in all_combinations: for case in all_combinations:
if is_invalid(*case): quant_test_config, m, n, k, _, _, _, act_order, *_ = case
cases.append(case) if act_order and not quant_test_config.get("support_act_order", False):
continue
f16_types = [scalar_types.float16]
inner_combinations = itertools.product(
quant_test_config.get("a_type", f16_types),
[quant_test_config["b_type"]],
quant_test_config.get("c_type", f16_types),
quant_test_config["group_blocks"],
)
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
):
continue
args = sub_case + (m, n, k) + case[4:]
if is_invalid(*args):
cases.append(args)
return cases return cases
...@@ -571,6 +640,7 @@ class MarlinMoEWeightData: ...@@ -571,6 +640,7 @@ class MarlinMoEWeightData:
qweight: torch.Tensor qweight: torch.Tensor
scales: torch.Tensor scales: torch.Tensor
global_scale: torch.Tensor | None global_scale: torch.Tensor | None
a_scales_factor: torch.Tensor | None
g_idx: torch.Tensor | None g_idx: torch.Tensor | None
zeros: torch.Tensor | None zeros: torch.Tensor | None
sort_indices: torch.Tensor | None sort_indices: torch.Tensor | None
...@@ -583,11 +653,20 @@ class MarlinMoEWeightData: ...@@ -583,11 +653,20 @@ class MarlinMoEWeightData:
group_size: int, group_size: int,
act_order: bool | None = None, act_order: bool | None = None,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
input_type: ScalarType = None,
) -> "MarlinMoEWeightData": ) -> "MarlinMoEWeightData":
assert w.ndim == 3 assert w.ndim == 3
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
k = w.shape[-1] k = w.shape[-1]
if input_type == scalar_types.int8:
input_dtype = torch.int8
elif input_type == scalar_types.float8_e4m3fn:
input_dtype = torch.float8_e4m3fn
else:
input_dtype = w.dtype
w_ref_l: list[torch.Tensor] = [] w_ref_l: list[torch.Tensor] = []
qweight_l: list[torch.Tensor] = [] qweight_l: list[torch.Tensor] = []
scales_l: list[torch.Tensor] = [] scales_l: list[torch.Tensor] = []
...@@ -601,11 +680,13 @@ class MarlinMoEWeightData: ...@@ -601,11 +680,13 @@ class MarlinMoEWeightData:
if quant_type == scalar_types.float4_e2m1f: if quant_type == scalar_types.float4_e2m1f:
if group_size == 16: if group_size == 16:
w_ref, qweight, scales, global_scale = ( w_ref, qweight, scales, global_scale = (
rand_marlin_weight_nvfp4_like(w[i], group_size) rand_marlin_weight_nvfp4_like(
w[i], group_size, input_dtype=input_dtype
)
) )
else: else:
w_ref, qweight, scales = rand_marlin_weight_mxfp4_like( w_ref, qweight, scales = rand_marlin_weight_mxfp4_like(
w[i], group_size w[i], group_size, input_dtype=input_dtype
) )
global_scale = None global_scale = None
...@@ -615,13 +696,18 @@ class MarlinMoEWeightData: ...@@ -615,13 +696,18 @@ class MarlinMoEWeightData:
if global_scale is not None: if global_scale is not None:
global_scale_l.append(global_scale) global_scale_l.append(global_scale)
elif quant_type == scalar_types.float8_e4m3fn: elif quant_type == scalar_types.float8_e4m3fn:
w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size) w_ref, qweight, scales = marlin_quant_fp8_torch(
w[i], group_size, input_dtype=input_dtype
)
w_ref_l.append(w_ref.T) w_ref_l.append(w_ref.T)
qweight_l.append(qweight) qweight_l.append(qweight)
scales_l.append(scales) scales_l.append(scales)
elif has_zp: elif has_zp:
w_ref, qweight, scales, zeros = awq_marlin_quantize( w_ref, qweight, scales, zeros = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size w[i].transpose(1, 0),
quant_type,
group_size,
input_dtype=input_dtype,
) )
w_ref_l.append(w_ref.T) w_ref_l.append(w_ref.T)
...@@ -631,7 +717,12 @@ class MarlinMoEWeightData: ...@@ -631,7 +717,12 @@ class MarlinMoEWeightData:
else: else:
test_perm = torch.randperm(k) test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm w[i].transpose(1, 0),
quant_type,
group_size,
act_order,
test_perm,
input_dtype=input_dtype,
) )
w_ref_l.append(w_ref.T) w_ref_l.append(w_ref.T)
...@@ -652,11 +743,18 @@ class MarlinMoEWeightData: ...@@ -652,11 +743,18 @@ class MarlinMoEWeightData:
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
marlin_bias = stack_and_dev(bias_l) if bias_l else None marlin_bias = stack_and_dev(bias_l) if bias_l else None
a_scales_factor = None
if input_type == scalar_types.int8 and group_size != -1:
a_scales_factor = 1 / 4096 * scales.max().float()
scales = scales / scales.max() * 4096
scales = scales.round().to(torch.int16).view(w.dtype)
return MarlinMoEWeightData( return MarlinMoEWeightData(
w_ref=w_ref, w_ref=w_ref,
qweight=qweight, qweight=qweight,
scales=scales, scales=scales,
global_scale=global_scale, global_scale=global_scale,
a_scales_factor=a_scales_factor,
g_idx=g_idx, g_idx=g_idx,
zeros=zeros, zeros=zeros,
sort_indices=sort_indices, sort_indices=sort_indices,
...@@ -666,28 +764,47 @@ class MarlinMoEWeightData: ...@@ -666,28 +764,47 @@ class MarlinMoEWeightData:
@pytest.mark.flaky(reruns=2) @pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"), (
"a_type, b_type, c_type, group_blocks,"
"m, n, k, e, topk, ep_size, act_order, is_k_full"
),
marlin_moe_generate_valid_test_cases(), marlin_moe_generate_valid_test_cases(),
) )
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_fused_marlin_moe( def test_fused_marlin_moe(
m: int, a_type,
n: int, b_type,
k: int, c_type,
e: int, group_blocks,
topk: int, m,
ep_size: int, n,
dtype: torch.dtype, k,
group_size: int, e,
act_order: bool, topk,
quant_type: ScalarType, ep_size,
is_k_full: bool, act_order,
is_k_full,
): ):
torch.cuda.manual_seed(0) torch.cuda.manual_seed(1)
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
if c_type == scalar_types.float16:
dtype = torch.float16
elif c_type == scalar_types.bfloat16:
dtype = torch.bfloat16
else:
raise RuntimeError("unsupported c_type")
if a_type == scalar_types.int8:
a_dtype = torch.int8
elif a_type == scalar_types.float8_e4m3fn:
a_dtype = torch.float8_e4m3fn
else:
a_dtype = dtype
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
if ep_size > 1: if ep_size > 1:
local_e = e // ep_size local_e = e // ep_size
...@@ -700,11 +817,19 @@ def test_fused_marlin_moe( ...@@ -700,11 +817,19 @@ def test_fused_marlin_moe(
e_map = None e_map = None
w1_data = MarlinMoEWeightData.make( w1_data = MarlinMoEWeightData.make(
w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order w=w1,
quant_type=b_type,
group_size=group_size,
act_order=act_order,
input_type=a_type,
) )
w2_data = MarlinMoEWeightData.make( w2_data = MarlinMoEWeightData.make(
w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order w=w2,
quant_type=b_type,
group_size=group_size,
act_order=act_order,
input_type=a_type,
) )
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
...@@ -712,8 +837,18 @@ def test_fused_marlin_moe( ...@@ -712,8 +837,18 @@ def test_fused_marlin_moe(
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
torch_output = torch_moe( score = torch.softmax(score, dim=-1, dtype=torch.float32)
a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map topk_weight, topk_ids = torch.topk(score, topk)
torch_output = torch_experts(
a,
w1_data.w_ref,
w2_data.w_ref,
topk_weight=topk_weight,
topk_ids=topk_ids,
global_num_experts=e,
expert_map=e_map,
quant_dtype=a_dtype,
per_act_token_quant=True,
) )
marlin_output = fused_marlin_moe( marlin_output = fused_marlin_moe(
...@@ -733,15 +868,18 @@ def test_fused_marlin_moe( ...@@ -733,15 +868,18 @@ def test_fused_marlin_moe(
global_scale2=w2_data.global_scale, global_scale2=w2_data.global_scale,
g_idx1=w1_data.g_idx, g_idx1=w1_data.g_idx,
g_idx2=w2_data.g_idx, g_idx2=w2_data.g_idx,
input_global_scale1=w1_data.a_scales_factor,
input_global_scale2=w2_data.a_scales_factor,
sort_indices1=w1_data.sort_indices, sort_indices1=w1_data.sort_indices,
sort_indices2=w2_data.sort_indices, sort_indices2=w2_data.sort_indices,
w1_zeros=w1_data.zeros, w1_zeros=w1_data.zeros,
w2_zeros=w2_data.zeros, w2_zeros=w2_data.zeros,
quant_type_id=quant_type.id, input_dtype=a_dtype,
quant_type_id=b_type.id,
is_k_full=is_k_full, is_k_full=is_k_full,
) )
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0)
@pytest.mark.flaky(reruns=2) @pytest.mark.flaky(reruns=2)
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`. Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
""" """
import itertools
import pytest import pytest
import torch import torch
...@@ -17,8 +19,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( ...@@ -17,8 +19,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
) )
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES,
marlin_make_empty_g_idx, marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_permute_bias, marlin_permute_bias,
...@@ -26,7 +30,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -26,7 +30,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
query_marlin_supported_quant_types, query_marlin_supported_quant_types,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
rand_marlin_weight_mxfp4_like, rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like, rand_marlin_weight_nvfp4_like,
) )
...@@ -50,6 +53,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -50,6 +53,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights, quantize_weights,
sort_weights, sort_weights,
) )
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
ACT_ORDER_OPTS = [False, True] ACT_ORDER_OPTS = [False, True]
...@@ -65,6 +69,12 @@ MARLIN_24_N_CHUNKS = [512] ...@@ -65,6 +69,12 @@ MARLIN_24_N_CHUNKS = [512]
HQQ_SUPPORTED_GROUP_SIZES = [64] HQQ_SUPPORTED_GROUP_SIZES = [64]
MARLIN_REPACK_NK_FACTORS = [
(4, 8),
(7, 5),
(13, 11),
]
MNK_FACTORS = [ MNK_FACTORS = [
(1, 1, 1), (1, 1, 1),
(1, 4, 8), (1, 4, 8),
...@@ -74,6 +84,64 @@ MNK_FACTORS = [ ...@@ -74,6 +84,64 @@ MNK_FACTORS = [
DTYPES = [torch.float16, torch.bfloat16] DTYPES = [torch.float16, torch.bfloat16]
DENSE_MARLIN_QUANT_TEST_CONFIGS = [
# AWQ-INT4
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
# GPTQ-INT4
{
"b_type": scalar_types.uint4b8,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT8
{
"b_type": scalar_types.uint8b128,
"support_act_order": True,
"group_blocks": [-1, 2, 4, 8],
},
# FP8
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
# NVFP4
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
# MXFP4
{
"a_type": [scalar_types.bfloat16],
"b_type": scalar_types.float4_e2m1f,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4b8,
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.uint4,
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": [scalar_types.float8_e4m3fn],
"b_type": scalar_types.float4_e2m1f,
"c_type": [scalar_types.bfloat16],
"group_blocks": [2],
},
]
def compute_max_diff(output, output_ref): def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean( return torch.mean(torch.abs(output - output_ref)) / torch.mean(
...@@ -85,6 +153,58 @@ def rand_data(shape, dtype=torch.float16): ...@@ -85,6 +153,58 @@ def rand_data(shape, dtype=torch.float16):
return torch.randn(shape, dtype=dtype, device="cuda") return torch.randn(shape, dtype=dtype, device="cuda")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_without_zp():
qweight_unpacked = torch.randint(
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
)
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed)
torch_res = torch.where(
qweight_unpacked >= 8, qweight_unpacked - 8, 15 - qweight_unpacked
)
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
torch_res = torch_res.to(torch.int8).view(torch.int32)
assert (cuda_res == torch_res).all()
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
def test_marlin_int4_fp8_preprocess_awq():
group_size = 128
qweight_unpacked = torch.randint(
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
)
qzeros_unpacked = torch.randint(
0, 16, size=(2048 // group_size, 2048), dtype=torch.int32, device="cuda"
)
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
qzeros_packed = qzeros_unpacked[:, ::2] * 16 + qzeros_unpacked[:, 1::2]
qzeros_packed = qzeros_packed.to(torch.int8).view(torch.int32)
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed, qzeros_packed)
repeated_zp = qzeros_unpacked.repeat_interleave(group_size, 0)
torch_res = qweight_unpacked - repeated_zp
torch_res[torch_res < 0] = 15 - qweight_unpacked[torch_res < 0]
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
torch_res = torch_res.to(torch.int8).view(torch.int32)
assert (cuda_res == torch_res).all()
@pytest.mark.skipif( @pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"), not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.", reason="Marlin is not supported on this GPU type.",
...@@ -92,16 +212,17 @@ def rand_data(shape, dtype=torch.float16): ...@@ -92,16 +212,17 @@ def rand_data(shape, dtype=torch.float16):
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False)) @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
def test_gptq_marlin_repack( def test_gptq_marlin_repack(
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors k_chunk, n_chunk, quant_type, act_order, is_a_8bit, nk_factors
): ):
m_factor, n_factor, k_factor = mnk_factors n_factor, k_factor = nk_factors
size_k = k_chunk * k_factor size_k = k_chunk * k_factor
size_n = n_chunk * n_factor size_n = n_chunk * n_factor
group_size = 128
# Filter act_order # Filter act_order
if act_order: if act_order:
...@@ -109,6 +230,8 @@ def test_gptq_marlin_repack( ...@@ -109,6 +230,8 @@ def test_gptq_marlin_repack(
return return
if group_size == size_k: if group_size == size_k:
return return
if is_a_8bit:
return
# Normalize group_size # Normalize group_size
if group_size == -1: if group_size == -1:
...@@ -133,23 +256,19 @@ def test_gptq_marlin_repack( ...@@ -133,23 +256,19 @@ def test_gptq_marlin_repack(
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Pack to Marlin format # Pack to Marlin format
weight_perm = get_weight_perm(quant_type.size_bits) weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
marlin_q_w_1 = marlin_weights( marlin_q_w_1 = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
) )
opcheck( opcheck(
torch.ops._C.gptq_marlin_repack, torch.ops._C.gptq_marlin_repack,
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits), (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit),
) )
# Run Marlin repack GPU kernel # Run Marlin repack GPU kernel
marlin_q_w_2 = ops.gptq_marlin_repack( marlin_q_w_2 = ops.gptq_marlin_repack(
q_w_gptq, q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit
sort_indices,
size_k,
size_n,
quant_type.size_bits,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -163,18 +282,15 @@ def test_gptq_marlin_repack( ...@@ -163,18 +282,15 @@ def test_gptq_marlin_repack(
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True)) @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("is_a_8bit", [True, False])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors): def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors):
m_factor, n_factor, k_factor = mnk_factors n_factor, k_factor = nk_factors
size_k = k_chunk * k_factor size_k = k_chunk * k_factor
size_n = n_chunk * n_factor size_n = n_chunk * n_factor
# Normalize group_size group_size = 128
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Create input # Create input
b_weight = rand_data((size_k, size_n)) b_weight = rand_data((size_k, size_n))
...@@ -188,162 +304,221 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors ...@@ -188,162 +304,221 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n) q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
# Pack to Marlin format # Pack to Marlin format
weight_perm = get_weight_perm(quant_type.size_bits) weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
marlin_q_w_1 = marlin_weights( marlin_q_w_1 = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
) )
opcheck( opcheck(
torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits) torch.ops._C.awq_marlin_repack,
(q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit),
) )
# Run Marlin repack GPU kernel # Run Marlin repack GPU kernel
marlin_q_w_2 = ops.awq_marlin_repack( marlin_q_w_2 = ops.awq_marlin_repack(
q_w_awq, q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit
size_k,
size_n,
quant_type.size_bits,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2) torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
def marlin_generate_valid_test_cases():
all_combinations = itertools.product(
DENSE_MARLIN_QUANT_TEST_CONFIGS,
MNK_FACTORS,
MARLIN_N_CHUNKS,
MARLIN_K_CHUNKS,
ACT_ORDER_OPTS,
K_FULL_OPTS,
USE_ATOMIC_ADD_OPTS,
USE_FP32_REDUCE_OPTS,
)
def is_invalid(
a_type,
b_type,
c_type,
group_blocks,
size_m,
size_n,
size_k,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
):
if use_atomic_add:
if use_fp32_reduce:
return False
if (
c_type == scalar_types.bfloat16
and torch.cuda.get_device_capability()[0] < 9
):
return False
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
if group_size > 0 and size_k % group_size != 0:
return False
if act_order and group_size in [-1, size_k]:
return False
if group_size == size_k:
return False
if not act_order and is_k_full:
return False
return a_type.size_bits < 16 or a_type is c_type
cases = []
for case in all_combinations:
quant_test_config, mnk_factors, n_chunk, k_chunk, act_order, *_ = case
size_m = mnk_factors[0]
size_n = mnk_factors[1] * n_chunk
size_k = mnk_factors[2] * k_chunk
if act_order and not quant_test_config.get("support_act_order", False):
continue
f16_types = [scalar_types.float16, scalar_types.bfloat16]
inner_combinations = itertools.product(
quant_test_config.get("a_type", f16_types),
[quant_test_config["b_type"]],
quant_test_config.get("c_type", f16_types),
quant_test_config["group_blocks"],
)
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
):
continue
args = sub_case + (size_m, size_n, size_k) + case[4:]
if is_invalid(*args):
cases.append(args)
return cases
@pytest.mark.skipif( @pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"), not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.", reason="Marlin is not supported on this GPU type.",
) )
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES) (
"a_type, b_type, c_type, group_blocks,"
"size_m, size_n, size_k, act_order, is_k_full,"
"use_atomic_add, use_fp32_reduce"
),
marlin_generate_valid_test_cases(),
) )
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_gptq_marlin_gemm( def test_gptq_marlin_gemm(
k_chunk, a_type,
n_chunk, b_type,
quant_type, c_type,
group_size, group_blocks,
mnk_factors, size_m,
size_n,
size_k,
act_order, act_order,
is_k_full, is_k_full,
use_atomic_add, use_atomic_add,
use_fp32_reduce, use_fp32_reduce,
dtype,
): ):
m_factor, n_factor, k_factor = mnk_factors has_zp = b_type in [scalar_types.uint4, scalar_types.uint8]
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
size_m = m_factor group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
if act_order:
if group_size == -1:
return
if group_size == size_k:
return
if has_zp:
return
if size_k % group_size != 0: if c_type == scalar_types.float16:
return dtype = torch.float16
elif c_type == scalar_types.bfloat16:
dtype = torch.bfloat16
else:
raise RuntimeError("unsupported c_type")
a_input = rand_data((size_m, size_k), dtype) if a_type == scalar_types.int8:
b_weight = rand_data((size_k, size_n), dtype) a_dtype = torch.int8
elif a_type == scalar_types.float8_e4m3fn:
a_dtype = torch.float8_e4m3fn
else:
a_dtype = dtype
if quant_type == scalar_types.float4_e2m1f: a_input = rand_data((size_m, size_k), dtype=dtype)
if group_size not in [16, 32] or act_order: b_weight = rand_data((size_k, size_n), dtype=dtype)
return
if group_size == 32 and dtype == torch.float16:
return
if b_type == scalar_types.float4_e2m1f:
if group_size == 16: if group_size == 16:
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like( w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
b_weight.T, group_size b_weight.T, group_size, input_dtype=a_dtype
) )
else: else:
w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like( w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
b_weight.T, group_size b_weight.T, group_size, input_dtype=a_dtype
) )
marlin_s2 = None marlin_s2 = None
g_idx = None g_idx = None
sort_indices = None sort_indices = None
marlin_zp = None marlin_zp = None
elif quant_type == scalar_types.float8_e4m3fn: elif b_type == scalar_types.float8_e4m3fn:
if group_size not in [-1, 128]: w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
return b_weight.T, group_size, input_dtype=a_dtype
if act_order: )
return
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size)
g_idx = None g_idx = None
sort_indices = None sort_indices = None
marlin_zp = None marlin_zp = None
marlin_s2 = None marlin_s2 = None
elif has_zp: elif has_zp:
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size b_weight, b_type, group_size, input_dtype=a_dtype
) )
g_idx = None g_idx = None
sort_indices = None sort_indices = None
marlin_s2 = None marlin_s2 = None
else: else:
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order b_weight, b_type, group_size, act_order, input_dtype=a_dtype
) )
marlin_zp = None marlin_zp = None
marlin_s2 = None marlin_s2 = None
workspace = marlin_make_workspace_new(w_ref.device) workspace = marlin_make_workspace_new(w_ref.device)
opcheck( if a_type == scalar_types.int8:
torch.ops._C.gptq_marlin_gemm, a_input, a_scales = per_token_quant_int8(a_input)
( a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
a_input, a_input_ref = a_input_ref.to(dtype)
None,
marlin_q_w, if group_size != -1:
None, a_scales = a_scales / 4096 * marlin_s.max()
marlin_s, a_scales = a_scales.float()
marlin_s2, marlin_s = marlin_s / marlin_s.max() * 4096
marlin_zp, marlin_s = marlin_s.round().to(torch.int16).view(dtype)
g_idx, elif a_type == scalar_types.float8_e4m3fn:
sort_indices, a_input, a_scales = ops.scaled_fp8_quant(a_input, use_per_token_if_dynamic=True)
workspace, a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
quant_type.id, a_input_ref = a_input_ref.to(dtype)
a_input.shape[0], else:
b_weight.shape[1], assert a_type.size_bits == 16
a_input.shape[1], a_input_ref = a_input
is_k_full, a_scales = None
use_atomic_add,
use_fp32_reduce, output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
False,
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
a_input, a_input,
None, output,
marlin_q_w, marlin_q_w,
None, None,
marlin_s, marlin_s,
a_scales,
marlin_s2, marlin_s2,
marlin_zp, marlin_zp,
g_idx, g_idx,
sort_indices, sort_indices,
workspace, workspace,
quant_type, b_type,
a_input.shape[0], a_input.shape[0],
b_weight.shape[1], b_weight.shape[1],
a_input.shape[1], a_input.shape[1],
...@@ -352,12 +527,9 @@ def test_gptq_marlin_gemm( ...@@ -352,12 +527,9 @@ def test_gptq_marlin_gemm(
use_fp32_reduce=use_fp32_reduce, use_fp32_reduce=use_fp32_reduce,
is_zp_float=False, is_zp_float=False,
) )
output_ref = torch.matmul(a_input, w_ref) output_ref = torch.matmul(a_input_ref, w_ref)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref) max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04 assert max_diff < 0.04
...@@ -507,6 +679,7 @@ def test_hqq_marlin_gemm( ...@@ -507,6 +679,7 @@ def test_hqq_marlin_gemm(
None, None,
marlin_s, marlin_s,
None, None,
None,
marlin_zp, marlin_zp,
g_idx, g_idx,
g_idx_sort_indices, g_idx_sort_indices,
...@@ -559,6 +732,7 @@ def test_marlin_gemm_subset_input(): ...@@ -559,6 +732,7 @@ def test_marlin_gemm_subset_input():
None, None,
marlin_s, marlin_s,
None, None,
None,
marlin_zp, marlin_zp,
g_idx, g_idx,
sort_indices, sort_indices,
...@@ -607,6 +781,7 @@ def test_marlin_gemm_with_bias(size_m): ...@@ -607,6 +781,7 @@ def test_marlin_gemm_with_bias(size_m):
marlin_bias, marlin_bias,
marlin_s, marlin_s,
None, None,
None,
marlin_zp, marlin_zp,
g_idx, g_idx,
sort_indices, sort_indices,
......
...@@ -846,6 +846,13 @@ def torch_experts( ...@@ -846,6 +846,13 @@ def torch_experts(
or (expert_map is not None and global_num_experts == expert_map.shape[0]) or (expert_map is not None and global_num_experts == expert_map.shape[0])
) )
if quant_dtype in [torch.float16, torch.bfloat16]:
quant_dtype = None
quant_input_only = quant_dtype is not None and w1_scale is None and w2_scale is None
if quant_input_only:
assert a1_scale is None and a2_scale is None
assert per_act_token_quant
M, K = a.shape M, K = a.shape
topk = topk_ids.shape[1] topk = topk_ids.shape[1]
...@@ -863,6 +870,9 @@ def torch_experts( ...@@ -863,6 +870,9 @@ def torch_experts(
a, a1_scale, quant_dtype, per_act_token_quant, block_shape a, a1_scale, quant_dtype, per_act_token_quant, block_shape
) )
if quant_input_only:
a = (a.float() * a_scale.view(-1, 1)).to(w1.dtype)
num_experts = w1.shape[0] num_experts = w1.shape[0]
topk_ids = topk_ids.view(-1) topk_ids = topk_ids.view(-1)
...@@ -882,6 +892,14 @@ def torch_experts( ...@@ -882,6 +892,14 @@ def torch_experts(
out[mask] = tmp2 @ w2[i].transpose(0, 1) out[mask] = tmp2 @ w2[i].transpose(0, 1)
if b_bias2 is not None: if b_bias2 is not None:
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype) out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
elif quant_input_only:
tmp1 = a[mask] @ w1[i].transpose(0, 1)
tmp2 = SiluAndMul()(tmp1)
tmp2, tmp2_scale = moe_kernel_quantize_input(
tmp2, None, quant_dtype, per_act_token_quant
)
tmp2 = (tmp2.float() * tmp2_scale.view(-1, 1)).to(w2.dtype)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
elif block_shape is not None: elif block_shape is not None:
# block quantized # block quantized
assert ( assert (
......
...@@ -554,6 +554,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -554,6 +554,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
b_bias: torch.Tensor | None, b_bias: torch.Tensor | None,
b_scales: torch.Tensor, b_scales: torch.Tensor,
a_scales: torch.Tensor | None,
global_scale: torch.Tensor | None, global_scale: torch.Tensor | None,
b_zeros: torch.Tensor | None, b_zeros: torch.Tensor | None,
g_idx: torch.Tensor | None, g_idx: torch.Tensor | None,
...@@ -568,7 +569,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -568,7 +569,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
use_fp32_reduce: bool = False, use_fp32_reduce: bool = False,
is_zp_float: bool = False, is_zp_float: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) dtype = a.dtype
if dtype not in [torch.half, torch.bfloat16]:
dtype = b_scales.dtype
return torch.empty((size_m, size_n), device=a.device, dtype=dtype)
@register_fake("_C::awq_dequantize") @register_fake("_C::awq_dequantize")
def _awq_dequantize_fake( def _awq_dequantize_fake(
...@@ -1167,8 +1171,11 @@ def gptq_marlin_repack( ...@@ -1167,8 +1171,11 @@ def gptq_marlin_repack(
size_k: int, size_k: int,
size_n: int, size_n: int,
num_bits: int, num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) return torch.ops._C.gptq_marlin_repack(
b_q_weight, perm, size_k, size_n, num_bits, is_a_8bit
)
if hasattr(torch.ops._C, "gptq_marlin_repack"): if hasattr(torch.ops._C, "gptq_marlin_repack"):
...@@ -1180,6 +1187,7 @@ if hasattr(torch.ops._C, "gptq_marlin_repack"): ...@@ -1180,6 +1187,7 @@ if hasattr(torch.ops._C, "gptq_marlin_repack"):
size_k: torch.SymInt, size_k: torch.SymInt,
size_n: torch.SymInt, size_n: torch.SymInt,
num_bits: int, num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
pack_factor = 32 // num_bits pack_factor = 32 // num_bits
marlin_tile_size = 16 marlin_tile_size = 16
...@@ -1192,9 +1200,15 @@ if hasattr(torch.ops._C, "gptq_marlin_repack"): ...@@ -1192,9 +1200,15 @@ if hasattr(torch.ops._C, "gptq_marlin_repack"):
# awq_marlin # awq_marlin
def awq_marlin_repack( def awq_marlin_repack(
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int b_q_weight: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) return torch.ops._C.awq_marlin_repack(
b_q_weight, size_k, size_n, num_bits, is_a_8bit
)
if hasattr(torch.ops._C, "awq_marlin_repack"): if hasattr(torch.ops._C, "awq_marlin_repack"):
...@@ -1205,6 +1219,7 @@ if hasattr(torch.ops._C, "awq_marlin_repack"): ...@@ -1205,6 +1219,7 @@ if hasattr(torch.ops._C, "awq_marlin_repack"):
size_k: torch.SymInt, size_k: torch.SymInt,
size_n: torch.SymInt, size_n: torch.SymInt,
num_bits: int, num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
pack_factor = 32 // num_bits pack_factor = 32 // num_bits
marlin_tile_size = 16 marlin_tile_size = 16
...@@ -1221,6 +1236,7 @@ def gptq_marlin_moe_repack( ...@@ -1221,6 +1236,7 @@ def gptq_marlin_moe_repack(
size_k: int, size_k: int,
size_n: int, size_n: int,
num_bits: int, num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
num_experts = b_q_weight.shape[0] num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0 assert size_k % 16 == 0
...@@ -1231,7 +1247,7 @@ def gptq_marlin_moe_repack( ...@@ -1231,7 +1247,7 @@ def gptq_marlin_moe_repack(
) )
for e in range(num_experts): for e in range(num_experts):
output[e] = torch.ops._C.gptq_marlin_repack( output[e] = torch.ops._C.gptq_marlin_repack(
b_q_weight[e], perm[e], size_k, size_n, num_bits b_q_weight[e], perm[e], size_k, size_n, num_bits, is_a_8bit
) )
return output return output
...@@ -1242,6 +1258,7 @@ def awq_marlin_moe_repack( ...@@ -1242,6 +1258,7 @@ def awq_marlin_moe_repack(
size_k: int, size_k: int,
size_n: int, size_n: int,
num_bits: int, num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
num_experts = b_q_weight.shape[0] num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0 assert size_k % 16 == 0
...@@ -1252,17 +1269,26 @@ def awq_marlin_moe_repack( ...@@ -1252,17 +1269,26 @@ def awq_marlin_moe_repack(
) )
for e in range(num_experts): for e in range(num_experts):
output[e] = torch.ops._C.awq_marlin_repack( output[e] = torch.ops._C.awq_marlin_repack(
b_q_weight[e], size_k, size_n, num_bits b_q_weight[e], size_k, size_n, num_bits, is_a_8bit
) )
return output return output
def marlin_int4_fp8_preprocess(
qweight: torch.Tensor,
qzeros_or_none: torch.Tensor | None = None,
inplace: bool = False,
):
return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace)
def gptq_marlin_gemm( def gptq_marlin_gemm(
a: torch.Tensor, a: torch.Tensor,
c: torch.Tensor | None, c: torch.Tensor | None,
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
b_bias: torch.Tensor | None, b_bias: torch.Tensor | None,
b_scales: torch.Tensor, b_scales: torch.Tensor,
a_scales: torch.Tensor | None,
global_scale: torch.Tensor | None, global_scale: torch.Tensor | None,
b_zeros: torch.Tensor | None, b_zeros: torch.Tensor | None,
g_idx: torch.Tensor | None, g_idx: torch.Tensor | None,
...@@ -1283,6 +1309,7 @@ def gptq_marlin_gemm( ...@@ -1283,6 +1309,7 @@ def gptq_marlin_gemm(
b_q_weight, b_q_weight,
b_bias, b_bias,
b_scales, b_scales,
a_scales,
global_scale, global_scale,
b_zeros, b_zeros,
g_idx, g_idx,
...@@ -1600,7 +1627,7 @@ def allspark_repack_weight( ...@@ -1600,7 +1627,7 @@ def allspark_repack_weight(
if use asymmetric quantization, has_zp = True. if use asymmetric quantization, has_zp = True.
Returns: Returns:
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] :
rearranged weight, scale, and optionally zero_point. rearranged weight, scale, and optionally zero_point.
""" """
K = qweight.shape[0] K = qweight.shape[0]
...@@ -1683,7 +1710,7 @@ def scaled_int8_quant( ...@@ -1683,7 +1710,7 @@ def scaled_int8_quant(
symmetric: Whether to use symmetric quantization (scale only, azp ignored). symmetric: Whether to use symmetric quantization (scale only, azp ignored).
Returns: Returns:
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp.
""" """
output = torch.empty_like(input, dtype=torch.int8) output = torch.empty_like(input, dtype=torch.int8)
if scale is not None: if scale is not None:
...@@ -2004,6 +2031,7 @@ def moe_wna16_marlin_gemm( ...@@ -2004,6 +2031,7 @@ def moe_wna16_marlin_gemm(
b_qweight: torch.Tensor, b_qweight: torch.Tensor,
b_bias: torch.Tensor | None, b_bias: torch.Tensor | None,
b_scales: torch.Tensor, b_scales: torch.Tensor,
a_scales: torch.Tensor | None,
global_scale: torch.Tensor | None, global_scale: torch.Tensor | None,
b_qzeros: torch.Tensor | None, b_qzeros: torch.Tensor | None,
g_idx: torch.Tensor | None, g_idx: torch.Tensor | None,
...@@ -2025,6 +2053,9 @@ def moe_wna16_marlin_gemm( ...@@ -2025,6 +2053,9 @@ def moe_wna16_marlin_gemm(
use_atomic_add: bool, use_atomic_add: bool,
use_fp32_reduce: bool, use_fp32_reduce: bool,
is_zp_float: bool, is_zp_float: bool,
thread_k: int = -1,
thread_n: int = -1,
blocks_per_sm: int = -1,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops._moe_C.moe_wna16_marlin_gemm( return torch.ops._moe_C.moe_wna16_marlin_gemm(
input, input,
...@@ -2032,6 +2063,7 @@ def moe_wna16_marlin_gemm( ...@@ -2032,6 +2063,7 @@ def moe_wna16_marlin_gemm(
b_qweight, b_qweight,
b_bias, b_bias,
b_scales, b_scales,
a_scales,
global_scale, global_scale,
b_qzeros, b_qzeros,
g_idx, g_idx,
...@@ -2053,6 +2085,9 @@ def moe_wna16_marlin_gemm( ...@@ -2053,6 +2085,9 @@ def moe_wna16_marlin_gemm(
use_atomic_add, use_atomic_add,
use_fp32_reduce, use_fp32_reduce,
is_zp_float, is_zp_float,
thread_k,
thread_n,
blocks_per_sm,
) )
...@@ -2088,7 +2123,10 @@ if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe") ...@@ -2088,7 +2123,10 @@ if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe")
input: torch.Tensor, input: torch.Tensor,
output: torch.Tensor | None, output: torch.Tensor | None,
b_qweight: torch.Tensor, b_qweight: torch.Tensor,
b_bias: torch.Tensor | None,
b_scales: torch.Tensor, b_scales: torch.Tensor,
a_scales: torch.Tensor | None,
global_scale: torch.Tensor | None,
b_qzeros: torch.Tensor | None, b_qzeros: torch.Tensor | None,
g_idx: torch.Tensor | None, g_idx: torch.Tensor | None,
perm: torch.Tensor | None, perm: torch.Tensor | None,
...@@ -2109,7 +2147,7 @@ if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe") ...@@ -2109,7 +2147,7 @@ if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe")
use_atomic_add: bool, use_atomic_add: bool,
use_fp32_reduce: bool, use_fp32_reduce: bool,
is_zp_float: bool, is_zp_float: bool,
) -> torch.Tensor: ):
return torch.empty( return torch.empty(
(size_m * top_k, size_n), dtype=input.dtype, device=input.device (size_m * top_k, size_n), dtype=input.dtype, device=input.device
) )
...@@ -2583,7 +2621,7 @@ def onednn_scaled_int8_quant( ...@@ -2583,7 +2621,7 @@ def onednn_scaled_int8_quant(
symmetric: Whether to use symmetric quantization (scale only, azp ignored). symmetric: Whether to use symmetric quantization (scale only, azp ignored).
Returns: Returns:
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp.
""" """
output = torch.empty_like(input, dtype=torch.int8) output = torch.empty_like(input, dtype=torch.int8)
token_num = input.numel() // input.shape[-1] token_num = input.numel() // input.shape[-1]
......
...@@ -145,6 +145,7 @@ if TYPE_CHECKING: ...@@ -145,6 +145,7 @@ if TYPE_CHECKING:
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict" VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict"
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_MARLIN_INPUT_DTYPE: Literal["int8", "fp8"] | None = None
VLLM_MXFP4_USE_MARLIN: bool | None = None VLLM_MXFP4_USE_MARLIN: bool | None = None
VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_BUCKET_PADDING_GAP: int = 0
...@@ -1122,6 +1123,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1122,6 +1123,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool( "VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool(
os.environ.get("VLLM_MXFP4_USE_MARLIN", None) os.environ.get("VLLM_MXFP4_USE_MARLIN", None)
), ),
# The activation dtype for marlin kernel
"VLLM_MARLIN_INPUT_DTYPE": env_with_choices(
"VLLM_MARLIN_INPUT_DTYPE", None, ["int8", "fp8"]
),
# Whether to turn on the outlines cache for V1 # Whether to turn on the outlines cache for V1
# This cache is unbounded and on disk, so it's not safe to use in # This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users. # an environment with potentially malicious users.
......
...@@ -24,7 +24,7 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_in ...@@ -24,7 +24,7 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_in
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_intermediate_size, marlin_moe_intermediate_size,
maybe_warn_marlin_atomic_add, marlin_quant_input,
) )
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
...@@ -65,6 +65,8 @@ def _fused_marlin_moe( ...@@ -65,6 +65,8 @@ def _fused_marlin_moe(
activation_func: Callable[ activation_func: Callable[
[str, torch.Tensor, torch.Tensor], None [str, torch.Tensor, torch.Tensor], None
] = default_activation_func, ] = default_activation_func,
input_global_scale1: torch.Tensor | None = None,
input_global_scale2: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None, g_idx1: torch.Tensor | None = None,
...@@ -77,6 +79,7 @@ def _fused_marlin_moe( ...@@ -77,6 +79,7 @@ def _fused_marlin_moe(
intermediate_cache13: torch.Tensor | None = None, intermediate_cache13: torch.Tensor | None = None,
intermediate_cache2: torch.Tensor | None = None, intermediate_cache2: torch.Tensor | None = None,
output: torch.Tensor | None = None, output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
is_k_full: bool = True, is_k_full: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
assert hidden_states.ndim == 2 assert hidden_states.ndim == 2
...@@ -106,18 +109,22 @@ def _fused_marlin_moe( ...@@ -106,18 +109,22 @@ def _fused_marlin_moe(
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * num_topk, N)) intermediate_cache2 = _resize_cache(intermediate_cache2, (M * num_topk, N))
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype) a_scales1 = None
use_atomic_add = ( gate_up_input = hidden_states
hidden_states.dtype == torch.half if input_dtype == torch.int8:
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 gate_up_input, a_scales1 = marlin_quant_input(hidden_states, input_dtype)
) if input_global_scale1 is not None:
a_scales1 = a_scales1 * input_global_scale1
elif input_dtype == torch.float8_e4m3fn:
gate_up_input, a_scales1 = marlin_quant_input(hidden_states, input_dtype)
intermediate_cache1 = ops.moe_wna16_marlin_gemm( intermediate_cache1 = ops.moe_wna16_marlin_gemm(
hidden_states, gate_up_input,
intermediate_cache1, intermediate_cache1,
w1, w1,
bias1, bias1,
w1_scale, w1_scale,
a_scales1,
global_scale1, global_scale1,
w1_zeros, w1_zeros,
g_idx1, g_idx1,
...@@ -136,7 +143,7 @@ def _fused_marlin_moe( ...@@ -136,7 +143,7 @@ def _fused_marlin_moe(
size_n=2 * N, size_n=2 * N,
size_k=K, size_k=K,
is_k_full=is_k_full, is_k_full=is_k_full,
use_atomic_add=use_atomic_add, use_atomic_add=False,
use_fp32_reduce=True, use_fp32_reduce=True,
is_zp_float=False, is_zp_float=False,
) )
...@@ -151,12 +158,25 @@ def _fused_marlin_moe( ...@@ -151,12 +158,25 @@ def _fused_marlin_moe(
if expert_map is not None: if expert_map is not None:
output.zero_() output.zero_()
a_scales2 = None
if input_dtype == torch.int8:
intermediate_cache2, a_scales2 = marlin_quant_input(
intermediate_cache2, input_dtype
)
if input_global_scale2 is not None:
a_scales2 = a_scales2 * input_global_scale2
elif input_dtype == torch.float8_e4m3fn:
intermediate_cache2, a_scales2 = marlin_quant_input(
intermediate_cache2, input_dtype
)
output = ops.moe_wna16_marlin_gemm( output = ops.moe_wna16_marlin_gemm(
intermediate_cache2, intermediate_cache2,
output, output,
w2, w2,
bias2, bias2,
w2_scale, w2_scale,
a_scales2,
global_scale2, global_scale2,
w2_zeros, w2_zeros,
g_idx2, g_idx2,
...@@ -175,7 +195,7 @@ def _fused_marlin_moe( ...@@ -175,7 +195,7 @@ def _fused_marlin_moe(
size_n=K, size_n=K,
size_k=N, size_k=N,
is_k_full=is_k_full, is_k_full=is_k_full,
use_atomic_add=use_atomic_add, use_atomic_add=False,
use_fp32_reduce=True, use_fp32_reduce=True,
is_zp_float=False, is_zp_float=False,
) )
...@@ -203,6 +223,8 @@ def fused_marlin_moe( ...@@ -203,6 +223,8 @@ def fused_marlin_moe(
] = default_activation_func, ] = default_activation_func,
moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None, moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
input_global_scale1: torch.Tensor | None = None,
input_global_scale2: torch.Tensor | None = None,
global_scale1: torch.Tensor | None = None, global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None, global_scale2: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None, g_idx1: torch.Tensor | None = None,
...@@ -216,6 +238,7 @@ def fused_marlin_moe( ...@@ -216,6 +238,7 @@ def fused_marlin_moe(
intermediate_cache2: torch.Tensor | None = None, intermediate_cache2: torch.Tensor | None = None,
is_k_full: bool = True, is_k_full: bool = True,
output: torch.Tensor | None = None, output: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
inplace: bool = False, inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -287,6 +310,9 @@ def fused_marlin_moe( ...@@ -287,6 +310,9 @@ def fused_marlin_moe(
if M * topk / E / block_size_m < 0.9: if M * topk / E / block_size_m < 0.9:
break break
if input_dtype is not None and input_dtype.itemsize == 1:
block_size_m = max(block_size_m, 16)
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
...@@ -313,6 +339,8 @@ def fused_marlin_moe( ...@@ -313,6 +339,8 @@ def fused_marlin_moe(
num_tokens_post_padded=num_tokens_post_padded, num_tokens_post_padded=num_tokens_post_padded,
activation=activation, activation=activation,
activation_func=activation_func, activation_func=activation_func,
input_global_scale1=input_global_scale1,
input_global_scale2=input_global_scale2,
global_scale1=global_scale1, global_scale1=global_scale1,
global_scale2=global_scale2, global_scale2=global_scale2,
g_idx1=g_idx1, g_idx1=g_idx1,
...@@ -325,6 +353,7 @@ def fused_marlin_moe( ...@@ -325,6 +353,7 @@ def fused_marlin_moe(
intermediate_cache13=intermediate_cache13, intermediate_cache13=intermediate_cache13,
intermediate_cache2=intermediate_cache2, intermediate_cache2=intermediate_cache2,
output=None, output=None,
input_dtype=input_dtype,
is_k_full=is_k_full, is_k_full=is_k_full,
).view(-1, topk, K) ).view(-1, topk, K)
......
...@@ -266,7 +266,7 @@ class AutoRoundConfig(QuantizationConfig): ...@@ -266,7 +266,7 @@ class AutoRoundConfig(QuantizationConfig):
from vllm.model_executor.layers.quantization.awq_marlin import ( from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig, AWQMarlinConfig,
AWQMarlinLinearMethod, AWQMarlinLinearMethod,
AWQMoEMethod, AWQMarlinMoEMethod,
) )
quant_args_marlin = AWQMarlinConfig( quant_args_marlin = AWQMarlinConfig(
...@@ -291,7 +291,7 @@ class AutoRoundConfig(QuantizationConfig): ...@@ -291,7 +291,7 @@ class AutoRoundConfig(QuantizationConfig):
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
if use_marlin: if use_marlin:
return AWQMoEMethod(quant_args_marlin, layer.moe_config) return AWQMarlinMoEMethod(quant_args_marlin, layer.moe)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
config = { config = {
......
...@@ -106,7 +106,7 @@ class AWQConfig(QuantizationConfig): ...@@ -106,7 +106,7 @@ class AWQConfig(QuantizationConfig):
return AWQLinearMethod(self) return AWQLinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
# Lazy import to avoid circular import. # Lazy import to avoid circular import.
from .awq_marlin import AWQMarlinConfig, AWQMoEMethod from .awq_marlin import AWQMarlinConfig, AWQMarlinMoEMethod
from .moe_wna16 import MoeWNA16Config from .moe_wna16 import MoeWNA16Config
from .utils.marlin_utils import check_moe_marlin_supports_layer from .utils.marlin_utils import check_moe_marlin_supports_layer
...@@ -136,7 +136,7 @@ class AWQConfig(QuantizationConfig): ...@@ -136,7 +136,7 @@ class AWQConfig(QuantizationConfig):
awq_marlin_config = AWQMarlinConfig.from_config( awq_marlin_config = AWQMarlinConfig.from_config(
marlin_compatible_config_dict marlin_compatible_config_dict
) )
return AWQMoEMethod(awq_marlin_config, layer.moe_config) return AWQMarlinMoEMethod(awq_marlin_config, layer.moe_config)
return None return None
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
......
...@@ -40,6 +40,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -40,6 +40,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_marlin_supported,
check_marlin_supports_layer, check_marlin_supports_layer,
check_moe_marlin_supports_layer, check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_empty_g_idx, marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_moe_permute_scales,
...@@ -69,7 +71,6 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -69,7 +71,6 @@ class AWQMarlinConfig(QuantizationConfig):
# num_bits -> type # num_bits -> type
TYPE_MAP = { TYPE_MAP = {
4: scalar_types.uint4, 4: scalar_types.uint4,
8: scalar_types.uint8,
} }
def __init__( def __init__(
...@@ -193,7 +194,9 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -193,7 +194,9 @@ class AWQMarlinConfig(QuantizationConfig):
return AWQConfig.from_config(self.full_config).get_quant_method( return AWQConfig.from_config(self.full_config).get_quant_method(
layer, prefix layer, prefix
) )
return AWQMarlinLinearMethod(self) quant_method = AWQMarlinLinearMethod(self)
quant_method.input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
...@@ -211,7 +214,9 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -211,7 +214,9 @@ class AWQMarlinConfig(QuantizationConfig):
return MoeWNA16Config.from_config(self.full_config).get_quant_method( return MoeWNA16Config.from_config(self.full_config).get_quant_method(
layer, prefix layer, prefix
) )
return AWQMoEMethod(self, layer.moe_config) moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
return None return None
@classmethod @classmethod
...@@ -270,6 +275,8 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -270,6 +275,8 @@ class AWQMarlinLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQMarlinConfig) -> None: def __init__(self, quant_config: AWQMarlinConfig) -> None:
self.quant_config = quant_config self.quant_config = quant_config
self.quant_type = scalar_types.uint4
self.input_dtype = None
def create_weights( def create_weights(
self, self,
...@@ -312,6 +319,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -312,6 +319,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
) )
num_groups = input_size_per_partition // group_size num_groups = input_size_per_partition // group_size
layer.num_groups = num_groups
qzeros = PackedvLLMParameter( qzeros = PackedvLLMParameter(
data=torch.empty( data=torch.empty(
...@@ -358,12 +366,19 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -358,12 +366,19 @@ class AWQMarlinLinearMethod(LinearMethodBase):
# Allocate marlin workspace # Allocate marlin workspace
layer.workspace = marlin_make_workspace_new(device) layer.workspace = marlin_make_workspace_new(device)
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(layer.qweight, layer.qzeros, inplace=True)
layer.scales.data = layer.scales.data * 512
# Repack weights from AWQ format to marlin format. # Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack( marlin_qweight = ops.awq_marlin_repack(
layer.qweight, layer.qweight,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits, num_bits=self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "qweight", marlin_qweight) replace_parameter(layer, "qweight", marlin_qweight)
...@@ -373,7 +388,16 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -373,7 +388,16 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
) )
if self.input_dtype == torch.int8 and layer.num_groups > 1:
marlin_scales, input_global_scale = marlin_act_int8_process_scales(
marlin_scales
)
layer.register_parameter(
"input_global_scale", Parameter(input_global_scale, requires_grad=False)
)
replace_parameter(layer, "scales", marlin_scales) replace_parameter(layer, "scales", marlin_scales)
# Permute zero-points from AWQ format to marlin format. # Permute zero-points from AWQ format to marlin format.
...@@ -382,6 +406,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -382,6 +406,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.num_groups, size_k=layer.num_groups,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits, num_bits=self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "qzeros", marlin_zp) replace_parameter(layer, "qzeros", marlin_zp)
...@@ -409,11 +434,13 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -409,11 +434,13 @@ class AWQMarlinLinearMethod(LinearMethodBase):
quant_type=self.quant_config.quant_type, quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition, output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition, input_size_per_partition=layer.input_size_per_partition,
input_global_scale=getattr(layer, "input_global_scale", None),
bias=bias, bias=bias,
input_dtype=self.input_dtype,
) )
class AWQMoEMethod(FusedMoEMethodBase): class AWQMarlinMoEMethod(FusedMoEMethodBase):
def __init__( def __init__(
self, self,
quant_config: AWQMarlinConfig, quant_config: AWQMarlinConfig,
...@@ -422,8 +449,9 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -422,8 +449,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
if self.quant_config.weight_bits != 4: if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.") raise ValueError("AWQMarlinMoEMethod only supports 4bit now.")
self.quant_type = scalar_types.uint4 self.quant_type = scalar_types.uint4
self.input_dtype = None
self.use_marlin = True self.use_marlin = True
def create_weights( def create_weights(
...@@ -435,6 +463,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -435,6 +463,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
layer.input_dtype = self.input_dtype
extra_weight_attrs.update( extra_weight_attrs.update(
{ {
"is_transposed": True, "is_transposed": True,
...@@ -468,6 +497,8 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -468,6 +497,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_groups_w13 = hidden_size // self.quant_config.group_size num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size
layer.num_groups_w13 = num_groups_w13
layer.num_groups_w2 = num_groups_w2
# WEIGHT_SCALES # WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively. # Allocate 2 scales for w1 and w3 respectively.
...@@ -522,6 +553,21 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -522,6 +553,21 @@ class AWQMoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0] num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device device = layer.w13_qweight.device
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(
layer.w13_qweight.view(-1, layer.w13_qweight.size(2)),
layer.w13_qzeros.view(-1, layer.w13_qzeros.size(2)),
inplace=True,
)
ops.marlin_int4_fp8_preprocess(
layer.w2_qweight.view(-1, layer.w2_qweight.size(2)),
layer.w2_qzeros.view(-1, layer.w2_qzeros.size(2)),
inplace=True,
)
layer.w13_scales.data = layer.w13_scales.data * 512
layer.w2_scales.data = layer.w2_scales.data * 512
layer.w13_g_idx_sort_indices = torch.nn.Parameter( layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device), torch.empty((num_experts, 0), dtype=torch.int32, device=device),
...@@ -538,6 +584,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -538,6 +584,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w13_qweight.shape[1], size_k=layer.w13_qweight.shape[1],
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits, num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w13_qweight", marlin_w13_qweight) replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
...@@ -547,6 +594,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -547,6 +594,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w2_qweight.shape[1], size_k=layer.w2_qweight.shape[1],
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits, num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w2_qweight", marlin_w2_qweight) replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
...@@ -556,7 +604,16 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -556,7 +604,16 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.intermediate_size_per_partition, size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2], size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
) )
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_scales", marlin_w13_scales) replace_parameter(layer, "w13_scales", marlin_w13_scales)
...@@ -565,7 +622,17 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -565,7 +622,17 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.intermediate_size_per_partition, size_k=layer.intermediate_size_per_partition,
size_n=layer.w2_scales.shape[2], size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
) )
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_scales", marlin_w2_scales) replace_parameter(layer, "w2_scales", marlin_w2_scales)
marlin_w13_zp = moe_awq_to_marlin_zero_points( marlin_w13_zp = moe_awq_to_marlin_zero_points(
...@@ -573,6 +640,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -573,6 +640,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w13_qzeros.shape[1], size_k=layer.w13_qzeros.shape[1],
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits, num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w13_qzeros", marlin_w13_zp) replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
...@@ -581,6 +649,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -581,6 +649,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
size_k=layer.w2_qzeros.shape[1], size_k=layer.w2_qzeros.shape[1],
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits, num_bits=self.quant_config.weight_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w2_qzeros", marlin_w2_zp) replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
...@@ -636,6 +705,8 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -636,6 +705,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits, router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -643,4 +714,5 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -643,4 +714,5 @@ class AWQMoEMethod(FusedMoEMethodBase):
w1_zeros=layer.w13_qzeros, w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros, w2_zeros=layer.w2_qzeros,
workspace=layer.workspace, workspace=layer.workspace,
input_dtype=self.input_dtype,
) )
...@@ -157,7 +157,9 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -157,7 +157,9 @@ class CompressedTensorsConfig(QuantizationConfig):
if isinstance(layer, Attention): if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self) return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix) return CompressedTensorsMoEMethod.get_moe_method(
self, layer, layer_name=prefix
)
return None return None
def _add_fused_moe_to_target_scheme_map(self): def _add_fused_moe_to_target_scheme_map(self):
...@@ -547,6 +549,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -547,6 +549,7 @@ class CompressedTensorsConfig(QuantizationConfig):
weight_quant: QuantizationArgs, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs, input_quant: QuantizationArgs,
format: str | None = None, format: str | None = None,
layer_name: str | None = None,
) -> "CompressedTensorsScheme": ) -> "CompressedTensorsScheme":
# use the per-layer format if defined, otherwise, use global format # use the per-layer format if defined, otherwise, use global format
format = format if format is not None else self.quant_format format = format if format is not None else self.quant_format
...@@ -585,6 +588,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -585,6 +588,7 @@ class CompressedTensorsConfig(QuantizationConfig):
symmetric=weight_quant.symmetric, symmetric=weight_quant.symmetric,
group_size=weight_quant.group_size, group_size=weight_quant.group_size,
actorder=weight_quant.actorder, actorder=weight_quant.actorder,
layer_name=layer_name,
) )
act_quant_format = is_activation_quantization_format(format) act_quant_format = is_activation_quantization_format(format)
...@@ -724,7 +728,10 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -724,7 +728,10 @@ class CompressedTensorsConfig(QuantizationConfig):
else: else:
# Find the quant_scheme # Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore scheme = self._get_scheme_from_parts( # type: ignore
weight_quant=weight_quant, input_quant=input_quant, format=format weight_quant=weight_quant,
input_quant=input_quant,
format=format,
layer_name=layer_name,
) )
# Raise error if device does not support the scheme # Raise error if device does not support the scheme
......
...@@ -64,6 +64,8 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -64,6 +64,8 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_moe_permute_scales,
) )
...@@ -101,7 +103,7 @@ __all__ = [ ...@@ -101,7 +103,7 @@ __all__ = [
"CompressedTensorsW8A8Int8MoEMethod", "CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod", "CompressedTensorsWNA16MoEMethod",
"CompressedTensorsW4A4Nvfp4MoeMethod", "CompressedTensorsW4A4Nvfp4MoEMethod",
"CompressedTensorsW4A8Int8MoEMethod", "CompressedTensorsW4A8Int8MoEMethod",
] ]
...@@ -111,13 +113,13 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -111,13 +113,13 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def get_moe_method( def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, layer_name: str,
) -> "CompressedTensorsMoEMethod": ) -> "CompressedTensorsMoEMethod":
# FusedMoE was made by combining multiple Linears so need to # FusedMoE was made by combining multiple Linears so need to
# make sure quantization config for Linear can target it # make sure quantization config for Linear can target it
quant_config._add_fused_moe_to_target_scheme_map() quant_config._add_fused_moe_to_target_scheme_map()
unfused_names = [ unfused_names = [
prefix + proj_name layer_name + proj_name
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"] for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
] ]
# TODO: refactor this to use expert_mapping and check all layer numbers # TODO: refactor this to use expert_mapping and check all layer numbers
...@@ -158,32 +160,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -158,32 +160,40 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"WNA16MoE is not supported with actorder=group/dynamic." "WNA16MoE is not supported with actorder=group/dynamic."
) )
logger.info_once("Using CompressedTensorsWNA16MoEMethod") logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config) return CompressedTensorsWNA16MoEMethod(
quant_config, layer.moe_config, layer_name
)
else: else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod( return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config quant_config, layer.moe_config, layer_name
) )
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4Nvfp4MoeMethod(layer.moe_config) return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name)
elif ( elif (
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
or quant_config._is_fp8_w8a8(weight_quant, input_quant) or quant_config._is_fp8_w8a8(weight_quant, input_quant)
): ):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config, layer.moe_config) return CompressedTensorsW8A8Fp8MoEMethod(
quant_config, layer.moe_config, layer_name
)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config, layer.moe_config) return CompressedTensorsW8A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name
)
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant): elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
return CompressedTensorsW4A8Int8MoEMethod(quant_config, layer.moe_config) return CompressedTensorsW4A8Int8MoEMethod(
quant_config, layer.moe_config, layer_name
)
else: else:
raise RuntimeError( raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}" f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}"
) )
class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support, detect_nvfp4_moe_support,
) )
...@@ -194,17 +204,21 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): ...@@ -194,17 +204,21 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
self.allow_flashinfer = _nvfp4.allow_flashinfer self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin self.use_marlin = _nvfp4.use_marlin
self.group_size = 16 self.group_size = 16
self.layer_name = layer_name
self.marlin_input_dtype = (
get_marlin_input_dtype(layer_name) if self.use_marlin else None
)
self.flashinfer_moe_backend = None self.flashinfer_moe_backend = None
if self.allow_flashinfer: if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend() self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once( logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
" for CompressedTensorsW4A4Nvfp4MoeMethod." " for CompressedTensorsW4A4Nvfp4MoEMethod."
) )
elif self.use_marlin: elif self.use_marlin:
logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoeMethod.") logger.info_once("Using Marlin for CompressedTensorsW4A4Nvfp4MoEMethod.")
else: else:
logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoeMethod.") logger.info_once("Using Cutlass for CompressedTensorsW4A4Nvfp4MoEMethod.")
def create_weights( def create_weights(
self, self,
...@@ -354,7 +368,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): ...@@ -354,7 +368,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
) )
if self.use_marlin: if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer) prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
return return
# w13 # w13
if ( if (
...@@ -538,7 +552,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): ...@@ -538,7 +552,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
): ):
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet." "EPLB not supported for `CompressedTensorsW4A4MoEMethod` yet."
) )
return flashinfer_trtllm_fp4_moe( return flashinfer_trtllm_fp4_moe(
...@@ -576,6 +590,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): ...@@ -576,6 +590,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace, workspace=layer.workspace,
) )
...@@ -610,7 +625,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod): ...@@ -610,7 +625,7 @@ class CompressedTensorsW4A4Nvfp4MoeMethod(CompressedTensorsMoEMethod):
assert expert_map is None, ( assert expert_map is None, (
"Expert Parallelism / expert_map " "Expert Parallelism / expert_map "
"is currently not supported for " "is currently not supported for "
"CompressedTensorsW4A4Nvfp4MoeMethod." "CompressedTensorsW4A4Nvfp4MoEMethod."
) )
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
...@@ -637,6 +652,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -637,6 +652,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
...@@ -690,6 +706,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -690,6 +706,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
or self.is_fp8_w8a8_sm100 or self.is_fp8_w8a8_sm100
) )
self.disable_expert_map = False self.disable_expert_map = False
self.layer_name = layer_name
self.marlin_input_dtype = (
get_marlin_input_dtype(layer_name) if self.use_marlin else None
)
def create_weights( def create_weights(
self, self,
...@@ -931,7 +951,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -931,7 +951,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
elif self.use_marlin: elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False) prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
...@@ -1144,6 +1166,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1144,6 +1166,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace, workspace=layer.workspace,
) )
...@@ -1240,6 +1263,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1240,6 +1263,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
...@@ -1392,6 +1416,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1392,6 +1416,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
...@@ -1403,6 +1428,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1403,6 +1428,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
self.strategy = config.strategy self.strategy = config.strategy
self.group_size = config.group_size self.group_size = config.group_size
self.actorder = config.actorder self.actorder = config.actorder
self.layer_name = layer_name
self.marlin_input_dtype = get_marlin_input_dtype(layer_name)
assert config.symmetric, "Only symmetric quantization is supported for MoE" assert config.symmetric, "Only symmetric quantization is supported for MoE"
if not ( if not (
...@@ -1477,6 +1504,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1477,6 +1504,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
num_groups_w2 = w2_scales_size // self.group_size num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size num_groups_w13 = hidden_size // self.group_size
layer.num_groups_w13 = num_groups_w13
layer.num_groups_w2 = num_groups_w2
w13_scale = torch.nn.Parameter( w13_scale = torch.nn.Parameter(
torch.ones( torch.ones(
num_experts, num_experts,
...@@ -1560,6 +1590,17 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1560,6 +1590,17 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_weight_g_idx.shape[0] num_experts = layer.w13_weight_g_idx.shape[0]
device = layer.w13_weight_g_idx.device device = layer.w13_weight_g_idx.device
is_a_8bit = (
self.marlin_input_dtype is not None
and self.marlin_input_dtype.itemsize == 1
)
if self.marlin_input_dtype == torch.float8_e4m3fn:
# NOTE: for non-zp quantization format only
ops.marlin_int4_fp8_preprocess(layer.w13_weight_packed, inplace=True)
ops.marlin_int4_fp8_preprocess(layer.w2_weight_packed, inplace=True)
layer.w13_weight_scale.data = layer.w13_weight_scale.data * 512
layer.w2_weight_scale.data = layer.w2_weight_scale.data * 512
# when running models with grouped act order, # when running models with grouped act order,
# resort to g_idx values provided in checkpoint # resort to g_idx values provided in checkpoint
...@@ -1610,31 +1651,54 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1610,31 +1651,54 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight_packed.shape[1] * self.packed_factor, layer.w13_weight_packed.shape[1] * self.packed_factor,
layer.w13_weight_packed.shape[2], layer.w13_weight_packed.shape[2],
self.num_bits, self.num_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight) replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack( marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_weight_packed, layer.w2_weight_packed,
layer.w2_g_idx_sort_indices, layer.w2_g_idx_sort_indices,
layer.w2_weight_packed.shape[1] * self.packed_factor, layer.w2_weight_packed.shape[1] * self.packed_factor,
layer.w2_weight_packed.shape[2], layer.w2_weight_packed.shape[2],
self.num_bits, self.num_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight) replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
# Repack scales # Repack scales
marlin_w13_scales = marlin_moe_permute_scales( marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_weight_scale, s=layer.w13_weight_scale,
size_k=layer.w13_weight_packed.shape[2], size_k=layer.w13_weight_packed.shape[2],
size_n=layer.w13_weight_scale.shape[2], size_n=layer.w13_weight_scale.shape[2],
group_size=self.group_size, group_size=self.group_size,
is_a_8bit=is_a_8bit,
) )
if self.marlin_input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales( marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_weight_scale, s=layer.w2_weight_scale,
size_k=layer.w2_weight_scale.shape[1] size_k=layer.w2_weight_scale.shape[1]
* (self.group_size if self.group_size != -1 else self.packed_factor), * (self.group_size if self.group_size != -1 else self.packed_factor),
size_n=layer.w2_weight_scale.shape[2], size_n=layer.w2_weight_scale.shape[2],
group_size=self.group_size, group_size=self.group_size,
is_a_8bit=is_a_8bit,
) )
if self.marlin_input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
layer.workspace = marlin_make_workspace_new(device, 4) layer.workspace = marlin_make_workspace_new(device, 4)
...@@ -1729,6 +1793,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1729,6 +1793,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
router_logits, router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -1738,6 +1804,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1738,6 +1804,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
sort_indices1=layer.w13_g_idx_sort_indices, sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace, workspace=layer.workspace,
input_dtype=self.marlin_input_dtype,
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
) )
...@@ -1747,6 +1814,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1747,6 +1814,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
...@@ -1999,6 +2067,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1999,6 +2067,7 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
self, self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig, moe: FusedMoEConfig,
layer_name: str | None = None,
): ):
super().__init__(moe) super().__init__(moe)
self.has_bias = self.moe.has_bias self.has_bias = self.moe.has_bias
......
...@@ -14,7 +14,11 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( ...@@ -14,7 +14,11 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, MPLinearLayerConfig,
choose_mp_linear_kernel, choose_mp_linear_kernel,
) )
from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import (
MarlinLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
marlin_repeat_scales_on_all_ranks, marlin_repeat_scales_on_all_ranks,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
...@@ -45,12 +49,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -45,12 +49,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
group_size: int | None = None, group_size: int | None = None,
symmetric: bool | None = True, symmetric: bool | None = True,
actorder: ActivationOrdering | None = None, actorder: ActivationOrdering | None = None,
layer_name: str | None = None,
): ):
self.pack_factor = 32 // num_bits self.pack_factor = 32 // num_bits
self.strategy = strategy self.strategy = strategy
self.symmetric = symmetric self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP self.has_g_idx = actorder == ActivationOrdering.GROUP
self.layer_name = layer_name
if self.group_size == -1 and self.strategy != "channel": if self.group_size == -1 and self.strategy != "channel":
raise ValueError( raise ValueError(
...@@ -108,6 +114,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -108,6 +114,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
logger.info("Using %s for CompressedTensorsWNA16", kernel_type.__name__) logger.info("Using %s for CompressedTensorsWNA16", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__)
if isinstance(kernel_type, MarlinLinearKernel):
input_dtype = get_marlin_input_dtype(self.layer_name)
if input_dtype is not None:
mp_linear_kernel_config.act_type = input_dtype
# If group_size is -1, we are in channelwise case. # If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = input_size != input_size_per_partition row_parallel = input_size != input_size_per_partition
......
...@@ -69,6 +69,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -69,6 +69,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_tensor_strategy, process_fp8_weight_tensor_strategy,
validate_fp8_block_shape, validate_fp8_block_shape,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin, prepare_fp8_layer_for_marlin,
...@@ -316,7 +319,9 @@ class Fp8Config(QuantizationConfig): ...@@ -316,7 +319,9 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping, fused_mapping=self.packed_modules_mapping,
): ):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
return Fp8LinearMethod(self) quant_method = Fp8LinearMethod(self)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
if is_layer_skipped( if is_layer_skipped(
prefix=prefix, prefix=prefix,
...@@ -324,7 +329,9 @@ class Fp8Config(QuantizationConfig): ...@@ -324,7 +329,9 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping, fused_mapping=self.packed_modules_mapping,
): ):
return UnquantizedFusedMoEMethod(layer.moe_config) return UnquantizedFusedMoEMethod(layer.moe_config)
return Fp8MoEMethod(self, layer) moe_quant_method = Fp8MoEMethod(self, layer)
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self) return Fp8KVCacheMethod(self)
return None return None
...@@ -375,6 +382,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -375,6 +382,7 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin # For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
self.marlin_input_dtype = None
self.use_marlin = ( self.use_marlin = (
not current_platform.has_device_capability(89) not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN or envs.VLLM_TEST_FORCE_FP8_MARLIN
...@@ -552,7 +560,9 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -552,7 +560,9 @@ class Fp8LinearMethod(LinearMethodBase):
) )
if self.use_marlin: if self.use_marlin:
prepare_fp8_layer_for_marlin(layer, size_k_first) prepare_fp8_layer_for_marlin(
layer, size_k_first, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.input_scale del layer.input_scale
return return
...@@ -610,6 +620,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -610,6 +620,7 @@ class Fp8LinearMethod(LinearMethodBase):
workspace=layer.workspace, workspace=layer.workspace,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
input_dtype=self.marlin_input_dtype,
bias=bias, bias=bias,
) )
...@@ -657,6 +668,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -657,6 +668,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.block_quant, layer.moe_parallel_config self.block_quant, layer.moe_parallel_config
) )
self.marlin_input_dtype = None
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
...@@ -1031,7 +1043,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1031,7 +1043,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight.data = w13_weight.data layer.w13_weight.data = w13_weight.data
if self.use_marlin: if self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False) prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
...@@ -1270,6 +1284,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1270,6 +1284,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace, workspace=layer.workspace,
) )
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
......
...@@ -41,6 +41,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import ( ...@@ -41,6 +41,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_marlin_supported,
check_moe_marlin_supports_layer, check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_moe_permute_scales,
marlin_permute_bias, marlin_permute_bias,
...@@ -251,8 +253,21 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -251,8 +253,21 @@ class GPTQMarlinConfig(QuantizationConfig):
return MoeWNA16Config.from_config(self.full_config).get_quant_method( return MoeWNA16Config.from_config(self.full_config).get_quant_method(
layer, prefix layer, prefix
) )
return get_moe_quant_method(self, layer, prefix, GPTQMarlinMoEMethod) moe_quant_method = get_moe_quant_method(
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) self, layer, prefix, GPTQMarlinMoEMethod
)
if moe_quant_method is None:
return None
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
quant_method = get_linear_quant_method(
self, layer, prefix, GPTQMarlinLinearMethod
)
if quant_method is None:
return None
quant_method.input_dtype = get_marlin_input_dtype(prefix)
return quant_method
@classmethod @classmethod
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]): def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
...@@ -319,6 +334,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -319,6 +334,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def __init__(self, quant_config: GPTQMarlinConfig) -> None: def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config self.quant_config = quant_config
self.input_dtype = None
self.quant_type = self.quant_config.quant_type
# Verify supported on platform. # Verify supported on platform.
verify_marlin_supported( verify_marlin_supported(
...@@ -339,6 +356,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -339,6 +356,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
input_dtype = self.input_dtype
mp_linear_kernel_config = MPLinearLayerConfig( mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size), full_weight_shape=(input_size, output_size),
...@@ -347,7 +365,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -347,7 +365,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition, output_size_per_partition,
), ),
weight_type=self.quant_config.quant_type, weight_type=self.quant_config.quant_type,
act_type=params_dtype, act_type=params_dtype if input_dtype is None else input_dtype,
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
zero_points=False, zero_points=False,
has_g_idx=self.quant_config.desc_act, has_g_idx=self.quant_config.desc_act,
...@@ -482,6 +500,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -482,6 +500,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
self.quant_type = scalar_types.uint8b128 self.quant_type = scalar_types.uint8b128
else: else:
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.") raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
self.input_dtype = None
self.use_marlin = True self.use_marlin = True
def create_weights( def create_weights(
...@@ -493,6 +512,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -493,6 +512,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
layer.input_dtype = self.input_dtype
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
self.is_k_full = (not self.quant_config.desc_act) or ( self.is_k_full = (not self.quant_config.desc_act) or (
...@@ -513,6 +540,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -513,6 +540,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
scales_size2 = 1 scales_size2 = 1
strategy = FusedMoeWeightScaleSupported.CHANNEL.value strategy = FusedMoeWeightScaleSupported.CHANNEL.value
layer.num_groups_w13 = scales_size13
layer.num_groups_w2 = scales_size2
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True}) extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_qweight = torch.nn.Parameter( w13_qweight = torch.nn.Parameter(
...@@ -630,6 +660,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -630,6 +660,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.workspace = marlin_make_workspace_new(device, 4) layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(layer.w13_qweight, inplace=True)
ops.marlin_int4_fp8_preprocess(layer.w2_qweight, inplace=True)
layer.w13_scales.data = layer.w13_scales.data * 512
layer.w2_scales.data = layer.w2_scales.data * 512
# Process act_order # Process act_order
if self.quant_config.desc_act: if self.quant_config.desc_act:
# Get sorting based on g_idx # Get sorting based on g_idx
...@@ -678,6 +721,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -678,6 +721,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w13_qweight.shape[1] * self.quant_config.pack_factor, layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2], layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits, self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w13_qweight", marlin_w13_qweight) replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack( marlin_w2_qweight = ops.gptq_marlin_moe_repack(
...@@ -686,6 +730,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -686,6 +730,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_qweight.shape[1] * self.quant_config.pack_factor, layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2], layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits, self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w2_qweight", marlin_w2_qweight) replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales # Repack scales
...@@ -694,7 +739,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -694,7 +739,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
size_k=layer.intermediate_size_per_partition, size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2], size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
) )
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_scales", marlin_w13_scales) replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales( marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales, s=layer.w2_scales,
...@@ -706,7 +761,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -706,7 +761,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
), ),
size_n=layer.w2_scales.shape[2], size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
) )
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_scales", marlin_w2_scales) replace_parameter(layer, "w2_scales", marlin_w2_scales)
if hasattr(layer, "w13_bias") and layer.w13_bias is not None: if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
...@@ -761,6 +826,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -761,6 +826,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits, router_logits,
topk_weights, topk_weights,
topk_ids, topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id, quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -771,4 +838,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -771,4 +838,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
sort_indices2=layer.w2_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace, workspace=layer.workspace,
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
) )
...@@ -351,6 +351,7 @@ class HQQMarlinMethod(LinearMethodBase): ...@@ -351,6 +351,7 @@ class HQQMarlinMethod(LinearMethodBase):
bias, bias,
scales, scales,
None, None,
None,
zeros, zeros,
layer.g_idx, layer.g_idx,
layer.g_idx_sort_indices, layer.g_idx_sort_indices,
......
...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_GROUP_SIZES,
apply_gptq_marlin_linear, apply_gptq_marlin_linear,
check_marlin_supports_shape, check_marlin_supports_shape,
marlin_act_int8_process_scales,
marlin_is_k_full, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_make_workspace_new,
...@@ -21,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -21,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
) )
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
...@@ -65,6 +67,18 @@ class MarlinLinearKernel(MPLinearKernel): ...@@ -65,6 +67,18 @@ class MarlinLinearKernel(MPLinearKernel):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device device = getattr(layer, self.w_q_name).device
c = self.config c = self.config
is_a_8bit = c.act_type is not None and c.act_type.itemsize == 1
if is_a_8bit:
assert c.weight_type == scalar_types.uint4b8, (
"W8A8 is not supported by marlin kernel."
)
if c.act_type == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(getattr(layer, self.w_q_name), inplace=True)
getattr(layer, self.w_s_name).data = (
getattr(layer, self.w_s_name).data * 512
)
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0] row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
...@@ -88,6 +102,7 @@ class MarlinLinearKernel(MPLinearKernel): ...@@ -88,6 +102,7 @@ class MarlinLinearKernel(MPLinearKernel):
size_k=c.partition_weight_shape[0], size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1], size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits, num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
) )
return x return x
...@@ -99,7 +114,22 @@ class MarlinLinearKernel(MPLinearKernel): ...@@ -99,7 +114,22 @@ class MarlinLinearKernel(MPLinearKernel):
size_k=c.partition_weight_shape[0], size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1], size_n=c.partition_weight_shape[1],
group_size=c.group_size, group_size=c.group_size,
is_a_8bit=is_a_8bit,
) )
if c.group_size == -1:
num_groups = 1
else:
num_groups = c.partition_weight_shape[0] // c.group_size
if c.act_type == torch.int8 and num_groups > 1:
x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
layer.register_parameter(
"input_global_scale",
torch.nn.Parameter(input_global_scale, requires_grad=False),
)
else:
layer.input_global_scale = None
return x return x
if c.has_g_idx: if c.has_g_idx:
...@@ -129,6 +159,7 @@ class MarlinLinearKernel(MPLinearKernel): ...@@ -129,6 +159,7 @@ class MarlinLinearKernel(MPLinearKernel):
size_k=grouped_k, size_k=grouped_k,
size_n=c.partition_weight_shape[1], size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits, num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
), ),
) )
else: else:
...@@ -150,6 +181,7 @@ class MarlinLinearKernel(MPLinearKernel): ...@@ -150,6 +181,7 @@ class MarlinLinearKernel(MPLinearKernel):
# `process_weights_after_loading` will ensure w_zp and w_gidx are not # `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin # None for marlin
return apply_gptq_marlin_linear( return apply_gptq_marlin_linear(
input=x, input=x,
weight=w_q, weight=w_q,
...@@ -162,5 +194,7 @@ class MarlinLinearKernel(MPLinearKernel): ...@@ -162,5 +194,7 @@ class MarlinLinearKernel(MPLinearKernel):
input_size_per_partition=c.partition_weight_shape[0], input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1], output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
input_global_scale=getattr(layer, "input_global_scale", None),
bias=bias, bias=bias,
input_dtype=c.act_type,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment