Commit ec5e299c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.3' into v0.7.3-dev

parents 47bd229c ed6e9075
...@@ -16,6 +16,30 @@ namespace vllm::c3x { ...@@ -16,6 +16,30 @@ namespace vllm::c3x {
using namespace cute; using namespace cute;
template <typename T>
struct identity {
CUTLASS_HOST_DEVICE
T operator()(T lhs) const { return lhs; }
};
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct TrivialEpilogue {
private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Compute = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
using ArgumentType = typename EVTCompute::Arguments;
template <typename... Args>
static ArgumentType prepare_args(Args... args) {
return {};
}
};
/* /*
* This class provides the common load descriptors for the * This class provides the common load descriptors for the
* ScaledEpilogue[...] classes * ScaledEpilogue[...] classes
...@@ -174,6 +198,49 @@ struct ScaledEpilogueBias ...@@ -174,6 +198,49 @@ struct ScaledEpilogueBias
} }
}; };
/*
* This epilogue performs the same operation as ScaledEpilogueBias, but the
* bias is a column vector instead of a row vector. Useful e.g. if we are
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueColumnBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template ColLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args};
return ArgumentType{a_args, evt0_args, bias_args};
}
};
/* /*
* This epilogue directly supports per-tensor azp in int32 form. * This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
...@@ -314,4 +381,4 @@ struct ScaledEpilogueBiasAzpToken ...@@ -314,4 +381,4 @@ struct ScaledEpilogueBiasAzpToken
} }
}; };
}; // namespace vllm::c3x }; // namespace vllm::c3x
\ No newline at end of file
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <THC/THCAtomics.cuh> #include <ATen/cuda/Atomic.cuh>
#include "../cuda_compat.h" #include "../cuda_compat.h"
#include "../dispatch_utils.h" #include "../dispatch_utils.h"
...@@ -198,26 +198,27 @@ __global__ void moe_align_block_size_global_mem_kernel( ...@@ -198,26 +198,27 @@ __global__ void moe_align_block_size_global_mem_kernel(
} }
// taken from // taken from
// https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
template <typename scalar_t> template <typename scalar_t>
__global__ void sgl_moe_align_block_size_kernel( __global__ void sgl_moe_align_block_size_kernel(
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* cumsum) { int32_t block_size, size_t numel, int32_t* cumsum) {
__shared__ int32_t shared_counts[32][8]; __shared__ int32_t shared_counts[32][8];
__shared__ int32_t local_offsets[256];
const int warp_id = threadIdx.x / 32; const int warp_id = threadIdx.x / 32;
const int lane_id = threadIdx.x % 32;
const int experts_per_warp = 8; const int experts_per_warp = 8;
const int my_expert_start = warp_id * experts_per_warp; const int my_expert_start = warp_id * experts_per_warp;
// Initialize shared_counts for this warp's experts
for (int i = 0; i < experts_per_warp; ++i) { for (int i = 0; i < experts_per_warp; ++i) {
if (my_expert_start + i < num_experts) { if (my_expert_start + i < num_experts) {
shared_counts[warp_id][i] = 0; shared_counts[warp_id][i] = 0;
} }
} }
__syncthreads();
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread; const size_t start_idx = threadIdx.x * tokens_per_thread;
...@@ -230,6 +231,7 @@ __global__ void sgl_moe_align_block_size_kernel( ...@@ -230,6 +231,7 @@ __global__ void sgl_moe_align_block_size_kernel(
__syncthreads(); __syncthreads();
// Single thread computes cumulative sum and total tokens
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
cumsum[0] = 0; cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) { for (int i = 1; i <= num_experts; ++i) {
...@@ -246,19 +248,28 @@ __global__ void sgl_moe_align_block_size_kernel( ...@@ -246,19 +248,28 @@ __global__ void sgl_moe_align_block_size_kernel(
__syncthreads(); __syncthreads();
// Assign expert IDs to blocks
if (threadIdx.x < num_experts) { if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) { i += block_size) {
expert_ids[i / block_size] = threadIdx.x; expert_ids[i / block_size] = threadIdx.x;
} }
local_offsets[threadIdx.x] = cumsum[threadIdx.x];
} }
}
__syncthreads(); // taken from
// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { template <typename scalar_t>
__global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids,
int32_t* sorted_token_ids,
int32_t* cumsum_buffer,
size_t numel) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i]; int32_t expert_id = topk_ids[i];
int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1); int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
sorted_token_ids[rank_post_pad] = i; sorted_token_ids[rank_post_pad] = i;
} }
} }
...@@ -633,23 +644,34 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -633,23 +644,34 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) { torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(num_experts == 256,
"sgl_moe_align_block_size kernel only supports deepseek v3.");
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` // calc needed amount of shared mem for `cumsum` tensors
// tensors
auto options_int = auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
// torch::Tensor token_cnts_buffer =
// torch::empty({(num_experts + 1) * num_experts}, options_int);
torch::Tensor cumsum_buffer = torch::Tensor cumsum_buffer =
torch::empty({num_experts + 1}, options_int); torch::zeros({num_experts + 1}, options_int);
auto kernel = vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>; auto align_kernel =
kernel<<<1, 1024, 0, stream>>>( vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
align_kernel<<<1, 1024, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(), topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>()); topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
const int block_threads = 256;
const int num_blocks =
(topk_ids.numel() + block_threads - 1) / block_threads;
const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks);
auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel<scalar_t>;
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
}); });
} }
......
...@@ -317,8 +317,11 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, ...@@ -317,8 +317,11 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
torch::Tensor& e, torch::Tensor const& a);
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_scale,
torch::Tensor const& input_scale);
#endif #endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
......
...@@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel( ...@@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel(
void rotary_embedding( void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size] // [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] // [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size, int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) { bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1); // num_tokens = batch_size * seq_len
int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}
// Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have consistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int seq_dim_idx = positions_ndim - 1;
int num_kv_heads = key.size(-1) / head_size; int64_t query_stride = query.stride(seq_dim_idx);
int64_t query_stride = query.stride(-2); int64_t key_stride = key.stride(seq_dim_idx);
int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
...@@ -165,19 +201,58 @@ and process in batched manner. ...@@ -165,19 +201,58 @@ and process in batched manner.
void batched_rotary_embedding( void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size] // [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] // [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size, int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, int64_t rot_dim, bool is_neox, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens] torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
) { ) {
// num_tokens = batch_size * seq_len
int64_t num_tokens = cos_sin_cache_offsets.size(0); int64_t num_tokens = cos_sin_cache_offsets.size(0);
int num_heads = query.size(-1) / head_size; TORCH_CHECK(
int num_kv_heads = key.size(-1) / head_size; positions.size(0) == num_tokens || positions.numel() == num_tokens,
int64_t query_stride = query.stride(-2); "positions must have the same num_tokens or batch_size as "
int64_t key_stride = key.stride(-2); "cos_sin_cache_offsets");
int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}
// Make sure head_size is valid for query and key
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have concistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
......
...@@ -334,7 +334,7 @@ __global__ void __launch_bounds__(64) ...@@ -334,7 +334,7 @@ __global__ void __launch_bounds__(64)
} }
// TODO: Shang: Hoist loop invariance. // TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int ax1_0_1 = 0; ax1_0_1 < (N / 32); ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) { for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 +
((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
......
# CUTLASS Epilogues # CUTLASS Epilogues
## Introduction ## Introduction
This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.
This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.
Currently, we only support symmetric quantization for weights, Currently, we only support symmetric quantization for weights,
and symmetric and asymmetric quantization for activations. and symmetric and asymmetric quantization for activations.
Both can be quantized per-tensor or per-channel (weights) / per-token (activations). Both can be quantized per-tensor or per-channel (weights) / per-token (activations).
There are 4 epilogues: There are 4 epilogues:
1. ScaledEpilogue: symmetric quantization for activations, no bias.
1. ScaledEpilogueBias: symmetric quantization for activations, supports bias. 1. `ScaledEpilogue`: symmetric quantization for activations, no bias.
1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias. 1. `ScaledEpilogueBias`: symmetric quantization for activations, supports bias.
1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias. 1. `ScaledEpilogueAzp`: asymmetric per-tensor quantization for activations, supports bias.
1. `ScaledEpilogueAzpPerToken`: asymmetric per-token quantization for activations, supports bias.
We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size. We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
Instead, if no bias is passed, the epilogue will use 0 as the bias. Instead, if no bias is passed, the epilogue will use 0 as the bias.
...@@ -26,12 +28,15 @@ If $` \widehat X `$ is the quantized $` X `$, our matrices become the following ...@@ -26,12 +28,15 @@ If $` \widehat X `$ is the quantized $` X `$, our matrices become the following
```math ```math
A = s_a (\widehat A - J_a z_a) A = s_a (\widehat A - J_a z_a)
``` ```
```math ```math
B = s_b \widehat B B = s_b \widehat B
``` ```
```math ```math
D = A B + C D = A B + C
``` ```
```math ```math
D = s_a s_b \widehat D + C D = s_a s_b \widehat D + C
``` ```
...@@ -48,9 +53,11 @@ Expanding further, we can calculate $` \widehat D `$ as follows: ...@@ -48,9 +53,11 @@ Expanding further, we can calculate $` \widehat D `$ as follows:
```math ```math
A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
``` ```
```math ```math
A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right) A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
``` ```
```math ```math
\widehat D = \widehat A \widehat B - z_a J_a \widehat B \widehat D = \widehat A \widehat B - z_a J_a \widehat B
``` ```
...@@ -61,16 +68,19 @@ Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of ...@@ -61,16 +68,19 @@ Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of
## Epilogues ## Epilogues
### ScaledEpilogue ### `ScaledEpilogue`
This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$. This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$.
The output of the GEMM is: The output of the GEMM is:
```math ```math
\widehat D = \widehat A \widehat B \widehat D = \widehat A \widehat B
``` ```
```math ```math
D = s_a s_b \widehat D D = s_a s_b \widehat D
``` ```
```math ```math
D = s_a s_b \widehat A \widehat B D = s_a s_b \widehat A \widehat B
``` ```
...@@ -79,44 +89,51 @@ Epilogue parameters: ...@@ -79,44 +89,51 @@ Epilogue parameters:
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
### ScaledEpilogueBias ### `ScaledEpilogueBias`
This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$. This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$.
The output of the GEMM is: The output of the GEMM is:
```math ```math
\widehat D = \widehat A \widehat B \widehat D = \widehat A \widehat B
``` ```
```math ```math
D = s_a s_b \widehat D + C D = s_a s_b \widehat D + C
``` ```
```math ```math
D = s_a s_b \widehat A \widehat B + C D = s_a s_b \widehat A \widehat B + C
``` ```
Epilogue parameters: Epilogue parameters:
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
- `bias` is the bias, is always per-channel (row-vector). - `bias` is the bias, is always per-channel (row-vector).
### ScaledEpilogueAzp ### `ScaledEpilogueAzp`
This epilogue computes the asymmetric per-tensor quantization for activations with bias. This epilogue computes the asymmetric per-tensor quantization for activations with bias.
The output of the GEMM is: The output of the GEMM is:
```math ```math
\widehat D = \widehat A \widehat B - z_a J_a \widehat B \widehat D = \widehat A \widehat B - z_a J_a \widehat B
``` ```
```math ```math
D = s_a s_b \widehat D + C D = s_a s_b \widehat D + C
``` ```
```math ```math
D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
``` ```
Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$.
That is precomputed and stored in `azp_with_adj` as a row-vector. That is precomputed and stored in `azp_with_adj` as a row-vector.
Epilogue parameters: Epilogue parameters:
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
- Generally this will be per-tensor as the zero-points are per-tensor. - Generally this will be per-tensor as the zero-points are per-tensor.
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
...@@ -125,13 +142,15 @@ Epilogue parameters: ...@@ -125,13 +142,15 @@ Epilogue parameters:
To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel. To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel.
### ScaledEpilogueAzpPerToken ### `ScaledEpilogueAzpPerToken`
This epilogue computes the asymmetric per-token quantization for activations with bias. This epilogue computes the asymmetric per-token quantization for activations with bias.
The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector. The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector.
That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$. That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$.
Epilogue parameters: Epilogue parameters:
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
- Generally this will be per-token as the zero-points are per-token. - Generally this will be per-token as the zero-points are per-token.
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
...@@ -142,6 +161,7 @@ Epilogue parameters: ...@@ -142,6 +161,7 @@ Epilogue parameters:
To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel. To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel.
The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM): The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM):
```
```math
out = scale_a * scale_b * (Dq - azp_adj * azp) + bias out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
``` ```
...@@ -53,12 +53,17 @@ struct cutlass_3x_gemm { ...@@ -53,12 +53,17 @@ struct cutlass_3x_gemm {
using EVTCompute = typename Epilogue::EVTCompute; using EVTCompute = typename Epilogue::EVTCompute;
// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;
using CollectiveEpilogue = using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder< typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD,
EpilogueSchedule, EVTCompute>::CollectiveOp; AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize = static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage); sizeof(typename CollectiveEpilogue::SharedStorage);
...@@ -69,8 +74,8 @@ struct cutlass_3x_gemm { ...@@ -69,8 +74,8 @@ struct cutlass_3x_gemm {
using CollectiveMainloop = using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder< typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementAB, cutlass::layout::RowMajor, 16, ElementAB, cutlass::layout::RowMajor, AlignmentAB,
ElementAB, cutlass::layout::ColumnMajor, 16, ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
ElementAcc, TileShape, ClusterShape, ElementAcc, TileShape, ClusterShape,
Stages, Stages,
KernelSchedule>::CollectiveOp; KernelSchedule>::CollectiveOp;
......
...@@ -103,14 +103,19 @@ struct cutlass_2x_gemm { ...@@ -103,14 +103,19 @@ struct cutlass_2x_gemm {
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>; using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
// These are the minimum alignments needed for the kernels to compile
static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = 4;
// clang-format off // clang-format off
using RowMajor = typename cutlass::layout::RowMajor; using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor; using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType = using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor< ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16, ElementAB, RowMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16, ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentAB,
float, cutlass::layout::RowMajor, 4, float, cutlass::layout::RowMajor, AlignmentCD,
ElementAcc, float, cutlass::arch::OpClassTensorOp, ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch, Arch,
TileShape, WarpShape, InstructionShape, TileShape, WarpShape, InstructionShape,
......
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void scaled_fp4_quant_sm100a(torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf);
#endif
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_sf, torch::Tensor const& input_sf) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization");
}
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "cuda_utils.h"
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
return val;
#else
return 0;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val;
#else
return 0;
#endif
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
int numCols,
SFType* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
CVT_FP4_NUM_THREADS_PER_SF == 2);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 16.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
outerMIdx * outerMStride + innerMIdx * innerMStride +
innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
#endif
return nullptr;
}
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
// Quantizes the provided PackedVec into the uint32_t output
template <class Type, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
fp8SFVal = tmp & 0xff;
// Convert back to fp32.
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
// Convert back to fp32.
SFValue = float(tmp);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(
SFValue * reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
#else
return 0;
#endif
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
#else
cvt_fp16_to_fp4(
#endif
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
uint32_t* out, uint32_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
colIdx += blockDim.x) {
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx, colIdx, numCols, SFout);
out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
}
#endif
}
template <typename T>
void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale,
int64_t* output, int32_t* SFOuput, bool useUE8M0,
int multiProcessorCount, cudaStream_t stream) {
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 2048 / block.x;
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
// Launch the cvt kernel.
if (useUE8M0) {
cvt_fp16_to_fp4<T, true><<<grid, block, 0, stream>>>(
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput));
} else {
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput));
}
}
// Instantiate the function.
template void invokeFP4Quantization(int m, int n, half const* input,
float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0,
int multiProcessorCount,
cudaStream_t stream);
template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input,
float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0,
int multiProcessorCount,
cudaStream_t stream);
void scaled_fp4_quant_sm100a(torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf) {
int32_t m = input.size(0);
int32_t n = input.size(1);
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
auto stream = at::cuda::getStreamFromPool(false, input.get_device());
if (stream == nullptr) {
std::cerr << "Warning: Null CUDA stream" << std::endl;
}
// We don't support e8m0 scales at this moment.
bool useUE8M0 = false;
switch (input.scalar_type()) {
case torch::kHalf: {
auto input_ptr = reinterpret_cast<half const*>(input.data_ptr());
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out,
useUE8M0, multiProcessorCount, stream);
break;
}
case torch::kBFloat16: {
auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr());
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out,
useUE8M0, multiProcessorCount, stream);
break;
}
default: {
std::cerr << "Observing: " << input.scalar_type()
<< " for the input datatype which is invalid";
throw std::runtime_error(
"Unsupported input data type for quantize_to_fp4.");
}
}
}
#pragma once #pragma once
#if defined(__HIPCC__) && \ #if defined(__HIPCC__) && defined(__gfx942__)
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300__ #define __HIP__MI300__
#endif #endif
......
...@@ -6,25 +6,25 @@ Machete is a spiritual successor to the Marlin kernel but optimized for Hopper a ...@@ -6,25 +6,25 @@ Machete is a spiritual successor to the Marlin kernel but optimized for Hopper a
Machete effectively performs Machete effectively performs
``` ```python
scale_type = w_s.dtype scale_type = w_s.dtype
compute_type = a.dtype compute_type = a.dtype
out = (w_q.to(scale_type) * w_s - w_z.to(scale_type)) @ a out = (w_q.to(scale_type) * w_s - w_z.to(scale_type)) @ a
``` ```
Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and
`w_z` is the quantization zeropoints. `w_z` is the quantization zeropoints.
> **_NOTE:_** `w_z` is added after the scales so we can > **_NOTE:_** `w_z` is added after the scales so we can
use FMA operations, but this means they must have the scales pre-applied if the use FMA operations, but this means they must have the scales pre-applied if the
supplied zeropoints assume that they will be subtracted before the scales are supplied zeropoints assume that they will be subtracted before the scales are
applied. applied.
## API ## API
The main optimization within Machete is prepacking the weight matrix to more closely match the tensor core layouts, allowing for wider shared memory loads when loading the weight matrix. This means that the weight matrix must be prepacked before calling `machete_gemm`. The flow looks something like: The main optimization within Machete is prepacking the weight matrix to more closely match the tensor core layouts, allowing for wider shared memory loads when loading the weight matrix. This means that the weight matrix must be prepacked before calling `machete_gemm`. The flow looks something like:
``` ```python
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
... ...
...@@ -40,6 +40,6 @@ output = ops.machete_gemm( ...@@ -40,6 +40,6 @@ output = ops.machete_gemm(
## Code Generation ## Code Generation
Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`. Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`.
New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate. New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate.
\ No newline at end of file
...@@ -24,8 +24,7 @@ ...@@ -24,8 +24,7 @@
#include "../attention/dtype_fp8.cuh" #include "../attention/dtype_fp8.cuh"
#include "../quantization/fp8/amd/quant_utils.cuh" #include "../quantization/fp8/amd/quant_utils.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300_MI250__ #define __HIP__MI300_MI250__
#endif #endif
...@@ -1122,4 +1121,4 @@ void paged_attention( ...@@ -1122,4 +1121,4 @@ void paged_attention(
#undef WARP_SIZE #undef WARP_SIZE
#undef MAX #undef MAX
#undef MIN #undef MIN
#undef DIVIDE_ROUND_UP #undef DIVIDE_ROUND_UP
\ No newline at end of file
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
#include "cutlass/numeric_conversion.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
// clang-format on
using namespace cute;
using namespace vllm;
/// Make A structured sparse by replacing elements with 0 and compress it
template <typename ElementA_, typename ElementAcc_>
bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
TORCH_CHECK(a.dim() == 2)
// Check for strides and alignment
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
TORCH_CHECK(a.stride(1) == 1)
int m = a.size(0);
int k = a.size(1);
// Sparse kernel setup; this kernel is not used for matmul,
// but just for setting up the compressor utility
// A matrix configuration
using ElementA = ElementA_;
using LayoutTagA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
// B matrix configuration
using ElementB = ElementA;
using LayoutTagB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
// C/D matrix configuration
using ElementC = float;
using LayoutTagC = cutlass::layout::ColumnMajor;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
// Core kernel configurations
using ElementAccumulator = ElementAcc_;
using TileShape = Shape<_128, _128, _128>;
using TileShapeRef = Shape<_128, _128, _64>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = typename std::conditional<
std::is_same_v<ElementA, cutlass::float_e4m3_t>,
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
cutlass::gemm::KernelTmaWarpSpecialized>::type;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using ProblemShape = Shape<int, int, int, int>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC,
AlignmentC, ElementC, LayoutTagC, AlignmentC,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA,
LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
using StrideE = StrideA;
using StrideA = Stride<int64_t, Int<1>, int64_t>;
// The n (=1) dimension does not matter for the compressor
typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1};
using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA;
using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE;
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
// Offline compressor kernel
using CompressorUtility =
cutlass::transform::kernel::StructuredSparseCompressorUtility<
ProblemShape, ElementA, LayoutTagA, SparseConfig>;
using CompressorKernel =
cutlass::transform::kernel::StructuredSparseCompressor<
ProblemShape, ElementA, LayoutTagA, SparseConfig,
cutlass::arch::Sm90>;
using Compressor =
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
auto [M, N, K, L] = prob_shape;
StrideA stride_A;
stride_A =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
CompressorUtility compressor_utility(prob_shape, stride_A);
int ME = compressor_utility.get_metadata_m_physical();
int KE = compressor_utility.get_metadata_k_physical();
int KC = compressor_utility.get_tensorA_k_physical();
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
auto a_meta_ptr = static_cast<typename Gemm::CollectiveMainloop::ElementE*>(
a_meta.data_ptr());
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
typename Compressor::Arguments arguments{
prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}};
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(compressor_op.can_implement(arguments));
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(compressor_op.run());
CUDA_CHECK(cudaDeviceSynchronize());
return true;
}
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a) {
if (a.dtype() == torch::kBFloat16) {
return cutlass_sparse_compress<cutlass::bfloat16_t, float>(a_nzs, a_meta,
a);
} else if (a.dtype() == torch::kFloat16) {
return cutlass_sparse_compress<cutlass::half_t, float>(a_nzs, a_meta, a);
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
return cutlass_sparse_compress<cutlass::float_e4m3_t, float>(a_nzs, a_meta,
a);
} else if (a.dtype() == torch::kInt8) {
return cutlass_sparse_compress<int8_t, int32_t>(a_nzs, a_meta, a);
}
return false;
}
#endif
#pragma once
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
#include "cutlass/numeric_conversion.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
// clang-format on
using namespace cute;
using namespace vllm;
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
/// Make A structured sparse by replacing elements with 0 and compress it
template <typename Gemm>
CompressorResult cutlass_sparse_compress(torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
TORCH_CHECK(a.dim() == 2)
// Check for strides and alignment
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
TORCH_CHECK(a.stride(1) == 1)
using GemmKernel = typename Gemm::KernelType;
using ElementA = typename Gemm::ElementAB;
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
int m = a.size(0);
int k = a.size(1);
using ProblemShape = typename GemmKernel::ProblemShape;
ProblemShape prob_shape{m, 1, k, 1};
int64_t lda = a.stride(0);
using StrideA = Stride<int64_t, Int<1>, int64_t>;
StrideA a_stride{lda, Int<1>{}, 0};
using CompressorUtility = typename Gemm::CompressorUtility;
CompressorUtility compressor_utility(prob_shape, a_stride);
// Allocate buffers for the metadata E and the compressed matrix A
int ME = compressor_utility.get_metadata_m_physical();
int KE = compressor_utility.get_metadata_k_physical();
int MC = compressor_utility.get_tensorA_m_physical();
int KC = compressor_utility.get_tensorA_k_physical();
auto const a_meta_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto const a_nzs_options =
torch::TensorOptions().dtype(a.dtype()).device(a.device());
auto a_meta = torch::zeros({ME, KE}, a_meta_options);
auto a_nzs = torch::zeros({MC, KC}, a_nzs_options);
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
auto a_meta_ptr = static_cast<ElementE*>(a_meta.data_ptr());
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = a.device().index();
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
using Compressor = typename Gemm::Compressor;
typename Compressor::Arguments arguments{
prob_shape, {a_ptr, a_stride, a_nzs_ptr, a_meta_ptr}, {hw_info}};
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(compressor_op.can_implement(arguments));
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.data_ptr()));
CUTLASS_CHECK(compressor_op.run());
CUDA_CHECK(cudaDeviceSynchronize());
return {a_meta, a_nzs};
}
#endif
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a);
#endif
bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta,
torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2);
TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) &&
a_nzs.size(1) * 2 == a.size(1) &&
a_meta.size(1) * 2 * 4 == a.size(1));
// Considering elemsPerMetaElem = 8b / 2b_per_nz = 4
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 &&
a_meta.stride(1) == 1); // Row-major
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if (version_num >= 90) {
return cutlass_sparse_compress_sm90(a_nzs, a_meta, a);
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}
...@@ -9,17 +9,30 @@ ...@@ -9,17 +9,30 @@
using namespace cute; using namespace cute;
using namespace vllm; using namespace vllm;
struct GemmCallerTraits {
using return_type = void;
template <typename GemmConfig, typename... Args>
static return_type invoke(Args&&... args) {
return cutlass_sparse_gemm_caller<GemmConfig>(std::forward<Args>(args)...);
}
};
struct GemmCompressorTraits {
using return_type = CompressorResult;
template <typename GemmConfig, typename... Args>
static return_type invoke(Args&&... args) {
return cutlass_sparse_compress<GemmConfig>(std::forward<Args>(args)...);
}
};
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename DispatchFunc, typename... Args>
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, typename DispatchFunc::return_type cutlass_gemm_sm90_fp8_dispatch(
torch::Tensor const& bt_nzs, uint32_t m, uint32_t n, Args&&... args) {
torch::Tensor const& bt_meta, static_assert(std::is_same_v<InType, cutlass::float_e4m3_t>);
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm; typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
...@@ -49,122 +62,87 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, ...@@ -49,122 +62,87 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
using Cutlass3xGemm8 = using Cutlass3xGemm8 =
typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm; typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const n = bt_nzs.size(0);
uint32_t const m = a.size(0); // Batch size
uint32_t const mp2 = uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2 std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
if (mp2 <= 64) { if (mp2 <= 64) {
if (n == 28672) { if (n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm2>( return DispatchFunc::template invoke<Cutlass3xGemm2>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (n == 4096 || n == 6144) { } else if (n == 4096 || n == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm1>( return DispatchFunc::template invoke<Cutlass3xGemm1>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} else if (mp2 <= 128) { } else if (mp2 <= 128) {
if (n == 4096) { if (n == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm3>( return DispatchFunc::template invoke<Cutlass3xGemm3>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (n == 28672) { } else if (n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm5>( return DispatchFunc::template invoke<Cutlass3xGemm5>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (n == 6144) { } else if (n == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm4>( return DispatchFunc::template invoke<Cutlass3xGemm4>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} else if (mp2 <= 256) { } else if (mp2 <= 256) {
if (n == 4096) { if (n == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm6>( return DispatchFunc::template invoke<Cutlass3xGemm6>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (n == 28672) { } else if (n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>( return DispatchFunc::template invoke<Cutlass3xGemm8>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (n == 6144) { } else if (n == 6144) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>( return DispatchFunc::template invoke<Cutlass3xGemm7>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} else { } else {
if (n == 6144 || n == 28672) { if (n == 6144 || n == 28672) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>( return DispatchFunc::template invoke<Cutlass3xGemm8>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (n == 4096) { } else if (n == 4096) {
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>( return DispatchFunc::template invoke<Cutlass3xGemm7>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} }
// Otherwise the default heuristic // Otherwise the default heuristic
if (mp2 <= 64) { if (mp2 <= 64) {
// n in [1, 64] // n in [1, 64]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>( return DispatchFunc::template invoke<Cutlass3xGemmM64>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (mp2 <= 128) { } else if (mp2 <= 128) {
// n in (64, 128] // n in (64, 128]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>( return DispatchFunc::template invoke<Cutlass3xGemmM128>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (mp2 <= 256) { } else if (mp2 <= 256) {
// n in (128, 256] // n in (128, 256]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM256>( return DispatchFunc::template invoke<Cutlass3xGemmM256>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else { } else {
// n in (256, inf) // n in (256, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmM512>( return DispatchFunc::template invoke<Cutlass3xGemmM512>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} }
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename DispatchFunc, typename... Args>
void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a, typename DispatchFunc::return_type cutlass_gemm_sm90_16bit_dispatch(
torch::Tensor const& bt_nzs, uint32_t m, uint32_t n, Args&&... args) {
torch::Tensor const& bt_meta,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::half_t>());
TORCH_CHECK(a.dtype() == torch::kFloat16);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
// m in (128, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::bfloat16_t>());
TORCH_CHECK(a.dtype() == torch::kBFloat16);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm; typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
// m in (128, inf) return DispatchFunc::template invoke<Cutlass3xGemmDefault>(
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>( std::forward<Args>(args)...);
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
} }
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename DispatchFunc, typename... Args>
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, typename DispatchFunc::return_type cutlass_gemm_sm90_int8_dispatch(
torch::Tensor const& bt_nzs, uint32_t m, uint32_t n, Args&&... args) {
torch::Tensor const& bt_meta, static_assert(std::is_same_v<InType, int8_t>);
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm; typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
...@@ -179,37 +157,35 @@ void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, ...@@ -179,37 +157,35 @@ void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
typename sm90_int8_config_M32_NSmall<InType, OutType, typename sm90_int8_config_M32_NSmall<InType, OutType,
Epilogue>::Cutlass3xGemm; Epilogue>::Cutlass3xGemm;
uint32_t const n = out.size(1);
bool const is_small_n = n < 8192; bool const is_small_n = n < 8192;
uint32_t const m = a.size(0);
uint32_t const mp2 = uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2 std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) { if (mp2 <= 32) {
// m in [1, 32] // m in [1, 32]
if (is_small_n) { if (is_small_n) {
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NSmall>( return DispatchFunc::template invoke<Cutlass3xGemmM32NSmall>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else { } else {
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NBig>( return DispatchFunc::template invoke<Cutlass3xGemmM32NBig>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} else if (mp2 <= 64) { } else if (mp2 <= 64) {
// m in (32, 64] // m in (32, 64]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>( return DispatchFunc::template invoke<Cutlass3xGemmM64>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else if (mp2 <= 128) { } else if (mp2 <= 128) {
// m in (64, 128] // m in (64, 128]
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>( return DispatchFunc::template invoke<Cutlass3xGemmM128>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} else { } else {
// m in (128, inf) // m in (128, inf)
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>( return DispatchFunc::template invoke<Cutlass3xGemmDefault>(
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...); std::forward<Args>(args)...);
} }
} }
// Dispatch to GEMM implementations based on element types
template <template <typename, typename, typename> typename Epilogue, template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
...@@ -217,19 +193,24 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, ...@@ -217,19 +193,24 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
torch::Tensor const& bt_nzs, torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta, torch::Tensor const& bt_meta,
EpilogueArgs&&... epilogue_args) { EpilogueArgs&&... epilogue_args) {
uint32_t const m = out.size(0);
uint32_t const n = out.size(1);
// TODO: add dispatch functions to all of these
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8); TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
if (a.dtype() == torch::kInt8) { if (a.dtype() == torch::kInt8) {
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8); TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t, return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>( Epilogue, GemmCallerTraits>(
out, a, bt_nzs, bt_meta, m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...); std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>( return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue,
out, a, bt_nzs, bt_meta, GemmCallerTraits>(
m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...); std::forward<EpilogueArgs>(epilogue_args)...);
} }
} else if (a.dtype() == torch::kFloat8_e4m3fn) { } else if (a.dtype() == torch::kFloat8_e4m3fn) {
...@@ -237,47 +218,34 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, ...@@ -237,47 +218,34 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
if (out.dtype() == torch::kBFloat16) { if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>( cutlass::bfloat16_t, Epilogue,
out, a, bt_nzs, bt_meta, GemmCallerTraits>(
m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...); std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm90_fp8_dispatch<
cutlass::half_t, Epilogue>( cutlass::float_e4m3_t, cutlass::half_t, Epilogue, GemmCallerTraits>(
out, a, bt_nzs, bt_meta, m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...); std::forward<EpilogueArgs>(epilogue_args)...);
} }
} else if (a.dtype() == torch::kFloat16) { } else if (a.dtype() == torch::kFloat16) {
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16); TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
TORCH_CHECK(out.dtype() == torch::kFloat16);
if (out.dtype() == torch::kBFloat16) { return cutlass_gemm_sm90_16bit_dispatch<cutlass::half_t, cutlass::half_t,
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t, Epilogue, GemmCallerTraits>(
cutlass::bfloat16_t, Epilogue>( m, n, out, a, bt_nzs, bt_meta,
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(epilogue_args)...);
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t, cutlass::half_t,
Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} else { // a.dtype() == torch::kBFloat16 } else { // a.dtype() == torch::kBFloat16
TORCH_CHECK(a.dtype() == torch::kBFloat16); TORCH_CHECK(a.dtype() == torch::kBFloat16);
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16); TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
TORCH_CHECK(out.dtype() == torch::kBFloat16);
if (out.dtype() == torch::kBFloat16) { return cutlass_gemm_sm90_16bit_dispatch<
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t, Epilogue, GemmCallerTraits>(
cutlass::bfloat16_t, Epilogue>( m, n, out, a, bt_nzs, bt_meta,
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(epilogue_args)...);
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
cutlass::half_t, Epilogue>(
out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} }
} }
...@@ -287,17 +255,53 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a, ...@@ -287,17 +255,53 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype()); "CUTLASS scaled_mm bias dtype must match output dtype ",
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogueBias>( out.dtype());
out, a, bt_nzs, bt_meta, b_scales, a_scales, *bias); return cutlass_scaled_sparse_mm_sm90_epilogue<
c3x::ScaledEpilogueColumnBias>(out, a, bt_nzs, bt_meta, b_scales,
a_scales, *bias);
} else { } else {
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>( return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>(
out, a, bt_nzs, bt_meta, b_scales, a_scales); out, a, bt_nzs, bt_meta, b_scales, a_scales);
} }
} }
CompressorResult cutlass_sparse_compress_sm90(torch::Tensor const& a) {
// These m and n variables are fordispatching to different GEMM algorithms.
uint32_t const m = 1; // Set M to 1 for compression
uint32_t const n = a.size(1);
// Note: For correctess, the compressed format must be invariant in:
// - M, the flattened number of tokens
// - Whether output dtype is fp16 or bf16
// - CUTLASS epilogues
if (a.dtype() == torch::kInt8) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
c3x::TrivialEpilogue,
GemmCompressorTraits>(m, n, a);
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
return cutlass_gemm_sm90_fp8_dispatch<
cutlass::float_e4m3_t, cutlass::bfloat16_t, c3x::TrivialEpilogue,
GemmCompressorTraits>(m, n, a);
} else if (a.dtype() == torch::kFloat16) {
return cutlass_gemm_sm90_16bit_dispatch<
cutlass::bfloat16_t, cutlass::bfloat16_t, c3x::TrivialEpilogue,
GemmCompressorTraits>(m, n, a);
} else {
TORCH_CHECK(a.dtype() == torch::kBFloat16,
"cutlass_sparse_compress only supports int8, fp8_e4m3, fp16, "
"and bf16 datatypes");
return cutlass_gemm_sm90_16bit_dispatch<cutlass::half_t, cutlass::half_t,
c3x::TrivialEpilogue,
GemmCompressorTraits>(m, n, a);
}
}
#endif #endif
#pragma once
// clang-format will break include orders // clang-format will break include orders
// clang-format off // clang-format off
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
...@@ -12,6 +14,9 @@ ...@@ -12,6 +14,9 @@
#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "core/math.hpp" #include "core/math.hpp"
#include "cutlass_extensions/cute_utils.cuh" #include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
...@@ -22,7 +27,7 @@ ...@@ -22,7 +27,7 @@
using namespace cute; using namespace cute;
/* /*
This file defines sparse quantized GEMM operations using the CUTLASS 3.x API, This file defines 2:4 sparse GEMM operations using the CUTLASS 3.x API,
for NVIDIA GPUs with sm90a (Hopper) or later. for NVIDIA GPUs with sm90a (Hopper) or later.
*/ */
...@@ -45,17 +50,20 @@ struct enable_sm90_or_later : Kernel { ...@@ -45,17 +50,20 @@ struct enable_sm90_or_later : Kernel {
using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; using GemmUniversalMode = cutlass::gemm::GemmUniversalMode;
/*
* cutlass_sparse_3x_gemm defines a 2:4 sparse GEMM kernel via CUTLASS
* for SM90 Hopper systems.
*/
template <typename ElementAB_, typename ElementD_, template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_, template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule, typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule, typename AccType, typename EpilogueSchedule>
typename TileSchedule = cutlass::gemm::PersistentScheduler,
GemmUniversalMode Mode_ = GemmUniversalMode::kGemm>
struct cutlass_sparse_3x_gemm { struct cutlass_sparse_3x_gemm {
static const GemmUniversalMode Mode = Mode_;
using ElementAB = ElementAB_; using ElementAB = ElementAB_;
using ElementD = ElementD_; using ElementD = ElementD_;
using ElementAcc = AccType; using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using EpilogueDescriptor = using EpilogueDescriptor =
cutlass::epilogue::collective::detail::EpilogueDescriptor< cutlass::epilogue::collective::detail::EpilogueDescriptor<
...@@ -66,30 +74,22 @@ struct cutlass_sparse_3x_gemm { ...@@ -66,30 +74,22 @@ struct cutlass_sparse_3x_gemm {
using ElementC = void; using ElementC = void;
using LayoutC = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using LayoutC_Transpose = using LayoutC_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutC>::type; typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
using EVTCompute = typename Epilogue::EVTCompute; using EVTCompute = typename Epilogue::EVTCompute;
static constexpr int AlignmentA = // These are the minimum alignments needed for the kernels to compile
128 / cutlass::sizeof_bits<ElementAB>::value; static constexpr int AlignmentAB =
static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementAB>::value; 128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentCD = static constexpr int AlignmentCD = 4;
128 / cutlass::sizeof_bits<ElementD>::value;
using CollectiveEpilogue = using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder< typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAcc, ElementAcc, ElementC, LayoutC_Transpose, AlignmentCD, ElementAcc, float, ElementC, LayoutC_Transpose, AlignmentCD, ElementD,
ElementD, LayoutD_Transpose, AlignmentCD, EpilogueSchedule, LayoutC_Transpose, AlignmentCD, EpilogueSchedule,
EVTCompute>::CollectiveOp; EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize = static constexpr size_t CEStorageSize =
...@@ -101,8 +101,8 @@ struct cutlass_sparse_3x_gemm { ...@@ -101,8 +101,8 @@ struct cutlass_sparse_3x_gemm {
using CollectiveMainloop = using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder< typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
ElementAB, cutlass::layout::RowMajor, AlignmentA, ElementAB, cutlass::layout::RowMajor, AlignmentAB,
ElementAB, cutlass::layout::ColumnMajor, AlignmentB, ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
ElementAcc, TileShape, ClusterShape, ElementAcc, TileShape, ClusterShape,
Stages, Stages,
KernelSchedule>::CollectiveOp; KernelSchedule>::CollectiveOp;
...@@ -110,11 +110,100 @@ struct cutlass_sparse_3x_gemm { ...@@ -110,11 +110,100 @@ struct cutlass_sparse_3x_gemm {
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal< using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
TileSchedule>>; cutlass::gemm::PersistentScheduler>>;
struct GemmKernel : public KernelType {}; struct GemmKernel : public KernelType {};
// Sparse compressor definitions
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
using LayoutTagA = cutlass::layout::RowMajor;
using CompressorUtility =
cutlass::transform::kernel::StructuredSparseCompressorUtility<
typename GemmKernel::ProblemShape, ElementAB, LayoutTagA,
SparseConfig>;
using CompressorKernel =
cutlass::transform::kernel::StructuredSparseCompressor<
typename GemmKernel::ProblemShape, ElementAB, LayoutTagA,
SparseConfig, cutlass::arch::Sm90>;
using Compressor =
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
}; };
/*
* This class defines kernel to compress a 2:4 sparse matrix.
* The particular format is defined by the Gemm template parameter,
* which is a cutlass_sparse_3x_gemm.
*/
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
/// Make A structured sparse by replacing elements with 0 and compress it
template <typename Gemm>
CompressorResult cutlass_sparse_compress(torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
TORCH_CHECK(a.dim() == 2)
// Check for strides and alignment
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
TORCH_CHECK(a.stride(1) == 1)
using GemmKernel = typename Gemm::KernelType;
using ElementA = typename Gemm::ElementAB;
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
int m = a.size(0);
int k = a.size(1);
using ProblemShape = typename GemmKernel::ProblemShape;
ProblemShape prob_shape{m, 1, k, 1};
int64_t lda = a.stride(0);
using StrideA = Stride<int64_t, Int<1>, int64_t>;
StrideA a_stride{lda, Int<1>{}, 0};
using CompressorUtility = typename Gemm::CompressorUtility;
CompressorUtility compressor_utility(prob_shape, a_stride);
// Allocate buffers for the metadata E and the compressed matrix A
int ME = compressor_utility.get_metadata_m_physical();
int KE = compressor_utility.get_metadata_k_physical();
int MC = compressor_utility.get_tensorA_m_physical();
int KC = compressor_utility.get_tensorA_k_physical();
auto const a_meta_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto const a_nzs_options =
torch::TensorOptions().dtype(a.dtype()).device(a.device());
auto a_meta = torch::zeros({ME, KE}, a_meta_options);
auto a_nzs = torch::zeros({MC, KC}, a_nzs_options);
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
auto a_meta_ptr = static_cast<ElementE*>(a_meta.data_ptr());
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = a.device().index();
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
using Compressor = typename Gemm::Compressor;
typename Compressor::Arguments arguments{
prob_shape, {a_ptr, a_stride, a_nzs_ptr, a_meta_ptr}, {hw_info}};
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(compressor_op.can_implement(arguments));
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.data_ptr()));
CUTLASS_CHECK(compressor_op.run());
CUDA_CHECK(cudaDeviceSynchronize());
return {a_meta, a_nzs};
}
template <typename Gemm, typename... EpilogueArgs> template <typename Gemm, typename... EpilogueArgs>
void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a, void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs, torch::Tensor const& bt_nzs,
...@@ -126,27 +215,25 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a, ...@@ -126,27 +215,25 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
// Interface stride expected from the argument a (will get transposed) // Interface stride expected from the argument a (will get transposed)
// We compute C^T = B^T * A^T, but we assume B is transposed before // We compute C^T = B^T * A^T, but we assume B is transposed before
// compression and hence the bt_* naming // compression and hence the bt_* naming
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; using LayoutB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
using LayoutD = cutlass::layout::RowMajor;
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>; // M, N, K after transposition
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>; int32_t m = out.size(1);
int32_t n = out.size(0);
int32_t k = a.size(1);
int64_t lda = a.stride(0);
int64_t ldc = out.stride(0);
auto layout_A = make_cute_layout<StrideA>(a, "A"); using StrideA = Stride<int64_t, Int<1>, int64_t>;
auto layout_D = make_cute_layout<StrideD>(out, "D"); using StrideC = Stride<Int<1>, int64_t, int64_t>;
// Transpose A and D StrideA a_stride{lda, Int<1>{}, Int<0>{}};
// A doesn't need to be transposed since cutlass expects a NxK matrix StrideC c_stride{Int<1>{}, ldc, Int<0>{}};
// for B (which is At)
auto stride_At = layout_A.stride();
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
typename GemmKernel::ProblemShape prob_shape{ typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
static_cast<int>(bt_nzs.size(0)), static_cast<int>(size<0>(layout_A)),
static_cast<int>(size<1>(layout_A)), 1};
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE; using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig; using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
...@@ -158,13 +245,13 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a, ...@@ -158,13 +245,13 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
auto b_ptr = static_cast<ElementAB*>(bt_nzs.data_ptr()); auto b_ptr = static_cast<ElementAB*>(bt_nzs.data_ptr());
auto e_ptr = static_cast<ElementE*>(bt_meta.data_ptr()); auto e_ptr = static_cast<ElementE*>(bt_meta.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{ typename GemmKernel::MainloopArguments mainloop_args{
b_ptr, b_layout, a_ptr, stride_At, e_ptr, e_layout}; b_ptr, b_layout, a_ptr, a_stride, e_ptr, e_layout};
auto c_ptr = static_cast<ElementD*>(out.data_ptr()); auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{ typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args( Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...), std::forward<EpilogueArgs>(epilogue_params)...),
c_ptr, stride_Dt, c_ptr, stride_Dt}; c_ptr, c_stride, c_ptr, c_stride};
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape, mainloop_args, epilogue_args}; prob_shape, mainloop_args, epilogue_args};
...@@ -185,6 +272,10 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a, ...@@ -185,6 +272,10 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK(status); CUTLASS_CHECK(status);
} }
//////////////////////////////////////////////////
// Gemm Configs are defined below
//////////////////////////////////////////////////
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm90_config_default {}; struct sm90_config_default {};
...@@ -192,28 +283,25 @@ struct sm90_config_default {}; ...@@ -192,28 +283,25 @@ struct sm90_config_default {};
template <typename OutType, template <typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm90_config_default<half_t, OutType, Epilogue> { struct sm90_config_default<half_t, OutType, Epilogue> {
// M in (128, inf) using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>; using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<half_t, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<half_t, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename OutType, template <typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm90_config_default<cutlass::bfloat16_t, OutType, Epilogue> { struct sm90_config_default<cutlass::bfloat16_t, OutType, Epilogue> {
// M in (128, inf) using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>; using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape, cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule, ClusterShape, KernelSchedule, EpilogueSchedule>;
float>;
}; };
//////////////////////// Cherry-Picking Kernels //////////////////////// //////////////////////// Cherry-Picking Kernels ////////////////////////
...@@ -227,7 +315,7 @@ struct sm90_fp8_config_1 { ...@@ -227,7 +315,7 @@ struct sm90_fp8_config_1 {
using ClusterShape = Shape<_8, _1, _1>; using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -242,7 +330,7 @@ struct sm90_fp8_config_2 { ...@@ -242,7 +330,7 @@ struct sm90_fp8_config_2 {
using ClusterShape = Shape<_8, _1, _1>; using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -255,7 +343,7 @@ struct sm90_fp8_config_3 { ...@@ -255,7 +343,7 @@ struct sm90_fp8_config_3 {
using ClusterShape = Shape<_1, _2, _1>; using ClusterShape = Shape<_1, _2, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -269,7 +357,7 @@ struct sm90_fp8_config_4 { ...@@ -269,7 +357,7 @@ struct sm90_fp8_config_4 {
using ClusterShape = Shape<_8, _1, _1>; using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -283,7 +371,7 @@ struct sm90_fp8_config_5 { ...@@ -283,7 +371,7 @@ struct sm90_fp8_config_5 {
using ClusterShape = Shape<_8, _1, _1>; using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -296,7 +384,7 @@ struct sm90_fp8_config_6 { ...@@ -296,7 +384,7 @@ struct sm90_fp8_config_6 {
using ClusterShape = Shape<_1, _2, _1>; using ClusterShape = Shape<_1, _2, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -311,7 +399,7 @@ struct sm90_fp8_config_7 { ...@@ -311,7 +399,7 @@ struct sm90_fp8_config_7 {
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -326,7 +414,7 @@ struct sm90_fp8_config_8 { ...@@ -326,7 +414,7 @@ struct sm90_fp8_config_8 {
using ClusterShape = Shape<_8, _1, _1>; using ClusterShape = Shape<_8, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float>; KernelSchedule, EpilogueSchedule>;
}; };
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
...@@ -341,7 +429,7 @@ struct sm90_config_default<cutlass::float_e4m3_t, OutType, Epilogue> { ...@@ -341,7 +429,7 @@ struct sm90_config_default<cutlass::float_e4m3_t, OutType, Epilogue> {
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue, cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue,
TileShape, ClusterShape, KernelSchedule, TileShape, ClusterShape, KernelSchedule,
EpilogueSchedule, float>; EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -355,12 +443,9 @@ struct sm90_fp8_config_M64 { ...@@ -355,12 +443,9 @@ struct sm90_fp8_config_M64 {
using TileShape = Shape<_64, _64, _256>; using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using TileSchedule = cutlass::gemm::PersistentScheduler;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float, KernelSchedule, EpilogueSchedule>;
TileSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -374,12 +459,9 @@ struct sm90_fp8_config_M128 { ...@@ -374,12 +459,9 @@ struct sm90_fp8_config_M128 {
using TileShape = Shape<_64, _128, _256>; using TileShape = Shape<_64, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using TileSchedule = cutlass::gemm::PersistentScheduler;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float, KernelSchedule, EpilogueSchedule>;
TileSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -394,12 +476,9 @@ struct sm90_fp8_config_M256 { ...@@ -394,12 +476,9 @@ struct sm90_fp8_config_M256 {
using TileShape = Shape<_128, _128, _256>; using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using TileSchedule = cutlass::gemm::PersistentScheduler;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float, KernelSchedule, EpilogueSchedule>;
TileSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -414,12 +493,9 @@ struct sm90_fp8_config_M512 { ...@@ -414,12 +493,9 @@ struct sm90_fp8_config_M512 {
using TileShape = Shape<_128, _128, _256>; using TileShape = Shape<_128, _128, _256>;
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using TileSchedule = cutlass::gemm::PersistentScheduler;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, float, KernelSchedule, EpilogueSchedule>;
TileSchedule>;
}; };
template <typename OutType, template <typename OutType,
...@@ -433,7 +509,7 @@ struct sm90_config_default<int8_t, OutType, Epilogue> { ...@@ -433,7 +509,7 @@ struct sm90_config_default<int8_t, OutType, Epilogue> {
using ClusterShape = Shape<_2, _1, _1>; using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<int8_t, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<int8_t, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -448,7 +524,7 @@ struct sm90_int8_config_M128 { ...@@ -448,7 +524,7 @@ struct sm90_int8_config_M128 {
using ClusterShape = Shape<_2, _1, _1>; using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -462,7 +538,7 @@ struct sm90_int8_config_M64 { ...@@ -462,7 +538,7 @@ struct sm90_int8_config_M64 {
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -476,7 +552,7 @@ struct sm90_int8_config_M32_NBig { ...@@ -476,7 +552,7 @@ struct sm90_int8_config_M32_NBig {
using ClusterShape = Shape<_1, _4, _1>; using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
...@@ -490,7 +566,7 @@ struct sm90_int8_config_M32_NSmall { ...@@ -490,7 +566,7 @@ struct sm90_int8_config_M32_NSmall {
using ClusterShape = Shape<_1, _8, _1>; using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, int32_t>; KernelSchedule, EpilogueSchedule>;
}; };
} // namespace } // namespace
\ No newline at end of file
...@@ -23,6 +23,9 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a, ...@@ -23,6 +23,9 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
CompressorResult cutlass_sparse_compress_sm90(torch::Tensor const& a);
#endif #endif
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
...@@ -68,3 +71,30 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a, ...@@ -68,3 +71,30 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
"CUDA device capability: ", "CUDA device capability: ",
version_num); version_num);
} }
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a) {
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1); // Row-major
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if (version_num >= 90) {
std::vector<torch::Tensor> result_tensors;
auto [a_meta, a_nzs] = cutlass_sparse_compress_sm90(a);
result_tensors.push_back(std::move(a_nzs));
result_tensors.push_back(std::move(a_meta));
return result_tensors;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_sparse_compress for a compute capability less than "
"CUDA device capability: ",
version_num);
}
...@@ -538,10 +538,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -538,10 +538,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm); ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
// CUTLASS sparse matrix compressor // CUTLASS sparse matrix compressor
ops.def( ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
"cutlass_sparse_compress_entry(Tensor! a_nzs, Tensor! a_meta," ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
" Tensor a) -> bool");
ops.impl("cutlass_sparse_compress_entry", &cutlass_sparse_compress_entry);
// Mamba selective scan kernel // Mamba selective scan kernel
ops.def( ops.def(
...@@ -577,6 +575,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -577,6 +575,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"bool silu_activation," "bool silu_activation,"
"int pad_slot_id) -> ()"); "int pad_slot_id) -> ()");
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor! output, Tensor input,"
" Tensor! output_scale, Tensor input_scale) -> ()");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
#endif #endif
// Quantized GEMM for GPTQ. // Quantized GEMM for GPTQ.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment