Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
#pragma once
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/**
* This file defines Gemm kernel configurations for SM100 (fp8) based on the
* Gemm shape.
*/
namespace vllm {
using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_default {
// M in (256, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_256, _128, _128>;
using ClusterShape = Shape<_2, _2, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M256 {
// M in (64, 256]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M64 {
// M in (16, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M16 {
// M in [1, 16]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
using Cutlass3xGemmDefault =
typename sm100_fp8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM16 =
typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM256 =
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// m in [1, 16]
return cutlass_gemm_caller<Cutlass3xGemmM16>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// m in (16, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// m in (64, 256]
return cutlass_gemm_caller<Cutlass3xGemmM256>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (256, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
} // namespace vllm
\ No newline at end of file
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include "core/registration.h"
#include <torch/all.h> #include <torch/all.h>
#include <cutlass/arch/arch.h> #include <cutlass/arch/arch.h>
...@@ -418,3 +420,7 @@ void cutlass_fp4_group_mm( ...@@ -418,3 +420,7 @@ void cutlass_fp4_group_mm(
"12.8 or above."); "12.8 or above.");
#endif #endif
} }
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
}
...@@ -31,6 +31,13 @@ ...@@ -31,6 +31,13 @@
namespace vllm { namespace vllm {
template <typename Int>
__host__ __device__ inline Int round_up(Int x, Int y) {
static_assert(std::is_integral_v<Int>,
"round_up argument must be integral type");
return (x + y - 1) / y * y;
}
// Use UE4M3 by default. // Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false> template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
...@@ -42,10 +49,21 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -42,10 +49,21 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched."); "Vec size is not matched.");
int sf_m = round_up<int>(numRows, 128);
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) {
// Each thread writes 4 uint32_t elements.
for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_int;
col += blockDim.x * 4) {
SFout[row * sf_n_int + col] = 0x00;
}
}
// Get the global scaling factor, which will be applied to the SF. // Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is // Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)). // (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0];
// Input tensor row/col loops. // Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
...@@ -64,7 +82,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -64,7 +82,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
rowIdx, colIdx, numCols, SFout); rowIdx, colIdx, numCols, SFout);
out_pos = out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out); cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
} }
} }
} }
......
...@@ -159,7 +159,11 @@ void rms_norm_dynamic_per_token_quant( ...@@ -159,7 +159,11 @@ void rms_norm_dynamic_per_token_quant(
if (scale_ub.has_value()) { if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type); TORCH_CHECK(out.dtype() == kFp8Type);
} }
TORCH_CHECK(weight.dtype() == input.dtype());
TORCH_CHECK(scales.dtype() == torch::kFloat32); TORCH_CHECK(scales.dtype() == torch::kFloat32);
if (residual) {
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
}
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] { input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh" #include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead // TODO(luka/varun):refactor common.cuh to use this file instead
// #include "quantization/fp8/common.cuh" // #include "quantization/w8a8/fp8/common.cuh"
namespace vllm { namespace vllm {
......
...@@ -189,7 +189,7 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*, ...@@ -189,7 +189,7 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*,
const uint32_t*, const half*, const uint32_t*, const half*,
half*, const int, const int, half*, const int, const int,
const int, const int, const int, const int,
const int*); const bool, const int*);
template <bool first_block, int m_count> template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_4bit_kernel( __global__ void gemm_half_q_half_gptq_4bit_kernel(
...@@ -197,12 +197,15 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( ...@@ -197,12 +197,15 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c, const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups, const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) { const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n); MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x; auto t = threadIdx.x;
// Block // Block
...@@ -260,10 +263,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( ...@@ -260,10 +263,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
half2 y1y16[4][2]; half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n); b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
// Column result // Column result
float block_c[m_count][4] = {}; float block_c[m_count][4] = {};
...@@ -276,10 +279,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( ...@@ -276,10 +279,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
nextgroup += groupsize; nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n); b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
} }
#pragma unroll #pragma unroll
...@@ -333,12 +336,15 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( ...@@ -333,12 +336,15 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c, const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups, const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) { const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n); MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x; auto t = threadIdx.x;
// Block // Block
...@@ -413,10 +419,10 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( ...@@ -413,10 +419,10 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
int4 load_int4 = *b_ptr4; int4 load_int4 = *b_ptr4;
half2 dq[4][8]; half2 dq[4][8];
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset);
#pragma unroll #pragma unroll
for (int m = 0; m < m_count; m++) { for (int m = 0; m < m_count; m++) {
...@@ -452,12 +458,15 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( ...@@ -452,12 +458,15 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c, const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups, const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) { const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n); MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x; auto t = threadIdx.x;
// Block // Block
...@@ -538,13 +547,13 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( ...@@ -538,13 +547,13 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
half2 dq[4][16]; half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0],
size_n, zeros[0] + 1); size_n, zeros[0] + zero_offset);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1],
size_n, zeros[1] + 1); size_n, zeros[1] + zero_offset);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2],
size_n, zeros[2] + 1); size_n, zeros[2] + zero_offset);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3],
size_n, zeros[3] + 1); size_n, zeros[3] + zero_offset);
#pragma unroll #pragma unroll
for (int m = 0; m < m_count; m++) { for (int m = 0; m < m_count; m++) {
...@@ -578,12 +587,15 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( ...@@ -578,12 +587,15 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c, const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups, const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) { const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n); MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x; auto t = threadIdx.x;
// Block // Block
...@@ -662,13 +674,13 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( ...@@ -662,13 +674,13 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
half2 dq[4][4]; half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n,
zeros[0] + 1); zeros[0] + zero_offset);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n,
zeros[1] + 1); zeros[1] + zero_offset);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n,
zeros[2] + 1); zeros[2] + zero_offset);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n,
zeros[3] + 1); zeros[3] + zero_offset);
for (int m = 0; m < m_count; m++) { for (int m = 0; m < m_count; m++) {
block_c[m][0] = block_c[m][0] =
...@@ -734,7 +746,8 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, ...@@ -734,7 +746,8 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros, const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_q_perm, const half* b_gptq_scales, const int* b_q_perm,
half* c, int size_m, int size_n, int size_k, half* c, int size_m, int size_n, int size_k,
int m_count, int groups, int bit) { int m_count, int groups, bool use_v2_format,
int bit) {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE; blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
...@@ -747,20 +760,23 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, ...@@ -747,20 +760,23 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight,
pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); pick_gemm_half_q_half_gptq_kernel(true, m_count, bit);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>(a, b_q_weight, b_gptq_qzeros, kernel<<<gridDim, blockDim, 0, stream>>>(
b_gptq_scales, c, size_m, size_n, a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n, size_k,
size_k, groups, b_q_perm); groups, use_v2_format, b_q_perm);
} }
__global__ void reconstruct_exllama_8bit_kernel( __global__ void reconstruct_exllama_8bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) { const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n); MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
...@@ -816,13 +832,13 @@ __global__ void reconstruct_exllama_8bit_kernel( ...@@ -816,13 +832,13 @@ __global__ void reconstruct_exllama_8bit_kernel(
half2 dq[4][4]; half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n,
zeros[0] + 1); zeros[0] + zero_offset);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n,
zeros[1] + 1); zeros[1] + zero_offset);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n,
zeros[2] + 1); zeros[2] + zero_offset);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n,
zeros[3] + 1); zeros[3] + zero_offset);
// half* dqh = (half*)dq; // half* dqh = (half*)dq;
if (b_q_perm) { if (b_q_perm) {
...@@ -853,11 +869,14 @@ __global__ void reconstruct_exllama_4bit_kernel( ...@@ -853,11 +869,14 @@ __global__ void reconstruct_exllama_4bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) { const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n); MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
...@@ -892,10 +911,10 @@ __global__ void reconstruct_exllama_4bit_kernel( ...@@ -892,10 +911,10 @@ __global__ void reconstruct_exllama_4bit_kernel(
half2 y1y16[4][2]; half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n); b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
__syncthreads(); __syncthreads();
...@@ -908,10 +927,10 @@ __global__ void reconstruct_exllama_4bit_kernel( ...@@ -908,10 +927,10 @@ __global__ void reconstruct_exllama_4bit_kernel(
nextgroup += groupsize; nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n); b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
} }
for (int p = 0; p < 4; p++) { for (int p = 0; p < 4; p++) {
...@@ -958,11 +977,14 @@ __global__ void reconstruct_exllama_3bit_kernel( ...@@ -958,11 +977,14 @@ __global__ void reconstruct_exllama_3bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) { const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n); MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
...@@ -1020,13 +1042,13 @@ __global__ void reconstruct_exllama_3bit_kernel( ...@@ -1020,13 +1042,13 @@ __global__ void reconstruct_exllama_3bit_kernel(
half2 dq[4][16]; half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0],
size_n, zeros[0] + 1); size_n, zeros[0] + zero_offset);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1],
size_n, zeros[1] + 1); size_n, zeros[1] + zero_offset);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2],
size_n, zeros[2] + 1); size_n, zeros[2] + zero_offset);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3],
size_n, zeros[3] + 1); size_n, zeros[3] + zero_offset);
if (b_q_perm) { if (b_q_perm) {
for (int j = 0; j < 16; j++) { for (int j = 0; j < 16; j++) {
...@@ -1056,11 +1078,14 @@ __global__ void reconstruct_exllama_2bit_kernel( ...@@ -1056,11 +1078,14 @@ __global__ void reconstruct_exllama_2bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) { const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n); MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
...@@ -1112,10 +1137,10 @@ __global__ void reconstruct_exllama_2bit_kernel( ...@@ -1112,10 +1137,10 @@ __global__ void reconstruct_exllama_2bit_kernel(
int4 load_int4 = *b_ptr4; int4 load_int4 = *b_ptr4;
half2 dq[4][8]; half2 dq[4][8];
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset);
b_ptr += size_n; b_ptr += size_n;
// half* dqh = (half*)dq; // half* dqh = (half*)dq;
...@@ -1147,7 +1172,7 @@ void reconstruct_exllama(const uint32_t* b_q_weight, ...@@ -1147,7 +1172,7 @@ void reconstruct_exllama(const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros, const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_q_perm, const half* b_gptq_scales, const int* b_q_perm,
half* out, int height, int width, int groups, half* out, int height, int width, int groups,
int bit) { bool use_v2_format, int bit) {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE; blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
...@@ -1166,14 +1191,14 @@ void reconstruct_exllama(const uint32_t* b_q_weight, ...@@ -1166,14 +1191,14 @@ void reconstruct_exllama(const uint32_t* b_q_weight,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>( reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>(
b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups,
out); use_v2_format, out);
} }
__global__ void gemm_half_q_half_alt_4bit_kernel( __global__ void gemm_half_q_half_alt_4bit_kernel(
const half2* __restrict__ vec, const uint32_t* __restrict__ mat, const half2* __restrict__ vec, const uint32_t* __restrict__ mat,
half* __restrict__ mul, const half* __restrict__ scales, half* __restrict__ mul, const half* __restrict__ scales,
const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx,
int batch, int height, int width) { int batch, int height, int width, bool use_v2_format) {
int zero_width = width / 8; int zero_width = width / 8;
int vec_height = height * 4; int vec_height = height * 4;
const int blockwidth2 = BLOCK_KN_SIZE / 2; const int blockwidth2 = BLOCK_KN_SIZE / 2;
...@@ -1183,6 +1208,9 @@ __global__ void gemm_half_q_half_alt_4bit_kernel( ...@@ -1183,6 +1208,9 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) { if (threadIdx.x < h_end) {
for (int m = 0; m < b_end; ++m) { for (int m = 0; m < b_end; ++m) {
...@@ -1227,10 +1255,11 @@ __global__ void gemm_half_q_half_alt_4bit_kernel( ...@@ -1227,10 +1255,11 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
half2 zero = __halves2half2( half2 zero = __halves2half2(
__hmul(scale_f, __hmul(scale_f,
__int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) -
1)), zero_offset)),
__hmul(scale_f2, __hmul(
__int2half_rn( scale_f2,
-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))); __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) -
zero_offset)));
scales_tmp[tmp_k] = scale; scales_tmp[tmp_k] = scale;
zeros_tmp[tmp_k] = zero; zeros_tmp[tmp_k] = zero;
} }
...@@ -1272,7 +1301,7 @@ __global__ void gemm_half_q_half_alt_8bit_kernel( ...@@ -1272,7 +1301,7 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
const half2* __restrict__ vec, const uint32_t* __restrict__ mat, const half2* __restrict__ vec, const uint32_t* __restrict__ mat,
half* __restrict__ mul, const half* __restrict__ scales, half* __restrict__ mul, const half* __restrict__ scales,
const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx,
int batch, int height, int width) { int batch, int height, int width, bool use_v2_format) {
int zero_width = width / 4; int zero_width = width / 4;
int vec_height = height * 2; int vec_height = height * 2;
const int blockwidth2 = BLOCK_KN_SIZE / 2; const int blockwidth2 = BLOCK_KN_SIZE / 2;
...@@ -1282,6 +1311,9 @@ __global__ void gemm_half_q_half_alt_8bit_kernel( ...@@ -1282,6 +1311,9 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) { if (threadIdx.x < h_end) {
for (int m = 0; m < b_end; ++m) { for (int m = 0; m < b_end; ++m) {
...@@ -1316,12 +1348,13 @@ __global__ void gemm_half_q_half_alt_8bit_kernel( ...@@ -1316,12 +1348,13 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
half scale_f2 = scales[g2 * width + w]; half scale_f2 = scales[g2 * width + w];
half2 scale = __halves2half2(scale_f, scale_f2); half2 scale = __halves2half2(scale_f, scale_f2);
half2 zero = __halves2half2( half2 zero = __halves2half2(
__hmul(scale_f, __hmul(scale_f, __int2half_rn(
__int2half_rn( -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) -
-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), zero_offset)),
__hmul(scale_f2, __hmul(
__int2half_rn( scale_f2,
-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))); __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) -
zero_offset)));
scales_tmp[tmp_k] = scale; scales_tmp[tmp_k] = scale;
zeros_tmp[tmp_k] = zero; zeros_tmp[tmp_k] = zero;
} }
...@@ -1359,7 +1392,7 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, ...@@ -1359,7 +1392,7 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros, const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_g_idx, const half* b_gptq_scales, const int* b_g_idx,
half* c, int size_m, int size_n, int size_k, half* c, int size_m, int size_n, int size_k,
int bit) { bool use_v2_format, int bit) {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE; blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
...@@ -1376,17 +1409,15 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, ...@@ -1376,17 +1409,15 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>( kernel<<<gridDim, blockDim, 0, stream>>>(
(const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx,
size_m, size_k / 32 * bit, size_n); size_m, size_k / 32 * bit, size_n, use_v2_format);
} }
template <class T, int bit> template <class T, int bit>
__global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, __global__ void reconstruct_gptq_kernel(
const half* __restrict__ w_scales, const uint32_t* __restrict__ w, const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros, const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx,
const int* __restrict__ g_idx, const int height, const int width, const int group,
const int height, const int width, const bool use_v2_format, half* __restrict__ out) {
const int group,
half* __restrict__ out) {
// Start of block // Start of block
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
...@@ -1399,6 +1430,9 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, ...@@ -1399,6 +1430,9 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
MatrixView_half w_scales_(w_scales, group, width); MatrixView_half w_scales_(w_scales, group, width);
T w_zeros_(w_zeros, group, width); T w_zeros_(w_zeros, group, width);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
uint32_t w_read = w[blockIdx.y * width + column]; uint32_t w_read = w[blockIdx.y * width + column];
half* out_ptr = out_.item_ptr(row, column); half* out_ptr = out_.item_ptr(row, column);
...@@ -1406,7 +1440,7 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, ...@@ -1406,7 +1440,7 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
for (int s = 0; s < 32; s += bit) { for (int s = 0; s < 32; s += bit) {
int group = g_idx[row + s / bit]; int group = g_idx[row + s / bit];
half w_scale = w_scales_.item(group, column); half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1; uint32_t w_zero = w_zeros_.item(group, column) + zero_offset;
half w_item = half w_item =
__hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero),
w_scale); w_scale);
...@@ -1419,7 +1453,7 @@ __global__ void reconstruct_gptq_3bit_kernel( ...@@ -1419,7 +1453,7 @@ __global__ void reconstruct_gptq_3bit_kernel(
const uint32_t* __restrict__ w, const half* __restrict__ w_scales, const uint32_t* __restrict__ w, const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx, const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx,
const int height, const int width, const int group, const int height, const int width, const int group,
half* __restrict__ out) { const bool use_v2_format, half* __restrict__ out) {
// Start of block // Start of block
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
auto row = blockIdx.y * 32; auto row = blockIdx.y * 32;
...@@ -1431,6 +1465,9 @@ __global__ void reconstruct_gptq_3bit_kernel( ...@@ -1431,6 +1465,9 @@ __global__ void reconstruct_gptq_3bit_kernel(
MatrixView_half w_scales_(w_scales, group, width); MatrixView_half w_scales_(w_scales, group, width);
MatrixView_q3_row w_zeros_(w_zeros, group, width); MatrixView_q3_row w_zeros_(w_zeros, group, width);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
uint32_t w1 = w[(blockIdx.y * 3) * width + column]; uint32_t w1 = w[(blockIdx.y * 3) * width + column];
uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column];
uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column];
...@@ -1440,7 +1477,7 @@ __global__ void reconstruct_gptq_3bit_kernel( ...@@ -1440,7 +1477,7 @@ __global__ void reconstruct_gptq_3bit_kernel(
for (int i = 0; i < 32; i += 1) { for (int i = 0; i < 32; i += 1) {
int group = g_idx[row + i]; int group = g_idx[row + i];
half w_scale = w_scales_.item(group, column); half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1; uint32_t w_zero = w_zeros_.item(group, column) + zero_offset;
int w_item; int w_item;
if (i == 10) { if (i == 10) {
w_item = (w1 >> 30) | ((w2 << 2) & 0x4); w_item = (w1 >> 30) | ((w2 << 2) & 0x4);
...@@ -1460,7 +1497,8 @@ __global__ void reconstruct_gptq_3bit_kernel( ...@@ -1460,7 +1497,8 @@ __global__ void reconstruct_gptq_3bit_kernel(
void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_g_idx, half* out, const half* b_gptq_scales, const int* b_g_idx, half* out,
int height, int width, int groups, int bit) { int height, int width, int groups, bool use_v2_format,
int bit) {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE; blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
...@@ -1480,7 +1518,7 @@ void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, ...@@ -1480,7 +1518,7 @@ void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>(b_q_weight, b_gptq_scales, kernel<<<gridDim, blockDim, 0, stream>>>(b_q_weight, b_gptq_scales,
b_gptq_qzeros, b_g_idx, height, b_gptq_qzeros, b_g_idx, height,
width, groups, out); width, groups, use_v2_format, out);
} }
void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
...@@ -1488,7 +1526,8 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, ...@@ -1488,7 +1526,8 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
const uint32_t* b_gptq_qzeros, const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_g_idx, const half* b_gptq_scales, const int* b_g_idx,
half* c, half* temp_dq, int size_m, int size_n, half* c, half* temp_dq, int size_m, int size_n,
int size_k, int groups, bool use_exllama, int bit) { int size_k, int groups, bool use_exllama,
bool use_v2_format, int bit) {
bool use_reconstruct; bool use_reconstruct;
if (use_exllama) { if (use_exllama) {
use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) ||
...@@ -1502,10 +1541,10 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, ...@@ -1502,10 +1541,10 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
// Reconstruct FP16 matrix, then cuBLAS // Reconstruct FP16 matrix, then cuBLAS
if (use_exllama) { if (use_exllama) {
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups, bit); temp_dq, size_k, size_n, groups, use_v2_format, bit);
} else { } else {
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups, bit); temp_dq, size_k, size_n, groups, use_v2_format, bit);
} }
const half alpha = __float2half(1.0f); const half alpha = __float2half(1.0f);
...@@ -1521,18 +1560,18 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, ...@@ -1521,18 +1560,18 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
if (max_chunks) { if (max_chunks) {
gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, c, last_chunk, size_n, size_k, b_g_idx, c, last_chunk, size_n, size_k,
BLOCK_M_SIZE_MAX, groups, bit); BLOCK_M_SIZE_MAX, groups, use_v2_format, bit);
} }
if (last_chunk_size) { if (last_chunk_size) {
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, gemm_half_q_half_cuda_part(
b_gptq_qzeros, b_gptq_scales, b_g_idx, a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, b_gptq_scales,
c + last_chunk * size_n, last_chunk_size, b_g_idx, c + last_chunk * size_n, last_chunk_size, size_n, size_k,
size_n, size_k, last_chunk_size, groups, bit); last_chunk_size, groups, use_v2_format, bit);
} }
} else { } else {
gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
c, size_m, size_n, size_k, bit); c, size_m, size_n, size_k, use_v2_format, bit);
} }
} }
...@@ -1819,7 +1858,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, ...@@ -1819,7 +1858,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit) { bool use_exllama, bool use_v2_format, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
...@@ -1837,7 +1876,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ...@@ -1837,7 +1876,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
c.size(1), // n c.size(1), // n
a.size(1), // k a.size(1), // k
b_gptq_qzeros.size(0), // group number b_gptq_qzeros.size(0), // group number
use_exllama, bit); use_exllama, use_v2_format, bit);
return c; return c;
} }
......
...@@ -247,22 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, ...@@ -247,22 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
return out; return out;
} }
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack); m.impl("awq_marlin_repack", &awq_marlin_repack);
} }
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
}
...@@ -17,28 +17,32 @@ FILE_HEAD = """ ...@@ -17,28 +17,32 @@ FILE_HEAD = """
namespace MARLIN_NAMESPACE_NAME { namespace MARLIN_NAMESPACE_NAME {
""".strip() """.strip()
TEMPLATE = ("template __global__ void Marlin<" TEMPLATE = (
"{{scalar_t}}, " "template __global__ void Marlin<"
"{{w_type_id}}, " "{{scalar_t}}, "
"{{s_type_id}}, " "{{w_type_id}}, "
"{{threads}}, " "{{s_type_id}}, "
"{{thread_m_blocks}}, " "{{threads}}, "
"{{thread_n_blocks}}, " "{{thread_m_blocks}}, "
"{{thread_k_blocks}}, " "{{thread_n_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, " "{{thread_k_blocks}}, "
"{{stages}}, " "{{'true' if m_block_size_8 else 'false'}}, "
"{{group_blocks}}, " "{{stages}}, "
"{{'true' if is_zp_float else 'false'}}>" "{{group_blocks}}, "
"( MARLIN_KERNEL_PARAMS );") "{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported, # int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size. # we don't add it to reduce wheel size.
SCALAR_TYPES = [ SCALAR_TYPES = [
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", "vllm::kU4",
"vllm::kFE2M1f" "vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
] ]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
(128, 64, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks: # group_blocks:
...@@ -59,11 +63,12 @@ def generate_new_kernels(): ...@@ -59,11 +63,12 @@ def generate_new_kernels():
all_template_str_list = [] all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product( for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
):
# act order case only support gptq-int4 and gptq-int8 # act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [ if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128" "vllm::kU4B8",
"vllm::kU8B128",
]: ]:
continue continue
if thread_configs[2] == 256: if thread_configs[2] == 256:
...@@ -93,8 +98,7 @@ def generate_new_kernels(): ...@@ -93,8 +98,7 @@ def generate_new_kernels():
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
is_zp_float_list = [False] is_zp_float_list = [False]
if dtype == "fp16" and scalar_type == "vllm::kU4" and \ if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4:
group_blocks == 4:
# HQQ (is_zp_float = true) only supports # HQQ (is_zp_float = true) only supports
# 4bit quantization and fp16 # 4bit quantization and fp16
is_zp_float_list.append(True) is_zp_float_list.append(True)
......
...@@ -321,22 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, ...@@ -321,22 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
return out; return out;
} }
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack); m.impl("gptq_marlin_repack", &gptq_marlin_repack);
} }
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
}
...@@ -802,7 +802,7 @@ torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace) { ...@@ -802,7 +802,7 @@ torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace) {
}); });
if (numel % 256 != 0) { if (numel % 256 != 0) {
out = out.index({torch::indexing::Slice(0, numel / had_size)}); out = out.narrow(0, 0, numel / had_size);
} }
if (inplace && out.data_ptr() != x.data_ptr()) { if (inplace && out.data_ptr() != x.data_ptr()) {
......
...@@ -9,23 +9,23 @@ from collections.abc import Iterable ...@@ -9,23 +9,23 @@ from collections.abc import Iterable
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from functools import reduce from functools import reduce
from typing import Optional, Union
import jinja2 import jinja2
# yapf conflicts with isort for this block from vllm_cutlass_library_extension import (
# yapf: disable DataType,
from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag, EpilogueScheduleTag,
EpilogueScheduleType, EpilogueScheduleType,
MixedInputKernelScheduleType, MixedInputKernelScheduleType,
TileSchedulerTag, TileSchedulerTag,
TileSchedulerType, VLLMDataType, TileSchedulerType,
VLLMDataTypeNames, VLLMDataType,
VLLMDataTypeSize, VLLMDataTypeTag, VLLMDataTypeNames,
VLLMDataTypeTorchDataTypeTag, VLLMDataTypeSize,
VLLMDataTypeVLLMScalarTypeTag, VLLMDataTypeTag,
VLLMKernelScheduleTag) VLLMDataTypeTorchDataTypeTag,
VLLMDataTypeVLLMScalarTypeTag,
# yapf: enable VLLMKernelScheduleTag,
)
# #
# Generator templating # Generator templating
...@@ -258,7 +258,7 @@ class ScheduleConfig: ...@@ -258,7 +258,7 @@ class ScheduleConfig:
@dataclass(frozen=True) @dataclass(frozen=True)
class TypeConfig: class TypeConfig:
a: DataType a: DataType
b: Union[DataType, VLLMDataType] b: DataType | VLLMDataType
b_group_scale: DataType b_group_scale: DataType
b_group_zeropoint: DataType b_group_zeropoint: DataType
b_channel_scale: DataType b_channel_scale: DataType
...@@ -279,25 +279,30 @@ class PrepackTypeConfig: ...@@ -279,25 +279,30 @@ class PrepackTypeConfig:
class ImplConfig: class ImplConfig:
types: TypeConfig types: TypeConfig
schedules: list[ScheduleConfig] schedules: list[ScheduleConfig]
heuristic: list[tuple[Optional[str], ScheduleConfig]] heuristic: list[tuple[str | None, ScheduleConfig]]
def generate_sch_sig(schedule_config: ScheduleConfig) -> str: def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
tile_shape = ( tile_shape = (
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
) )
cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" + cluster_shape = (
f"x{schedule_config.cluster_shape_mnk[1]}" + f"{schedule_config.cluster_shape_mnk[0]}"
f"x{schedule_config.cluster_shape_mnk[2]}") + f"x{schedule_config.cluster_shape_mnk[1]}"
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\ + f"x{schedule_config.cluster_shape_mnk[2]}"
.split("::")[-1] )
epilogue_schedule = EpilogueScheduleTag[ kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule].split(
schedule_config.epilogue_schedule].split("::")[-1] "::"
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\ )[-1]
.split("::")[-1] epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split(
"::"
return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + )[-1]
f"_{epilogue_schedule}_{tile_scheduler}") tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
return (
f"{tile_shape}_{cluster_shape}_{kernel_schedule}"
+ f"_{epilogue_schedule}_{tile_scheduler}"
)
# mostly unique shorter sch_sig # mostly unique shorter sch_sig
...@@ -316,18 +321,24 @@ def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: ...@@ -316,18 +321,24 @@ def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
# unique type_name # unique type_name
def generate_type_signature(kernel_types: TypeConfig): def generate_type_signature(kernel_types: TypeConfig):
return str("".join([ return str(
VLLMDataTypeNames[getattr(kernel_types, field.name)] "".join(
for field in fields(TypeConfig) [
])) VLLMDataTypeNames[getattr(kernel_types, field.name)]
for field in fields(TypeConfig)
]
)
)
def generate_type_option_name(kernel_types: TypeConfig): def generate_type_option_name(kernel_types: TypeConfig):
return ", ".join([ return ", ".join(
f"{field.name.replace('b_', 'with_')+'_type'}=" + [
VLLMDataTypeNames[getattr(kernel_types, field.name)] f"{field.name.replace('b_', 'with_') + '_type'}="
for field in fields(TypeConfig) + VLLMDataTypeNames[getattr(kernel_types, field.name)]
]) for field in fields(TypeConfig)
]
)
def is_power_of_two(n): def is_power_of_two(n):
...@@ -335,7 +346,6 @@ def is_power_of_two(n): ...@@ -335,7 +346,6 @@ def is_power_of_two(n):
def to_cute_constant(value: list[int]): def to_cute_constant(value: list[int]):
def _to_cute_constant(value: int): def _to_cute_constant(value: int):
if is_power_of_two(value): if is_power_of_two(value):
return f"_{value}" return f"_{value}"
...@@ -350,11 +360,11 @@ def to_cute_constant(value: list[int]): ...@@ -350,11 +360,11 @@ def to_cute_constant(value: list[int]):
def unique_schedules(impl_configs: list[ImplConfig]): def unique_schedules(impl_configs: list[ImplConfig]):
# Use dict over set for deterministic ordering # Use dict over set for deterministic ordering
return list({ return list(
sch: None {
for impl_config in impl_configs sch: None for impl_config in impl_configs for sch in impl_config.schedules
for sch in impl_config.schedules }.keys()
}.keys()) )
def unsigned_type_with_bitwidth(num_bits): def unsigned_type_with_bitwidth(num_bits):
...@@ -380,7 +390,7 @@ template_globals = { ...@@ -380,7 +390,7 @@ template_globals = {
"gen_type_sig": generate_type_signature, "gen_type_sig": generate_type_signature,
"unique_schedules": unique_schedules, "unique_schedules": unique_schedules,
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth, "unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
"gen_type_option_name": generate_type_option_name "gen_type_option_name": generate_type_option_name,
} }
...@@ -398,23 +408,28 @@ prepack_dispatch_template = create_template(PREPACK_TEMPLATE) ...@@ -398,23 +408,28 @@ prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
sources = [] sources = []
sources.append(( sources.append(
"machete_mm_dispatch", (
mm_dispatch_template.render(impl_configs=impl_configs), "machete_mm_dispatch",
)) mm_dispatch_template.render(impl_configs=impl_configs),
)
)
prepack_types = [] prepack_types = []
for impl_config in impl_configs: for impl_config in impl_configs:
convert_type = impl_config.types.a \ convert_type = (
if impl_config.types.b_group_scale == DataType.void \ impl_config.types.a
else impl_config.types.b_group_scale if impl_config.types.b_group_scale == DataType.void
else impl_config.types.b_group_scale
)
prepack_types.append( prepack_types.append(
PrepackTypeConfig( PrepackTypeConfig(
a=impl_config.types.a, a=impl_config.types.a,
b_num_bits=VLLMDataTypeSize[impl_config.types.b], b_num_bits=VLLMDataTypeSize[impl_config.types.b],
convert=convert_type, convert=convert_type,
accumulator=impl_config.types.accumulator, accumulator=impl_config.types.accumulator,
)) )
)
def prepacked_type_key(prepack_type: PrepackTypeConfig): def prepacked_type_key(prepack_type: PrepackTypeConfig):
# For now, we can just use the first accumulator type seen since # For now, we can just use the first accumulator type seen since
...@@ -430,10 +445,14 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): ...@@ -430,10 +445,14 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
unique_prepack_types.append(prepack_type) unique_prepack_types.append(prepack_type)
prepack_types_seen.add(key) prepack_types_seen.add(key)
sources.append(( sources.append(
"machete_prepack", (
prepack_dispatch_template.render(types=unique_prepack_types, ), "machete_prepack",
)) prepack_dispatch_template.render(
types=unique_prepack_types,
),
)
)
# Split up impls across files # Split up impls across files
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0) num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
...@@ -466,10 +485,12 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): ...@@ -466,10 +485,12 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
curr_impl_in_file += len(files_impls[-1][-1].schedules) curr_impl_in_file += len(files_impls[-1][-1].schedules)
for part, file_impls in enumerate(files_impls): for part, file_impls in enumerate(files_impls):
sources.append(( sources.append(
f"machete_mm_impl_part{part+1}", (
mm_impl_template.render(impl_configs=file_impls), f"machete_mm_impl_part{part + 1}",
)) mm_impl_template.render(impl_configs=file_impls),
)
)
return sources return sources
...@@ -514,8 +535,7 @@ def generate(): ...@@ -514,8 +535,7 @@ def generate():
# For now we use the same heuristic for all types # For now we use the same heuristic for all types
# Heuristic is currently tuned for H100s # Heuristic is currently tuned for H100s
default_heuristic = [ default_heuristic = [
(cond, ScheduleConfig(*tile_config, (cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
**sch_common_params)) # type: ignore
for cond, tile_config in default_tile_heuristic_config.items() for cond, tile_config in default_tile_heuristic_config.items()
] ]
...@@ -541,14 +561,18 @@ def generate(): ...@@ -541,14 +561,18 @@ def generate():
a_token_scale=DataType.void, a_token_scale=DataType.void,
out=a, out=a,
accumulator=DataType.f32, accumulator=DataType.f32,
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) )
for a in (DataType.f16, DataType.bf16)) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
for a in (DataType.f16, DataType.bf16)
)
impl_configs += [ impl_configs += [
ImplConfig(x[0], x[1], x[2]) ImplConfig(x[0], x[1], x[2])
for x in zip(GPTQ_kernel_type_configs, for x in zip(
itertools.repeat(get_unique_schedules(default_heuristic)), GPTQ_kernel_type_configs,
itertools.repeat(default_heuristic)) itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic),
)
] ]
AWQ_kernel_type_configs = list( AWQ_kernel_type_configs = list(
...@@ -561,14 +585,18 @@ def generate(): ...@@ -561,14 +585,18 @@ def generate():
a_token_scale=DataType.void, a_token_scale=DataType.void,
out=a, out=a,
accumulator=DataType.f32, accumulator=DataType.f32,
) for b in (DataType.u4, DataType.u8) )
for a in (DataType.f16, DataType.bf16)) for b in (DataType.u4, DataType.u8)
for a in (DataType.f16, DataType.bf16)
)
impl_configs += [ impl_configs += [
ImplConfig(x[0], x[1], x[2]) ImplConfig(x[0], x[1], x[2])
for x in zip(AWQ_kernel_type_configs, for x in zip(
itertools.repeat(get_unique_schedules(default_heuristic)), AWQ_kernel_type_configs,
itertools.repeat(default_heuristic)) itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic),
)
] ]
# TODO: Support W4A8 when ready # TODO: Support W4A8 when ready
......
...@@ -231,7 +231,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, ...@@ -231,7 +231,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
} else { } else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>, OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
...@@ -245,7 +245,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, ...@@ -245,7 +245,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
} else { } else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>, OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
...@@ -259,7 +259,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, ...@@ -259,7 +259,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
} else { } else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>, OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm, Shape<_2, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
...@@ -271,10 +271,10 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, ...@@ -271,10 +271,10 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
// TMA epilogue isn't compatible with Swap A/B // TMA epilogue isn't compatible with Swap A/B
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>, OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>( cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
} }
} // namespace vllm } // namespace vllm
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment