Commit 96ae75ad authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.6.post1' into v0.6.6.post1-dev

parents f9f4a735 2339d59f
#include <climits>
#include <iostream>
inline uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
\ No newline at end of file
#include "cutlass_extensions/common.hpp"
int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
0);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
0);
int32_t version_num = major_capability * 10 + minor_capability;
return version_num;
}
\ No newline at end of file
......@@ -2,20 +2,27 @@
#include "cutlass/cutlass.h"
#include <climits>
#include "cuda_runtime.h"
#include <iostream>
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
TORCH_CHECK(status == cutlass::Status::kSuccess, \
cutlassGetStatusString(status)) \
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
}
inline uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
/**
* Panic wrapper for unwinding CUDA runtime errors
*/
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
}
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
int max_shared_mem_per_block_opt_in = 0;
......@@ -25,3 +32,4 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
return max_shared_mem_per_block_opt_in;
}
int32_t get_sm_version_num();
#pragma once
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
......
#pragma once
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
/*
......@@ -36,13 +38,13 @@ struct ScaledEpilogueBase {
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
// This utility function constructs the arguments for the load descriptors
......
......@@ -123,6 +123,92 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
}
}
// TODO(simon): this is temporarily adapted from
// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
// we did this to unblock Deepseek V3 but there should be a better
// implementation to manage shared memory.
template <typename scalar_t>
__global__ void moe_align_block_size_global_mem_kernel(
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
}
__syncthreads();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}
/**
* Each thread processes a token shard, calculating the index of each token
* after sorting by expert number. Given the example topk_ids =
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
}
template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d]
......@@ -147,7 +233,41 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_INTEGRAL_TYPES(
// If we have very large number of experts, we can no longer use shared
// memory.
// TODO(simon): the right solution should be calculating the exact right
// amount of shared memory and use that. The num_experts >= 256 is just a
// temporary solution to unblock Deepseek V3.
if (num_experts >= 256) {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t mem_tokens_cnts =
((num_experts + 1) * num_experts) * sizeof(int32_t);
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
// allocate global memory
int32_t* tokens_cnts;
int32_t* cumsum;
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
cudaMalloc(&cumsum, mem_cumsum);
auto kernel =
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
kernel<<<1, num_thread, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), tokens_cnts, cumsum);
cudaFree(tokens_cnts);
cudaFree(cumsum);
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
int32_t shared_mem_normal = ((num_thread + 1) * num_experts + (num_experts + 1)) *
......@@ -185,6 +305,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.numel());
}
});
}
}
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
......
......@@ -303,6 +303,17 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);
void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& e,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
torch::Tensor& e, torch::Tensor const& a);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
......
......@@ -21,15 +21,16 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "common.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on
using namespace cute;
/*
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
Epilogues defined in,
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
......
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#include <torch/all.h>
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
#include <ATen/cuda/CUDAContext.h>
#include <iostream>
#include <sstream>
#include <vector>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "common.hpp"
// clang-format on
using namespace cute;
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
using namespace vllm;
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later.
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace {
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm90_or_later : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
ElementD, EpilogueSchedule>;
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
using ElementC = void;
using StrideC = StrideD;
using EVTCompute = typename Epilogue::EVTCompute;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
EpilogueSchedule, EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage);
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(CEStorageSize)>;
// clang-format off
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementAB, cutlass::layout::RowMajor, 16,
ElementAB, cutlass::layout::ColumnMajor, 16,
ElementAcc, TileShape, ClusterShape,
Stages,
KernelSchedule>::CollectiveOp;
// clang-format on
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>>;
struct GemmKernel : public KernelType {};
};
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideA = Stride<int64_t, Int<1>, int64_t>;
using StrideB = Stride<int64_t, Int<1>, int64_t>;
using StrideC = typename Gemm::StrideC;
StrideA a_stride{lda, Int<1>{}, 0};
StrideB b_stride{ldb, Int<1>{}, 0};
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
using GemmKernel = typename Gemm::GemmKernel;
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
b_stride};
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...),
c_ptr, c_stride, c_ptr, c_stride};
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape, mainloop_args, epilogue_args};
// Launch the CUTLASS GEMM kernel.
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_default {
// M in (128, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M128 {
// M in (64, 128]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M64 {
// M in [1, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_default {
// For M > 128 and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M128 {
// For M in (64, 128] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M64 {
// For M in (32, 64] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M32_NBig {
// For M in [1, 32] and N >= 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _256>;
using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M32_NSmall {
// For M in [1, 32] and N < 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
} // namespace
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
using Cutlass3xGemmDefault =
typename sm90_fp8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
if (mp2 <= 64) {
// m in [1, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
using Cutlass3xGemmDefault =
typename sm90_int8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NBig =
typename sm90_int8_config_M32_NBig<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NSmall =
typename sm90_int8_config_M32_NSmall<InType, OutType,
Epilogue>::Cutlass3xGemm;
uint32_t const n = out.size(1);
bool const is_small_n = n < 8192;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) {
// m in [1, 32]
if (is_small_n) {
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
} else if (mp2 <= 64) {
// m in (32, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
......
#pragma once
// clang-format will break include orders
// clang-format off
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
// clang-format on
/*
Epilogues defined in,
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
must contain a public type named EVTCompute of type Sm90EVT, as well as a
static prepare_args function that constructs an EVTCompute::Arguments struct.
*/
using namespace cute;
namespace vllm {
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm90_or_later : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule>
struct cutlass_3x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
ElementD, EpilogueSchedule>;
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
using ElementC = void;
using StrideC = StrideD;
using EVTCompute = typename Epilogue::EVTCompute;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
EpilogueSchedule, EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage);
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(CEStorageSize)>;
// clang-format off
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementAB, cutlass::layout::RowMajor, 16,
ElementAB, cutlass::layout::ColumnMajor, 16,
ElementAcc, TileShape, ClusterShape,
Stages,
KernelSchedule>::CollectiveOp;
// clang-format on
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>>;
struct GemmKernel : public KernelType {};
};
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideA = Stride<int64_t, Int<1>, int64_t>;
using StrideB = Stride<int64_t, Int<1>, int64_t>;
using StrideC = typename Gemm::StrideC;
StrideA a_stride{lda, Int<1>{}, 0};
StrideB b_stride{ldb, Int<1>{}, 0};
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
using GemmKernel = typename Gemm::GemmKernel;
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
b_stride};
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...),
c_ptr, c_stride, c_ptr, c_stride};
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape, mainloop_args, epilogue_args};
// Launch the CUTLASS GEMM kernel.
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}
} // namespace vllm
#pragma once
#include "scaled_mm_c3x.cuh"
/**
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
* shape.
*/
namespace vllm {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_default {
// M in (128, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M128 {
// M in (64, 128]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M64 {
// M in [1, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _128>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
using Cutlass3xGemmDefault =
typename sm90_fp8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
if (mp2 <= 64) {
// m in [1, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace vllm
\ No newline at end of file
#pragma once
#include "scaled_mm_c3x.cuh"
/**
* This file defines Gemm kernel configurations for SM90 (int8) based on the
* Gemm shape.
*/
namespace vllm {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_default {
// For M > 128 and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M128 {
// For M in (64, 128] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M64 {
// For M in (32, 64] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M32_NBig {
// For M in [1, 32] and N >= 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _256>;
using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M32_NSmall {
// For M in [1, 32] and N < 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
using Cutlass3xGemmDefault =
typename sm90_int8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NBig =
typename sm90_int8_config_M32_NBig<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NSmall =
typename sm90_int8_config_M32_NSmall<InType, OutType,
Epilogue>::Cutlass3xGemm;
uint32_t const n = out.size(1);
bool const is_small_n = n < 8192;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) {
// m in [1, 32]
if (is_small_n) {
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
} else if (mp2 <= 64) {
// m in (32, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace vllm
\ No newline at end of file
......@@ -3,6 +3,8 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
......@@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
return false;
}
int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
0);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
0);
int32_t version_num = major_capability * 10 + minor_capability;
return version_num;
}
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
......
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
#include "cutlass/numeric_conversion.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
// clang-format on
using namespace cute;
using namespace vllm;
/// Make A structured sparse by replacing elements with 0 and compress it
template <typename ElementA_, typename ElementAcc_>
bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
TORCH_CHECK(a.dim() == 2)
// Check for strides and alignment
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
TORCH_CHECK(a.stride(1) == 1)
int m = a.size(0);
int k = a.size(1);
// Sparse kernel setup; this kernel is not used for matmul,
// but just for setting up the compressor utility
// A matrix configuration
using ElementA = ElementA_;
using LayoutTagA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
// B matrix configuration
using ElementB = ElementA;
using LayoutTagB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
// C/D matrix configuration
using ElementC = float;
using LayoutTagC = cutlass::layout::ColumnMajor;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
// Core kernel configurations
using ElementAccumulator = ElementAcc_;
using TileShape = Shape<_128, _128, _128>;
using TileShapeRef = Shape<_128, _128, _64>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = typename std::conditional<
std::is_same_v<ElementA, cutlass::float_e4m3_t>,
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
cutlass::gemm::KernelTmaWarpSpecialized>::type;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using ProblemShape = Shape<int, int, int, int>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC,
AlignmentC, ElementC, LayoutTagC, AlignmentC,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA,
LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
using StrideE = StrideA;
using StrideA = Stride<int64_t, Int<1>, int64_t>;
// The n (=1) dimension does not matter for the compressor
typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1};
using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA;
using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE;
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
// Offline compressor kernel
using CompressorUtility =
cutlass::transform::kernel::StructuredSparseCompressorUtility<
ProblemShape, ElementA, LayoutTagA, SparseConfig>;
using CompressorKernel =
cutlass::transform::kernel::StructuredSparseCompressor<
ProblemShape, ElementA, LayoutTagA, SparseConfig,
cutlass::arch::Sm90>;
using Compressor =
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
auto [M, N, K, L] = prob_shape;
StrideA stride_A;
stride_A =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
CompressorUtility compressor_utility(prob_shape, stride_A);
int ME = compressor_utility.get_metadata_m_physical();
int KE = compressor_utility.get_metadata_k_physical();
int KC = compressor_utility.get_tensorA_k_physical();
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
auto a_meta_ptr = static_cast<typename Gemm::CollectiveMainloop::ElementE*>(
a_meta.data_ptr());
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
typename Compressor::Arguments arguments{
prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}};
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(compressor_op.can_implement(arguments));
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(compressor_op.run());
CUDA_CHECK(cudaDeviceSynchronize());
return true;
}
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a) {
if (a.dtype() == torch::kBFloat16) {
return cutlass_sparse_compress<cutlass::bfloat16_t, float>(a_nzs, a_meta,
a);
} else if (a.dtype() == torch::kFloat16) {
return cutlass_sparse_compress<cutlass::half_t, float>(a_nzs, a_meta, a);
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
return cutlass_sparse_compress<cutlass::float_e4m3_t, float>(a_nzs, a_meta,
a);
} else if (a.dtype() == torch::kInt8) {
return cutlass_sparse_compress<int8_t, int32_t>(a_nzs, a_meta, a);
}
return false;
}
#endif
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a);
#endif
bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2);
TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) &&
a_nzs.size(1) * 2 == a.size(1) &&
a_meta.size(1) * 2 * 4 == a.size(1));
// Considering elemsPerMetaElem = 8b / 2b_per_nz = 4
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 &&
a_meta.stride(1) == 1); // Row-major
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if (version_num >= 90) {
return cutlass_sparse_compress_sm90(a_nzs, a_meta, a);
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
// clang-format on
using namespace cute;
using namespace vllm;
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM256 =
typename sm90_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM512 =
typename sm90_fp8_config_M512<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm1 =
typename sm90_fp8_config_1<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm2 =
typename sm90_fp8_config_2<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm3 =
typename sm90_fp8_config_3<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm4 =
typename sm90_fp8_config_4<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm5 =
typename sm90_fp8_config_5<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm6 =
typename sm90_fp8_config_6<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm7 =
typename sm90_fp8_config_7<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm8 =
typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const n = bt_nzs.size(0);
uint32_t const m = a.size(0); // Batch size
uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
if (mp2 <= 64) {
if (n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm2>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (n == 4096 || n == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm1>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
} else if (mp2 <= 128) {
if (n == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm3>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm5>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (n == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm4>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
} else if (mp2 <= 256) {
if (n == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm6>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (n == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
} else {
if (n == 6144 || n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (n == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
}
// Otherwise the default heuristic
if (mp2 <= 64) {
// n in [1, 64]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// n in (64, 128]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// n in (128, 256]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM256>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else {
// n in (256, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmM512>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::half_t>());
TORCH_CHECK(a.dtype() == torch::kFloat16);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
// m in (128, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::bfloat16_t>());
TORCH_CHECK(a.dtype() == torch::kBFloat16);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
// m in (128, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NBig =
typename sm90_int8_config_M32_NBig<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NSmall =
typename sm90_int8_config_M32_NSmall<InType, OutType,
Epilogue>::Cutlass3xGemm;
uint32_t const n = out.size(1);
bool const is_small_n = n < 8192;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) {
// m in [1, 32]
if (is_small_n) {
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NSmall>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else {
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NBig>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
} else if (mp2 <= 64) {
// m in (32, 64]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
}
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} else if (a.dtype() == torch::kFloat16) {
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t,
cutlass::bfloat16_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t, cutlass::half_t,
Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} else { // a.dtype() == torch::kBFloat16
TORCH_CHECK(a.dtype() == torch::kBFloat16);
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
cutlass::bfloat16_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
cutlass::half_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
out, a, bt_nzs, bt_meta, b_scales, a_scales, *bias);
} else {
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>(
out, a, bt_nzs, bt_meta, b_scales, a_scales);
}
}
#endif
This diff is collapsed.
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability) {
// sparse CUTLASS kernels need at least
// CUDA 12.2 and SM90 (Hopper)
#if defined CUDA_VERSION
return CUDA_VERSION >= 12020 && cuda_device_capability >= 90;
#endif
return false;
}
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& e,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
#endif
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
// Checks for conformality
TORCH_CHECK(a.dim() == 2 && bt_nzs.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(1) == bt_nzs.size(0) && bt_nzs.size(1) * 2 == a.size(1) &&
a.size(0) == c.size(0));
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == bt_nzs.size(0));
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && bt_nzs.stride(1) == 1 &&
c.stride(1) == 1); // Row-major
TORCH_CHECK(c.stride(0) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(bt_nzs.stride(0) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->numel() == bt_nzs.size(0) && bias->is_contiguous() &&
bias->dim() == 1);
}
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales,
bias);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}
......@@ -511,6 +511,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
// given capability
ops.def(
"cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool");
ops.impl("cutlass_sparse_scaled_mm_supported",
&cutlass_sparse_scaled_mm_supported);
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
" Tensor bt_nzs,"
" Tensor bt_meta, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
// CUTLASS sparse matrix compressor
ops.def(
"cutlass_sparse_compress_entry(Tensor! a_nzs, Tensor! a_meta,"
" Tensor a) -> bool");
ops.impl("cutlass_sparse_compress_entry", &cutlass_sparse_compress_entry);
// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
......
sphinx==6.2.1
sphinx-book-theme==1.0.1
sphinx-copybutton==0.5.2
myst-parser==2.0.0
myst-parser==3.0.1
sphinx-argparse==0.4.0
msgspec
cloudpickle
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment