Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
#pragma once
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/**
* This file defines Gemm kernel configurations for SM100 (fp8) based on the
* Gemm shape.
*/
namespace vllm {
using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_default {
// M in (256, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_256, _128, _128>;
using ClusterShape = Shape<_2, _2, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M256 {
// M in (64, 256]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M64 {
// M in (16, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M16 {
// M in [1, 16]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
using Cutlass3xGemmDefault =
typename sm100_fp8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM16 =
typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM256 =
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// m in [1, 16]
return cutlass_gemm_caller<Cutlass3xGemmM16>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// m in (16, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// m in (64, 256]
return cutlass_gemm_caller<Cutlass3xGemmM256>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (256, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
} // namespace vllm
\ No newline at end of file
......@@ -14,6 +14,8 @@
* limitations under the License.
*/
#include "core/registration.h"
#include <torch/all.h>
#include <cutlass/arch/arch.h>
......@@ -418,3 +420,7 @@ void cutlass_fp4_group_mm(
"12.8 or above.");
#endif
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
}
......@@ -31,6 +31,13 @@
namespace vllm {
template <typename Int>
__host__ __device__ inline Int round_up(Int x, Int y) {
static_assert(std::is_integral_v<Int>,
"round_up argument must be integral type");
return (x + y - 1) / y * y;
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
......@@ -42,10 +49,21 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
int sf_m = round_up<int>(numRows, 128);
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) {
// Each thread writes 4 uint32_t elements.
for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_int;
col += blockDim.x * 4) {
SFout[row * sf_n_int + col] = 0x00;
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0];
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
......@@ -64,7 +82,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
rowIdx, colIdx, numCols, SFout);
out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
}
}
}
......
......@@ -159,7 +159,11 @@ void rms_norm_dynamic_per_token_quant(
if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
}
TORCH_CHECK(weight.dtype() == input.dtype());
TORCH_CHECK(scales.dtype() == torch::kFloat32);
if (residual) {
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
}
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {
......
......@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead
// #include "quantization/fp8/common.cuh"
// #include "quantization/w8a8/fp8/common.cuh"
namespace vllm {
......
......@@ -189,7 +189,7 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*,
const uint32_t*, const half*,
half*, const int, const int,
const int, const int,
const int*);
const bool, const int*);
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_4bit_kernel(
......@@ -197,12 +197,15 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) {
const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x;
// Block
......@@ -260,10 +263,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
// Column result
float block_c[m_count][4] = {};
......@@ -276,10 +279,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
}
#pragma unroll
......@@ -333,12 +336,15 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) {
const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x;
// Block
......@@ -413,10 +419,10 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
int4 load_int4 = *b_ptr4;
half2 dq[4][8];
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset);
#pragma unroll
for (int m = 0; m < m_count; m++) {
......@@ -452,12 +458,15 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) {
const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x;
// Block
......@@ -538,13 +547,13 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0],
size_n, zeros[0] + 1);
size_n, zeros[0] + zero_offset);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1],
size_n, zeros[1] + 1);
size_n, zeros[1] + zero_offset);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2],
size_n, zeros[2] + 1);
size_n, zeros[2] + zero_offset);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3],
size_n, zeros[3] + 1);
size_n, zeros[3] + zero_offset);
#pragma unroll
for (int m = 0; m < m_count; m++) {
......@@ -578,12 +587,15 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) {
const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x;
// Block
......@@ -662,13 +674,13 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n,
zeros[0] + 1);
zeros[0] + zero_offset);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n,
zeros[1] + 1);
zeros[1] + zero_offset);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n,
zeros[2] + 1);
zeros[2] + zero_offset);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n,
zeros[3] + 1);
zeros[3] + zero_offset);
for (int m = 0; m < m_count; m++) {
block_c[m][0] =
......@@ -734,7 +746,8 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_q_perm,
half* c, int size_m, int size_n, int size_k,
int m_count, int groups, int bit) {
int m_count, int groups, bool use_v2_format,
int bit) {
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
......@@ -747,20 +760,23 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight,
pick_gemm_half_q_half_gptq_kernel(true, m_count, bit);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>(a, b_q_weight, b_gptq_qzeros,
b_gptq_scales, c, size_m, size_n,
size_k, groups, b_q_perm);
kernel<<<gridDim, blockDim, 0, stream>>>(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n, size_k,
groups, use_v2_format, b_q_perm);
}
__global__ void reconstruct_exllama_8bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) {
const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
......@@ -816,13 +832,13 @@ __global__ void reconstruct_exllama_8bit_kernel(
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n,
zeros[0] + 1);
zeros[0] + zero_offset);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n,
zeros[1] + 1);
zeros[1] + zero_offset);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n,
zeros[2] + 1);
zeros[2] + zero_offset);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n,
zeros[3] + 1);
zeros[3] + zero_offset);
// half* dqh = (half*)dq;
if (b_q_perm) {
......@@ -853,11 +869,14 @@ __global__ void reconstruct_exllama_4bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) {
const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
......@@ -892,10 +911,10 @@ __global__ void reconstruct_exllama_4bit_kernel(
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
__syncthreads();
......@@ -908,10 +927,10 @@ __global__ void reconstruct_exllama_4bit_kernel(
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++) {
......@@ -958,11 +977,14 @@ __global__ void reconstruct_exllama_3bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) {
const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
......@@ -1020,13 +1042,13 @@ __global__ void reconstruct_exllama_3bit_kernel(
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0],
size_n, zeros[0] + 1);
size_n, zeros[0] + zero_offset);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1],
size_n, zeros[1] + 1);
size_n, zeros[1] + zero_offset);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2],
size_n, zeros[2] + 1);
size_n, zeros[2] + zero_offset);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3],
size_n, zeros[3] + 1);
size_n, zeros[3] + zero_offset);
if (b_q_perm) {
for (int j = 0; j < 16; j++) {
......@@ -1056,11 +1078,14 @@ __global__ void reconstruct_exllama_2bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) {
const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
......@@ -1112,10 +1137,10 @@ __global__ void reconstruct_exllama_2bit_kernel(
int4 load_int4 = *b_ptr4;
half2 dq[4][8];
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset);
b_ptr += size_n;
// half* dqh = (half*)dq;
......@@ -1147,7 +1172,7 @@ void reconstruct_exllama(const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_q_perm,
half* out, int height, int width, int groups,
int bit) {
bool use_v2_format, int bit) {
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
......@@ -1166,14 +1191,14 @@ void reconstruct_exllama(const uint32_t* b_q_weight,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>(
b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups,
out);
use_v2_format, out);
}
__global__ void gemm_half_q_half_alt_4bit_kernel(
const half2* __restrict__ vec, const uint32_t* __restrict__ mat,
half* __restrict__ mul, const half* __restrict__ scales,
const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx,
int batch, int height, int width) {
int batch, int height, int width, bool use_v2_format) {
int zero_width = width / 8;
int vec_height = height * 4;
const int blockwidth2 = BLOCK_KN_SIZE / 2;
......@@ -1183,6 +1208,9 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) {
for (int m = 0; m < b_end; ++m) {
......@@ -1227,10 +1255,11 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
half2 zero = __halves2half2(
__hmul(scale_f,
__int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) -
1)),
__hmul(scale_f2,
__int2half_rn(
-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)));
zero_offset)),
__hmul(
scale_f2,
__int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) -
zero_offset)));
scales_tmp[tmp_k] = scale;
zeros_tmp[tmp_k] = zero;
}
......@@ -1272,7 +1301,7 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
const half2* __restrict__ vec, const uint32_t* __restrict__ mat,
half* __restrict__ mul, const half* __restrict__ scales,
const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx,
int batch, int height, int width) {
int batch, int height, int width, bool use_v2_format) {
int zero_width = width / 4;
int vec_height = height * 2;
const int blockwidth2 = BLOCK_KN_SIZE / 2;
......@@ -1282,6 +1311,9 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) {
for (int m = 0; m < b_end; ++m) {
......@@ -1316,12 +1348,13 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
half scale_f2 = scales[g2 * width + w];
half2 scale = __halves2half2(scale_f, scale_f2);
half2 zero = __halves2half2(
__hmul(scale_f,
__int2half_rn(
-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)),
__hmul(scale_f2,
__int2half_rn(
-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1)));
__hmul(scale_f, __int2half_rn(
-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) -
zero_offset)),
__hmul(
scale_f2,
__int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) -
zero_offset)));
scales_tmp[tmp_k] = scale;
zeros_tmp[tmp_k] = zero;
}
......@@ -1359,7 +1392,7 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_g_idx,
half* c, int size_m, int size_n, int size_k,
int bit) {
bool use_v2_format, int bit) {
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
......@@ -1376,17 +1409,15 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>(
(const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx,
size_m, size_k / 32 * bit, size_n);
size_m, size_k / 32 * bit, size_n, use_v2_format);
}
template <class T, int bit>
__global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int* __restrict__ g_idx,
const int height, const int width,
const int group,
half* __restrict__ out) {
__global__ void reconstruct_gptq_kernel(
const uint32_t* __restrict__ w, const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx,
const int height, const int width, const int group,
const bool use_v2_format, half* __restrict__ out) {
// Start of block
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
......@@ -1399,6 +1430,9 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
MatrixView_half w_scales_(w_scales, group, width);
T w_zeros_(w_zeros, group, width);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
uint32_t w_read = w[blockIdx.y * width + column];
half* out_ptr = out_.item_ptr(row, column);
......@@ -1406,7 +1440,7 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
for (int s = 0; s < 32; s += bit) {
int group = g_idx[row + s / bit];
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
uint32_t w_zero = w_zeros_.item(group, column) + zero_offset;
half w_item =
__hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero),
w_scale);
......@@ -1419,7 +1453,7 @@ __global__ void reconstruct_gptq_3bit_kernel(
const uint32_t* __restrict__ w, const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx,
const int height, const int width, const int group,
half* __restrict__ out) {
const bool use_v2_format, half* __restrict__ out) {
// Start of block
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
auto row = blockIdx.y * 32;
......@@ -1431,6 +1465,9 @@ __global__ void reconstruct_gptq_3bit_kernel(
MatrixView_half w_scales_(w_scales, group, width);
MatrixView_q3_row w_zeros_(w_zeros, group, width);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
uint32_t w1 = w[(blockIdx.y * 3) * width + column];
uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column];
uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column];
......@@ -1440,7 +1477,7 @@ __global__ void reconstruct_gptq_3bit_kernel(
for (int i = 0; i < 32; i += 1) {
int group = g_idx[row + i];
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
uint32_t w_zero = w_zeros_.item(group, column) + zero_offset;
int w_item;
if (i == 10) {
w_item = (w1 >> 30) | ((w2 << 2) & 0x4);
......@@ -1460,7 +1497,8 @@ __global__ void reconstruct_gptq_3bit_kernel(
void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_g_idx, half* out,
int height, int width, int groups, int bit) {
int height, int width, int groups, bool use_v2_format,
int bit) {
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
......@@ -1480,7 +1518,7 @@ void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>(b_q_weight, b_gptq_scales,
b_gptq_qzeros, b_g_idx, height,
width, groups, out);
width, groups, use_v2_format, out);
}
void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
......@@ -1488,7 +1526,8 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_g_idx,
half* c, half* temp_dq, int size_m, int size_n,
int size_k, int groups, bool use_exllama, int bit) {
int size_k, int groups, bool use_exllama,
bool use_v2_format, int bit) {
bool use_reconstruct;
if (use_exllama) {
use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) ||
......@@ -1502,10 +1541,10 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
// Reconstruct FP16 matrix, then cuBLAS
if (use_exllama) {
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups, bit);
temp_dq, size_k, size_n, groups, use_v2_format, bit);
} else {
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups, bit);
temp_dq, size_k, size_n, groups, use_v2_format, bit);
}
const half alpha = __float2half(1.0f);
......@@ -1521,18 +1560,18 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
if (max_chunks) {
gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, c, last_chunk, size_n, size_k,
BLOCK_M_SIZE_MAX, groups, bit);
BLOCK_M_SIZE_MAX, groups, use_v2_format, bit);
}
if (last_chunk_size) {
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight,
b_gptq_qzeros, b_gptq_scales, b_g_idx,
c + last_chunk * size_n, last_chunk_size,
size_n, size_k, last_chunk_size, groups, bit);
gemm_half_q_half_cuda_part(
a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, c + last_chunk * size_n, last_chunk_size, size_n, size_k,
last_chunk_size, groups, use_v2_format, bit);
}
} else {
gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
c, size_m, size_n, size_k, bit);
c, size_m, size_n, size_k, use_v2_format, bit);
}
}
......@@ -1819,7 +1858,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit) {
bool use_exllama, bool use_v2_format, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
......@@ -1837,7 +1876,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
c.size(1), // n
a.size(1), // k
b_gptq_qzeros.size(0), // group number
use_exllama, bit);
use_exllama, use_v2_format, bit);
return c;
}
......
......@@ -247,22 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
return out;
}
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
}
......@@ -17,28 +17,32 @@ FILE_HEAD = """
namespace MARLIN_NAMESPACE_NAME {
""".strip()
TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );")
TEMPLATE = (
"template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
"vllm::kFE2M1f"
"vllm::kU4",
"vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
(128, 64, 128)]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
......@@ -59,11 +63,12 @@ def generate_new_kernels():
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
):
# act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128"
"vllm::kU4B8",
"vllm::kU8B128",
]:
continue
if thread_configs[2] == 256:
......@@ -93,8 +98,7 @@ def generate_new_kernels():
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
is_zp_float_list = [False]
if dtype == "fp16" and scalar_type == "vllm::kU4" and \
group_blocks == 4:
if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4:
# HQQ (is_zp_float = true) only supports
# 4bit quantization and fp16
is_zp_float_list.append(True)
......
......@@ -321,22 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
return out;
}
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
}
......@@ -802,7 +802,7 @@ torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace) {
});
if (numel % 256 != 0) {
out = out.index({torch::indexing::Slice(0, numel / had_size)});
out = out.narrow(0, 0, numel / had_size);
}
if (inplace && out.data_ptr() != x.data_ptr()) {
......
......@@ -9,23 +9,23 @@ from collections.abc import Iterable
from copy import deepcopy
from dataclasses import dataclass, fields
from functools import reduce
from typing import Optional, Union
import jinja2
# yapf conflicts with isort for this block
# yapf: disable
from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag,
EpilogueScheduleType,
MixedInputKernelScheduleType,
TileSchedulerTag,
TileSchedulerType, VLLMDataType,
VLLMDataTypeNames,
VLLMDataTypeSize, VLLMDataTypeTag,
VLLMDataTypeTorchDataTypeTag,
VLLMDataTypeVLLMScalarTypeTag,
VLLMKernelScheduleTag)
# yapf: enable
from vllm_cutlass_library_extension import (
DataType,
EpilogueScheduleTag,
EpilogueScheduleType,
MixedInputKernelScheduleType,
TileSchedulerTag,
TileSchedulerType,
VLLMDataType,
VLLMDataTypeNames,
VLLMDataTypeSize,
VLLMDataTypeTag,
VLLMDataTypeTorchDataTypeTag,
VLLMDataTypeVLLMScalarTypeTag,
VLLMKernelScheduleTag,
)
#
# Generator templating
......@@ -258,7 +258,7 @@ class ScheduleConfig:
@dataclass(frozen=True)
class TypeConfig:
a: DataType
b: Union[DataType, VLLMDataType]
b: DataType | VLLMDataType
b_group_scale: DataType
b_group_zeropoint: DataType
b_channel_scale: DataType
......@@ -279,25 +279,30 @@ class PrepackTypeConfig:
class ImplConfig:
types: TypeConfig
schedules: list[ScheduleConfig]
heuristic: list[tuple[Optional[str], ScheduleConfig]]
heuristic: list[tuple[str | None, ScheduleConfig]]
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
tile_shape = (
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
)
cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" +
f"x{schedule_config.cluster_shape_mnk[1]}" +
f"x{schedule_config.cluster_shape_mnk[2]}")
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\
.split("::")[-1]
epilogue_schedule = EpilogueScheduleTag[
schedule_config.epilogue_schedule].split("::")[-1]
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\
.split("::")[-1]
return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" +
f"_{epilogue_schedule}_{tile_scheduler}")
cluster_shape = (
f"{schedule_config.cluster_shape_mnk[0]}"
+ f"x{schedule_config.cluster_shape_mnk[1]}"
+ f"x{schedule_config.cluster_shape_mnk[2]}"
)
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule].split(
"::"
)[-1]
epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split(
"::"
)[-1]
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
return (
f"{tile_shape}_{cluster_shape}_{kernel_schedule}"
+ f"_{epilogue_schedule}_{tile_scheduler}"
)
# mostly unique shorter sch_sig
......@@ -316,18 +321,24 @@ def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
# unique type_name
def generate_type_signature(kernel_types: TypeConfig):
return str("".join([
VLLMDataTypeNames[getattr(kernel_types, field.name)]
for field in fields(TypeConfig)
]))
return str(
"".join(
[
VLLMDataTypeNames[getattr(kernel_types, field.name)]
for field in fields(TypeConfig)
]
)
)
def generate_type_option_name(kernel_types: TypeConfig):
return ", ".join([
f"{field.name.replace('b_', 'with_')+'_type'}=" +
VLLMDataTypeNames[getattr(kernel_types, field.name)]
for field in fields(TypeConfig)
])
return ", ".join(
[
f"{field.name.replace('b_', 'with_') + '_type'}="
+ VLLMDataTypeNames[getattr(kernel_types, field.name)]
for field in fields(TypeConfig)
]
)
def is_power_of_two(n):
......@@ -335,7 +346,6 @@ def is_power_of_two(n):
def to_cute_constant(value: list[int]):
def _to_cute_constant(value: int):
if is_power_of_two(value):
return f"_{value}"
......@@ -350,11 +360,11 @@ def to_cute_constant(value: list[int]):
def unique_schedules(impl_configs: list[ImplConfig]):
# Use dict over set for deterministic ordering
return list({
sch: None
for impl_config in impl_configs
for sch in impl_config.schedules
}.keys())
return list(
{
sch: None for impl_config in impl_configs for sch in impl_config.schedules
}.keys()
)
def unsigned_type_with_bitwidth(num_bits):
......@@ -380,7 +390,7 @@ template_globals = {
"gen_type_sig": generate_type_signature,
"unique_schedules": unique_schedules,
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
"gen_type_option_name": generate_type_option_name
"gen_type_option_name": generate_type_option_name,
}
......@@ -398,23 +408,28 @@ prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
sources = []
sources.append((
"machete_mm_dispatch",
mm_dispatch_template.render(impl_configs=impl_configs),
))
sources.append(
(
"machete_mm_dispatch",
mm_dispatch_template.render(impl_configs=impl_configs),
)
)
prepack_types = []
for impl_config in impl_configs:
convert_type = impl_config.types.a \
if impl_config.types.b_group_scale == DataType.void \
else impl_config.types.b_group_scale
convert_type = (
impl_config.types.a
if impl_config.types.b_group_scale == DataType.void
else impl_config.types.b_group_scale
)
prepack_types.append(
PrepackTypeConfig(
a=impl_config.types.a,
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
convert=convert_type,
accumulator=impl_config.types.accumulator,
))
)
)
def prepacked_type_key(prepack_type: PrepackTypeConfig):
# For now, we can just use the first accumulator type seen since
......@@ -430,10 +445,14 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
unique_prepack_types.append(prepack_type)
prepack_types_seen.add(key)
sources.append((
"machete_prepack",
prepack_dispatch_template.render(types=unique_prepack_types, ),
))
sources.append(
(
"machete_prepack",
prepack_dispatch_template.render(
types=unique_prepack_types,
),
)
)
# Split up impls across files
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
......@@ -466,10 +485,12 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
curr_impl_in_file += len(files_impls[-1][-1].schedules)
for part, file_impls in enumerate(files_impls):
sources.append((
f"machete_mm_impl_part{part+1}",
mm_impl_template.render(impl_configs=file_impls),
))
sources.append(
(
f"machete_mm_impl_part{part + 1}",
mm_impl_template.render(impl_configs=file_impls),
)
)
return sources
......@@ -514,8 +535,7 @@ def generate():
# For now we use the same heuristic for all types
# Heuristic is currently tuned for H100s
default_heuristic = [
(cond, ScheduleConfig(*tile_config,
**sch_common_params)) # type: ignore
(cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
for cond, tile_config in default_tile_heuristic_config.items()
]
......@@ -541,14 +561,18 @@ def generate():
a_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
for a in (DataType.f16, DataType.bf16))
)
for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
for a in (DataType.f16, DataType.bf16)
)
impl_configs += [
ImplConfig(x[0], x[1], x[2])
for x in zip(GPTQ_kernel_type_configs,
itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic))
for x in zip(
GPTQ_kernel_type_configs,
itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic),
)
]
AWQ_kernel_type_configs = list(
......@@ -561,14 +585,18 @@ def generate():
a_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
) for b in (DataType.u4, DataType.u8)
for a in (DataType.f16, DataType.bf16))
)
for b in (DataType.u4, DataType.u8)
for a in (DataType.f16, DataType.bf16)
)
impl_configs += [
ImplConfig(x[0], x[1], x[2])
for x in zip(AWQ_kernel_type_configs,
itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic))
for x in zip(
AWQ_kernel_type_configs,
itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic),
)
]
# TODO: Support W4A8 when ready
......
......@@ -231,7 +231,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
}
......@@ -245,7 +245,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
}
......@@ -259,7 +259,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm,
Shape<_2, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
out, a, b, a_scales, b_scales);
}
......@@ -271,10 +271,10 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
// TMA epilogue isn't compatible with Swap A/B
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>,
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm
} // namespace vllm
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment