Commit 99324e25 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.2' into v0.9.2-ori

parents cc7f22a8 a5dd03c1
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <cmath>
#include "../../dispatch_utils.h"
#include "../vectorization_utils.cuh"
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
#endif
static inline __device__ int8_t float_to_int8_rn(float x) {
......@@ -103,134 +105,172 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
namespace vllm {
template <typename scalar_t, typename scale_type>
template <typename scalar_t, typename scale_t>
__global__ void static_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) {
int const tid = threadIdx.x;
int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
const scale_t* scale_ptr, const int hidden_size) {
const int tid = threadIdx.x;
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
const float scale = *scale_ptr;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale);
}
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
dst = float_to_int8_rn(static_cast<float>(src) / scale);
});
}
template <typename scalar_t, typename scale_type, typename azp_type>
template <typename scalar_t, typename scale_t, typename azp_t>
__global__ void static_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, azp_type const* azp_ptr,
const int hidden_size) {
int const tid = threadIdx.x;
int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
azp_type const azp = *azp_ptr;
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
const scale_t* scale_ptr, const azp_t* azp_ptr, const int hidden_size) {
const int tid = threadIdx.x;
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
const float scale = *scale_ptr;
const azp_t azp = *azp_ptr;
const float inv_s = 1.0f / scale;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[i]);
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
out[i] = quant_val;
}
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
const auto v = static_cast<float>(src) * inv_s;
dst = int32_to_int8(float_to_int32_rn(v) + azp);
});
}
template <typename scalar_t, typename scale_type>
template <typename scalar_t, typename scale_t>
__global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) {
int const tid = threadIdx.x;
int64_t const token_idx = blockIdx.x;
float absmax_val = 0.0f;
float const zero = 0.0f;
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
scale_t* scale_out, const int hidden_size) {
const int tid = threadIdx.x;
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[i]);
val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val;
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
float const block_absmax_val_maybe =
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
__shared__ float block_absmax_val;
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
// calculate for absmax
float thread_max = 0.f;
vectorize_read_with_alignment<16>(
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
const float v = fabsf(static_cast<float>(src));
thread_max = fmaxf(thread_max, v);
});
using BlockReduce = cub::BlockReduce<float, 256>;
__shared__ typename BlockReduce::TempStorage tmp;
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
__shared__ float absmax;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
scale[token_idx] = block_absmax_val / 127.0f;
absmax = block_max;
scale_out[blockIdx.x] = absmax / 127.f;
}
__syncthreads();
float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale);
}
float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax;
// 2. quantize
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
dst = float_to_int8_rn(static_cast<float>(src) * inv_s);
});
}
template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void dynamic_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, azp_type* azp, const int hidden_size) {
int64_t const token_idx = blockIdx.x;
// MinMax structure to hold min and max values in one go
struct MinMax {
float min, max;
// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;
// Scan for the min and max value for this token
float max_val = std::numeric_limits<float>::min();
float min_val = std::numeric_limits<float>::max();
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto val = static_cast<float>(input[i]);
max_val = std::max(max_val, val);
min_val = std::min(min_val, val);
__host__ __device__ MinMax()
: min(std::numeric_limits<float>::max()),
max(std::numeric_limits<float>::lowest()) {}
__host__ __device__ explicit MinMax(float v) : min(v), max(v) {}
// add a value to the MinMax
__host__ __device__ MinMax& operator+=(float v) {
min = fminf(min, v);
max = fmaxf(max, v);
return *this;
}
// Reduce the max and min values across the block
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x);
__syncthreads(); // Make sure min doesn't mess with max shared memory
min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);
__shared__ scale_type scale_sh;
__shared__ azp_type azp_sh;
// Compute the scale and zero point and store them, only on the first thread
if (threadIdx.x == 0) {
float const scale_val = (max_val - min_val) / 255.0f;
// Use rounding to even (same as torch.round)
auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val);
auto const azp_val = static_cast<azp_type>(azp_float);
// Store the scale and azp into shared and global
scale[token_idx] = scale_sh = scale_val;
azp[token_idx] = azp_sh = azp_val;
// merge two MinMax objects
__host__ __device__ MinMax& operator&=(const MinMax& other) {
min = fminf(min, other.min);
max = fmaxf(max, other.max);
return *this;
}
};
// Wait for the scale and azp to be computed
__syncthreads();
__host__ __device__ inline MinMax operator+(MinMax a, float v) {
return a += v;
}
__host__ __device__ inline MinMax operator&(MinMax a, const MinMax& b) {
return a &= b;
}
float const scale_val = scale_sh;
azp_type const azp_val = azp_sh;
template <typename scalar_t, typename scale_t, typename azp_t>
__global__ void dynamic_scaled_int8_azp_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ output,
scale_t* scale_out, azp_t* azp_out, const int hidden_size) {
const int tid = threadIdx.x;
const int stride = blockDim.x;
const int64_t token_idx = blockIdx.x;
// Quantize the values
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[i]);
auto const quant_val =
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
out[i] = quant_val;
// Must be performed using 64-bit math to avoid integer overflow.
const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size;
// 1. calculate min & max
MinMax thread_mm;
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
[&] __device__(const scalar_t& src) {
thread_mm += static_cast<float>(src);
});
using BlockReduce = cub::BlockReduce<MinMax, 256>;
__shared__ typename BlockReduce::TempStorage tmp;
MinMax mm = BlockReduce(tmp).Reduce(
thread_mm,
[] __device__(MinMax a, const MinMax& b) {
a &= b;
return a;
},
blockDim.x);
__shared__ float scale_sh;
__shared__ azp_t azp_sh;
if (tid == 0) {
float s = (mm.max - mm.min) / 255.f;
float zp = nearbyintf(-128.f - mm.min / s); // round-to-even
scale_sh = s;
azp_sh = azp_t(zp);
scale_out[blockIdx.x] = s;
azp_out[blockIdx.x] = azp_sh;
}
__syncthreads();
const float inv_s = 1.f / scale_sh;
const azp_t azp = azp_sh;
// 2. quantize
vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) {
const auto v = static_cast<float>(src) * inv_s;
dst = int32_to_int8(float_to_int32_rn(v) + azp);
});
}
} // namespace vllm
......@@ -247,7 +287,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
dim3 const block(std::min(hidden_size, 256));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
......@@ -278,7 +318,7 @@ void dynamic_scaled_int8_quant(
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
dim3 const block(std::min(hidden_size, 256));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
......
......@@ -51,7 +51,8 @@ struct cutlass_3x_gemm {
// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;
static constexpr int AlignmentCD =
128 / cutlass::sizeof_bits<ElementD>::value;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
......@@ -144,4 +145,65 @@ struct cutlass_3x_gemm_sm100 {
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
};
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm_sm120 {
using ElementAB = ElementAB_;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementAB>::value;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementAB>::value;
using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<ElementD_>::value;
using ElementD = ElementD_;
using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = AlignmentC;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
// MMA type
using ElementAccumulator = float;
// Epilogue types
using ElementBias = cutlass::half_t;
using ElementCompute = float;
using ElementAux = ElementD;
using LayoutAux = LayoutD;
using ElementAmax = float;
using EVTCompute = typename Epilogue::EVTCompute;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
EVTCompute>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementAB,
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
};
} // namespace vllm
......@@ -36,6 +36,12 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
......
......@@ -15,11 +15,11 @@ using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_default {
// M in (128, inf)
// M in (256, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_256, _128, _64>;
using TileShape = Shape<_256, _128, _128>;
using ClusterShape = Shape<_2, _2, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
......@@ -28,13 +28,13 @@ struct sm100_fp8_config_default {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M128 {
// M in (64, 128]
struct sm100_fp8_config_M256 {
// M in (64, 256]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _64>;
using ClusterShape = Shape<_2, _2, _1>;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
......@@ -43,12 +43,26 @@ struct sm100_fp8_config_M128 {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M64 {
// M in [1, 64]
// M in (16, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M16 {
// M in [1, 16]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _8, _1>;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
......@@ -68,25 +82,31 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
using Cutlass3xGemmDefault =
typename sm100_fp8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM16 =
typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM256 =
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 64) {
// m in [1, 64]
if (mp2 <= 16) {
// m in [1, 16]
return cutlass_gemm_caller<Cutlass3xGemmM16>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// m in (16, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
} else if (mp2 <= 256) {
// m in (64, 256]
return cutlass_gemm_caller<Cutlass3xGemmM256>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
// m in (256, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
......
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm
#pragma once
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/**
* This file defines Gemm kernel configurations for SM120 (fp8) based on the
* Gemm shape.
*/
namespace vllm {
using c3x::cutlass_gemm_caller;
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm120_fp8_config_default {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>; // Only work with Shape<_1, _1, _1>
using Cutlass3xGemm =
cutlass_3x_gemm_sm120<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
using Cutlass3xGemmDefault =
typename sm120_fp8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
} // namespace vllm
\ No newline at end of file
#include "core/registration.h"
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <cassert>
using namespace cute;
template <typename ElementAB, typename ElementC, typename ElementAccumulator,
typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
__global__ void get_ggemm_starts(
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
ElementC** out_offsets, ElementAccumulator** a_scale_offsets,
ElementAccumulator** b_scale_offsets, ElementAB* a_base_as_int,
ElementAB* b_base_as_int, ElementC* out_base_as_int,
ElementAccumulator* a_scale_base_as_int,
ElementAccumulator* b_scale_base_as_int, LayoutSFA* layout_sfa_base_as_int,
LayoutSFB* layout_sfb_base_as_int, int* problem_sizes) {
int expert_id = threadIdx.x;
if (expert_id >= gridDim.x * blockDim.x) {
return;
}
int m = problem_sizes[expert_id * 3];
int n = problem_sizes[expert_id * 3 + 1];
int k = problem_sizes[expert_id * 3 + 2];
int32_t expert_offset = expert_offsets[expert_id];
int a_stride = expert_offset * k;
int b_stride = expert_id * k * n;
int a_scale_stride = expert_offset * k / 128;
int b_scale_stride = expert_id * k * n / 128 / 128;
a_offsets[expert_id] = a_base_as_int + a_stride;
b_offsets[expert_id] = b_base_as_int + b_stride;
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
a_scale_offsets[expert_id] = a_scale_base_as_int + a_scale_stride;
b_scale_offsets[expert_id] = b_scale_base_as_int + b_scale_stride;
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
*layout_sfa_ptr =
ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
*layout_sfb_ptr =
ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, \
ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_ggemm_starts<cutlass::float_e4m3_t, C_TYPE, float, LayoutSFA, \
LayoutSFB, ScaleConfig><<<1, num_experts, 0, stream>>>( \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<float**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<float*>(b_scales.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<int*>(problem_sizes.data_ptr())); \
}
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_ggemm_starts(
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::Tensor out_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& layout_sfa,
torch::Tensor const& layout_sfb, torch::Tensor const& problem_sizes) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0);
TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0);
int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
if (false) {
}
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA,
LayoutSFB, ScaleConfig)
__CALL_GET_STARTS_KERNEL(torch::kFloat16, cutlass::half_t, LayoutSFA,
LayoutSFB, ScaleConfig)
else {
TORCH_CHECK(false, "Unsupported output tensor type");
}
}
template <typename OutType, typename ScheduleConfig, typename LayoutD>
void run_blockwise_scaled_group_mm(
torch::Tensor& out_ptrs, const torch::Tensor& a_ptrs,
const torch::Tensor& b_ptrs, const torch::Tensor& a_scales_ptrs,
const torch::Tensor& b_scales_ptrs, const torch::Tensor& stride_a,
const torch::Tensor& stride_b, const torch::Tensor& stride_c,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
// Types
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = OutType;
using ElementD = ElementC;
using ElementAccumulator = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = LayoutD;
// Alignments
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, typename ScheduleConfig::MmaTileShape,
typename ScheduleConfig::ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementAccumulator, void, LayoutC*, AlignmentC, ElementD, LayoutC*,
AlignmentC, typename ScheduleConfig::EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA,
cute::tuple<LayoutA*, typename ScheduleConfig::LayoutSFA*>,
AlignmentA, ElementB,
cute::tuple<LayoutB*, typename ScheduleConfig::LayoutSFB*>,
AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape,
typename ScheduleConfig::ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = (int)expert_offsets.size(0);
Gemm gemm_op;
// Mainloop Arguments
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementA**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(stride_a.data_ptr()),
static_cast<const ElementB**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(stride_b.data_ptr()),
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<typename ScheduleConfig::LayoutSFA*>(
layout_sfa.data_ptr()),
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(
layout_sfb.data_ptr())};
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = a_ptrs.get_device();
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
// Epilogue Arguments
typename GemmKernel::EpilogueArguments epilogue_args{
{}, // epilogue.thread
nullptr,
static_cast<StrideC*>(stride_c.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(stride_c.data_ptr())};
UnderlyingProblemShape* problem_sizes_as_shapes =
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
// Gemm Arguments
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info};
at::cuda::CUDAGuard device_guard{(char)a_ptrs.device().index()};
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM");
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a_ptrs.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void blockwise_scaled_group_mm_dispatch_shape(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
struct MmaConfig {
using ElementA = cutlass::float_e4m3_t;
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using LayoutC = cutlass::layout::RowMajor;
using MmaTileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>;
};
int num_experts = (int)expert_offsets.size(0);
auto a_ptrs = torch::empty(
{num_experts},
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto b_ptrs = torch::empty(
{num_experts},
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto out_ptrs = torch::empty(
{num_experts},
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto a_scales_ptrs = torch::empty(
{num_experts},
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto b_scales_ptrs = torch::empty(
{num_experts},
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto layout_sfa = torch::empty(
{num_experts, 5},
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
auto layout_sfb = torch::empty(
{num_experts, 5},
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
auto stride_a = torch::full(
{num_experts}, a.size(1),
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto stride_b = torch::full(
{num_experts}, a.size(1),
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
auto stride_c = torch::full(
{num_experts}, output.size(1),
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
torch::TensorOptions options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
run_get_ggemm_starts<typename MmaConfig::LayoutSFA,
typename MmaConfig::LayoutSFB,
typename MmaConfig::ScaleConfig>(
expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, a,
b, output, scales_a, scales_b, layout_sfa, layout_sfb, problem_sizes);
run_blockwise_scaled_group_mm<OutType, MmaConfig,
typename MmaConfig::LayoutC>(
out_ptrs, a_ptrs, b_ptrs, a_scales_ptrs, b_scales_ptrs, stride_a,
stride_b, stride_c, layout_sfa, layout_sfb, problem_sizes,
expert_offsets);
}
void cutlass_blockwise_scaled_grouped_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn,
"a must be kFloat8_e4m3fn");
TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn,
"b must be kFloat8_e4m3fn");
TORCH_CHECK(output.scalar_type() == torch::kBFloat16 ||
output.scalar_type() == torch::kHalf,
"output must be bfloat16 or half");
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32,
"scales_a must be float32");
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32,
"scales_b must be float32");
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32,
"expert_offsets must be int32");
TORCH_CHECK(output.dim() == 2, "output must be 2D tensor");
TORCH_CHECK(a.dim() == 2, "a must be 2D tensor");
TORCH_CHECK(b.dim() == 3, "b must be 3D tensor");
TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor");
TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
#if defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100
if (output.scalar_type() == torch::kBFloat16) {
blockwise_scaled_group_mm_dispatch_shape<cutlass::bfloat16_t>(
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
} else if (output.scalar_type() == torch::kFloat16) {
blockwise_scaled_group_mm_dispatch_shape<cutlass::half_t>(
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
} else {
TORCH_CHECK(false, "Unsupported output tensor type");
}
#endif
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_blockwise_scaled_grouped_mm",
&cutlass_blockwise_scaled_grouped_mm);
}
#include <cudaTypedefs.h>
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm120 (Blackwell Geforce).
*/
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
int M = a.size(0), N = b.size(1), K = a.size(1);
TORCH_CHECK(
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
"Currently, only fp8 gemm is implemented for Blackwell");
vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias);
}
#endif
......@@ -41,6 +41,14 @@ void cutlass_moe_mm_sm90(
#endif
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
......@@ -168,8 +176,15 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
if (version_num >= 120) {
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
if (version_num >= 100) {
if (version_num >= 100 && version_num < 120) {
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
return;
}
......@@ -241,7 +256,7 @@ void get_cutlass_moe_mm_data(
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k,
......@@ -252,7 +267,7 @@ void get_cutlass_moe_mm_data(
false,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: ",
version_num, ". Required capability: 90");
version_num, ". Required capability: 90 or 100");
}
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
......@@ -265,7 +280,8 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
......@@ -275,7 +291,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num, ". Required capability: 90");
version_num, ". Required capability: 90 or 100");
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
......
......@@ -335,8 +335,10 @@ void run_fp4_blockwise_scaled_group_mm(
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
#endif
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
......
......@@ -231,7 +231,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
......@@ -240,7 +240,7 @@ cvt_fp16_to_fp4(
#endif
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts, int n_experts) {
uint32_t* output_scale_offset_by_experts, int n_experts, bool low_latency) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
......@@ -248,49 +248,182 @@ cvt_fp16_to_fp4(
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
colIdx += blockDim.x) {
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// Find index within the experts.
int rowIdx_in_expert = 0;
int expert_idx = 0;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
// Each global thread processes one element
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
globalIdx += gridDim.x * blockDim.x) {
// Calculate which row and column this global thread should process
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;
int64_t inOffset = rowIdx * colsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// Find index within the experts using different strategies based on expert
// count
int rowIdx_in_expert = 0;
int expert_idx = 0;
if constexpr (SMALL_NUM_EXPERTS) {
for (int i = 0; i < n_experts; i++) {
if (rowIdx >= input_offset_by_experts[i] &&
rowIdx < input_offset_by_experts[i + 1]) {
rowIdx_in_expert = rowIdx - input_offset_by_experts[i];
uint32_t current_offset = __ldca(&input_offset_by_experts[i]);
uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]);
if (rowIdx >= current_offset && rowIdx < next_offset) {
rowIdx_in_expert = rowIdx - current_offset;
expert_idx = i;
break;
}
}
} else {
// Load input offsets into registers first, then do the computation.
// Local array size set to 17 because of register limit.
uint32_t local_offsets[17];
for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) {
*reinterpret_cast<int4*>(local_offsets) =
__ldca(reinterpret_cast<const int4*>(
&input_offset_by_experts[chunk_start]));
*reinterpret_cast<int4*>(local_offsets + 4) =
__ldca(reinterpret_cast<const int4*>(
&input_offset_by_experts[chunk_start + 4]));
*reinterpret_cast<int4*>(local_offsets + 8) =
__ldca(reinterpret_cast<const int4*>(
&input_offset_by_experts[chunk_start + 8]));
*reinterpret_cast<int4*>(local_offsets + 12) =
__ldca(reinterpret_cast<const int4*>(
&input_offset_by_experts[chunk_start + 12]));
local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]);
// Check against the 16 loaded offsets
#pragma unroll
for (int i = 0; i < 16; i++) {
if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) {
rowIdx_in_expert = rowIdx - local_offsets[i];
expert_idx = chunk_start + i;
break;
}
}
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
int factor = CVT_FP4_SF_VEC_SIZE * 4;
// The actual output_scales dim is computed from the padded numCols.
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
#endif
}
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(1024, 4) cvt_fp16_to_fp4(
#else
cvt_fp16_to_fp4(
#endif
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts, int n_experts) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
extern __shared__ uint32_t shared_input_offsets[];
// Load input offsets into shared memory.
// If n_experts is larger than 4, use vectorized int4 to save instructions.
// If n_experts is smaller than 4, read directly.
if constexpr (SMALL_NUM_EXPERTS) {
for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) {
shared_input_offsets[i] = input_offset_by_experts[i];
}
} else {
for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) {
*reinterpret_cast<int4*>(&shared_input_offsets[i]) =
*reinterpret_cast<const int4*>(&input_offset_by_experts[i]);
}
if (threadIdx.x == 0) {
shared_input_offsets[n_experts] = input_offset_by_experts[n_experts];
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
int factor = CVT_FP4_SF_VEC_SIZE * 4;
// The actual output_scales dim is computed from the padded numCols.
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
// Each global thread processes one element
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
globalIdx += gridDim.x * blockDim.x) {
// Calculate which row and column this global thread should process
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;
int64_t inOffset = rowIdx * colsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// Find expert using binary search for better performance with large m_topk
int rowIdx_in_expert = 0;
int expert_idx = 0;
// Binary search through experts using shared memory
int left = 0, right = n_experts - 1;
while (left <= right) {
int mid = (left + right) / 2;
// Get offsets: shared_input_offsets[i] corresponds to
// input_offset_by_experts[i]
uint32_t mid_offset = shared_input_offsets[mid];
uint32_t next_offset = shared_input_offsets[mid + 1];
if (rowIdx >= mid_offset && rowIdx < next_offset) {
rowIdx_in_expert = rowIdx - mid_offset;
expert_idx = mid;
break;
} else if (rowIdx < mid_offset) {
right = mid - 1;
} else {
left = mid + 1;
}
}
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert =
SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
#endif
}
......@@ -309,18 +442,63 @@ void quant_impl(void* output, void* output_scale, void* input,
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(k / ELTS_PER_THREAD), 512));
int const workSizePerRow = k / ELTS_PER_THREAD;
int const totalWorkSize = m_topk * workSizePerRow;
dim3 block(std::min(workSizePerRow, 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 2048 / block.x;
dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM));
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts), n_experts);
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x),
multiProcessorCount * numBlocksPerSM));
while (grid.x <= multiProcessorCount && block.x > 64) {
grid.x *= 2;
block.x = (block.x + 1) / 2;
}
int const blockRepeat =
(totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);
if (blockRepeat > 1) {
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
if (n_experts >= 4) {
cvt_fp16_to_fp4<T, false, false>
<<<grid, block, shared_mem_size, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts);
} else {
cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts);
}
} else {
if (n_experts >= 16) {
cvt_fp16_to_fp4<T, false, false><<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts, /* bool low_latency */ true);
} else {
cvt_fp16_to_fp4<T, false, true><<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts, /* bool low_latency */ true);
}
}
}
/*Quantization entry for fp4 experts quantization*/
......@@ -383,7 +561,7 @@ void scaled_fp4_experts_quant_sm100a(
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
auto in_dtype = input.dtype();
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
if (in_dtype == at::ScalarType::Half) {
......@@ -401,4 +579,4 @@ void scaled_fp4_experts_quant_sm100a(
} else {
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
}
}
\ No newline at end of file
}
......@@ -347,7 +347,7 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
// We don't support e8m0 scales at this moment.
......
......@@ -267,7 +267,7 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
B_sf.sizes()[1], ")");
auto out_dtype = D.dtype();
at::cuda::CUDAGuard device_guard{(char)A.get_device()};
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
if (out_dtype == at::ScalarType::Half) {
......
......@@ -448,8 +448,6 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
[[maybe_unused]] __half2_raw h2r =
__hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
union {
__half2_raw h2r;
uint32_t ui32;
......
......@@ -92,111 +92,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
torch::Tensor X, // input
int64_t type, int64_t row) {
int col = X.sizes()[1];
int vecs = X.sizes()[0];
const int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
at::Tensor Y = torch::empty({1, row}, options);
at::Tensor Y = torch::empty({vecs, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options);
at::Tensor quant_X = torch::empty({vecs, padded / 32 * 9}, options);
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] {
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
(void*)quant_X.data_ptr(), col, 1, stream);
quantize_row_q8_1_cuda<scalar_t>(
(scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, vecs, stream);
switch (type) {
case 2:
mul_mat_vec_q4_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 3:
mul_mat_vec_q4_1_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 6:
mul_mat_vec_q5_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 7:
mul_mat_vec_q5_1_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 8:
mul_mat_vec_q8_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 10:
mul_mat_vec_q2_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 11:
mul_mat_vec_q3_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 12:
mul_mat_vec_q4_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 13:
mul_mat_vec_q5_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 14:
mul_mat_vec_q6_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 16:
mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 17:
mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 18:
mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 19:
mul_mat_vec_iq1_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 20:
mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 21:
mul_mat_vec_iq3_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 22:
mul_mat_vec_iq2_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 23:
mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 29:
mul_mat_vec_iq1_m_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
(scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
}
});
......
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows) {
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows, const int nvecs) {
const auto row = blockIdx.x*blockDim.y + threadIdx.y;
const auto vec = blockIdx.y;
if (row >= nrows) {
if (row >= nrows || vec >= nvecs) {
return;
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
const int nrows_y = (ncols + 512 - 1) / 512 * 512;
// partial sum for each thread
// partial sum for each thread
float tmp = 0.0f;
const block_q_t * x = (const block_q_t *) vx;
......@@ -19,7 +22,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
for (auto i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
const int ibx = row*blocks_per_row + i; // x block index
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
const int iby = vec*(nrows_y/QK8_1) + i * (qk/QK8_1); // y block index that aligns with ibx
const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
......@@ -33,177 +36,177 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
}
if (threadIdx.x == 0) {
dst[row] = tmp;
dst[vec*nrows + row] = tmp;
}
}
template<typename scalar_t>
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template<typename scalar_t>
static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
......@@ -210,8 +210,6 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
auto offset_m = blockIdx.y * m_count;
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
......@@ -348,8 +346,6 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
auto offset_m = blockIdx.y * m_count;
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
......@@ -469,8 +465,6 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
auto offset_m = blockIdx.y * m_count;
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
......@@ -597,8 +591,6 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
auto offset_m = blockIdx.y * m_count;
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
......
......@@ -1113,8 +1113,6 @@ __global__ void Marlin(
if constexpr (has_zp && !is_zp_float) {
if (is_new_zp) {
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
FragB frag_zp_0;
FragB frag_zp_1;
int zp_quant_0, zp_quant_1;
if constexpr (w_type.size_bits() == 4) {
......
......@@ -38,7 +38,6 @@
#include "cute/atom/mma_atom.hpp"
#include "cute/atom/copy_traits_sm90_tma.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp"
......@@ -1003,7 +1002,7 @@ struct MacheteCollectiveMma {
static constexpr int A_CPY_VEC =
decltype(max_common_vector(tCsA, tCrA_load)){};
static constexpr int COVERSION_WIDTH =
static constexpr int CONVERSION_WIDTH =
std::min(A_CPY_VEC, int(size<0>(tCrA_mma)));
auto load_A_to_registers = [&](int read_stage) {
......@@ -1026,8 +1025,8 @@ struct MacheteCollectiveMma {
// PIPELINED MAIN LOOP
//
auto convert_A = [&, a_vec = Int<COVERSION_WIDTH>{}](int k_block,
int read_stage) {
auto convert_A = [&, a_vec = Int<CONVERSION_WIDTH>{}](int k_block,
int read_stage) {
load_extra_info_to_registers(partitioned_extra_info,
copy_partitions_extra_info, k_block,
read_stage);
......
#pragma once
#include "vectorization.cuh"
namespace vllm {
template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp>
struct DefaultVecOp {
ScaOp scalar_op;
__device__ __forceinline__ void operator()(
vec_n_t<OutT, VEC_SIZE>& dst, const vec_n_t<InT, VEC_SIZE>& src) const {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
scalar_op(dst.val[i], src.val[i]);
}
}
};
template <int VEC_SIZE, typename InT, typename OutT, typename VecOp,
typename ScaOp>
__device__ inline void vectorize_with_alignment(
const InT* in, OutT* out, int len, int tid, int stride,
VecOp&& vec_op, // vec_n_t<InT,16> -> vec_n_t<OutT,16>
ScaOp&& scalar_op) { // InT -> OutT
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
"VEC_SIZE must be a positive power-of-two");
constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
// fast path when the whole region is already aligned
// Note: currently the output is guaranteed to be same as the input, so we
// don't check it here, comments here just for future reference.
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
if (can_vec) {
int num_vec = len / VEC_SIZE;
using vin_t = vec_n_t<InT, VEC_SIZE>;
using vout_t = vec_n_t<OutT, VEC_SIZE>;
auto* v_in = reinterpret_cast<const vin_t*>(in);
auto* v_out = reinterpret_cast<vout_t*>(out);
for (int i = tid; i < num_vec; i += stride) {
vout_t tmp;
vec_op(tmp, v_in[i]);
v_out[i] = tmp;
}
return;
}
int misalignment_offset = addr & (WIDTH - 1); // addr % 64
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64
prefix_elems /= sizeof(InT);
prefix_elems = min(prefix_elems, len); // 0 ≤ prefix < 16
// 1. prefill the when it is unsafe to vectorize
for (int i = tid; i < prefix_elems; i += stride) {
scalar_op(out[i], in[i]);
}
in += prefix_elems;
out += prefix_elems;
len -= prefix_elems;
int num_vec = len / VEC_SIZE;
using vin_t = vec_n_t<InT, VEC_SIZE>;
using vout_t = vec_n_t<OutT, VEC_SIZE>;
auto* v_in = reinterpret_cast<const vin_t*>(in);
auto* v_out = reinterpret_cast<vout_t*>(out);
// 2. vectorize the main part
for (int i = tid; i < num_vec; i += stride) {
vout_t tmp;
vec_op(tmp, v_in[i]);
v_out[i] = tmp;
}
// 3. handle the tail
int tail_start = num_vec * VEC_SIZE;
for (int i = tid + tail_start; i < len; i += stride) {
scalar_op(out[i], in[i]);
}
}
template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp>
__device__ __forceinline__ void vectorize_with_alignment(const InT* in,
OutT* out, int len,
int tid, int stride,
ScaOp&& scalar_op) {
using Vec = DefaultVecOp<VEC_SIZE, InT, OutT, std::decay_t<ScaOp>>;
vectorize_with_alignment<VEC_SIZE>(in, out, len, tid, stride, Vec{scalar_op},
std::forward<ScaOp>(scalar_op));
}
template <int VEC_SIZE, typename InT, typename ScaOp>
struct DefaultReadVecOp {
ScaOp scalar_op;
__device__ __forceinline__ void operator()(
const vec_n_t<InT, VEC_SIZE>& src) const {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
scalar_op(src.val[i]);
}
}
};
// read-only version: iterate over the input with alignment guarantees
template <int VEC_SIZE, typename InT, typename VecOp, typename ScaOp>
__device__ inline void vectorize_read_with_alignment(const InT* in, int len,
int tid, int stride,
VecOp&& vec_op,
ScaOp&& scalar_op) {
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
"VEC_SIZE must be a positive power-of-two");
constexpr int WIDTH = VEC_SIZE * sizeof(InT);
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
// fast path when the whole region is already aligned
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
if (can_vec) {
int num_vec = len / VEC_SIZE;
using vin_t = vec_n_t<InT, VEC_SIZE>;
auto* v_in = reinterpret_cast<const vin_t*>(in);
for (int i = tid; i < num_vec; i += stride) {
vec_op(v_in[i]);
}
return;
}
int misalignment_offset = addr & (WIDTH - 1);
int alignment_bytes = WIDTH - misalignment_offset;
int prefix_elems = alignment_bytes & (WIDTH - 1);
prefix_elems /= sizeof(InT);
prefix_elems = min(prefix_elems, len);
// 1. handle the possibly unaligned prefix with scalar access.
for (int i = tid; i < prefix_elems; i += stride) {
scalar_op(in[i]);
}
in += prefix_elems;
len -= prefix_elems;
int num_vec = len / VEC_SIZE;
using vin_t = vec_n_t<InT, VEC_SIZE>;
auto* v_in = reinterpret_cast<const vin_t*>(in);
// 2. vectorized traversal of the main aligned region.
for (int i = tid; i < num_vec; i += stride) {
vec_op(v_in[i]);
}
// 3. handle remaining tail elements.
int tail_start = num_vec * VEC_SIZE;
for (int i = tid + tail_start; i < len; i += stride) {
scalar_op(in[i]);
}
}
// overload that requires only a scalar_op
template <int VEC_SIZE, typename InT, typename ScaOp>
__device__ __forceinline__ void vectorize_read_with_alignment(
const InT* in, int len, int tid, int stride, ScaOp&& scalar_op) {
using Vec = DefaultReadVecOp<VEC_SIZE, InT, std::decay_t<ScaOp>>;
vectorize_read_with_alignment<VEC_SIZE>(in, len, tid, stride, Vec{scalar_op},
std::forward<ScaOp>(scalar_op));
}
} // namespace vllm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment