Commit 082094b7 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Multiple updates and refactorings (#150)

* Multiple updates and refactorings

* Remove dead code
parent 1408756a
#pragma once
#include "kerutils/device/common.h"
namespace kerutils {
// st.async (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-async)
template<typename T>
CUTE_DEVICE
static void st_async(void* dst_ptr, const T& data, transac_bar_t &mbar) {
static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async.");
long2 data_long2 = *reinterpret_cast<const long2*>(&data);
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar);
asm volatile (
"st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n"
:
: "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr)
);
}
static constexpr int PEER_ADDR_MASK = 16777216;
// Given an address in the current CTA, return the corresponding address in the peer CTA
template<typename T>
CUTE_DEVICE
T* get_peer_addr(const T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
}
// Given an address in the current CTA, return the corresponding address in the peer CTA (if the current CTA_id%2 == 1) or the address itself (if CTA_id%2 == 0)
template<typename T>
CUTE_DEVICE
T* get_cta0_addr(const T* p) {
constexpr int CTA0_ADDR_MASK = 0xFEFFFFFF;
return (T*)((int64_t)(p) & CTA0_ADDR_MASK);
}
// TMA bulk reduce add (cp.reduce.async.bulk), shared to global, float32, add. (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk)
CUTE_DEVICE
void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) {
uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(src_ptr);
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
:
: "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes)
: "memory");
}
// Cluster barrier arrive with .release modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)
CUTE_DEVICE
void barrier_cluster_arrive_release() {
asm volatile("barrier.cluster.arrive.release;" : : : "memory");
}
// Cluster barrier arrive with .relaxed modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)
CUTE_DEVICE
void barrier_cluster_arrive_relaxed() {
asm volatile("barrier.cluster.arrive.relaxed;" : : :);
}
// Cluster barrier wait with .acquire modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)
CUTE_DEVICE
void barrier_cluster_wait_acquire() {
asm volatile("barrier.cluster.wait.acquire;" : : : "memory");
}
// mbarrier.arrive with .relaxed.cluster qualifier (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
CUTE_DEVICE
void mbarrier_arrive_relaxed_cluster(transac_bar_t &mbar) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(&mbar);
asm volatile(
"{\n\t"
"mbarrier.arrive.relaxed.cluster.shared::cta.b64 _, [%0];\n\t"
"}"
:
: "r"(smem_addr));
}
// AtomicAdd with v4.f32 type (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red)
CUTE_DEVICE
void atomicadd_f32x4_with_policy_and_pred(void* global_addr, const float4 &data, int64_t cache_policy, uint32_t pred = true) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.eq.u32 p, %6, 1;\n\t"
"@p red.relaxed.gpu.global.add.L2::cache_hint.v4.f32 [%4], {%0, %1, %2, %3}, %5; \n\t"
"}"
:
: "f"(data.x), "f"(data.y), "f"(data.z), "f"(data.w),
"l"((int64_t)global_addr), "l"(cache_policy), "r"(pred)
);
}
// cp.async.bulk, from .shared::cta to .shared::cluster (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk)
CUTE_DEVICE
void cp_async_bulk_shared_cta_to_shared_cluster(void* dst_ptr, const void* src_ptr, int32_t load_bytes, transac_bar_t &mbar) {
uint32_t dst_smem_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t src_smem_addr = cute::cast_smem_ptr_to_uint(src_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar);
asm volatile(
"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \n"
:
: "r"(dst_smem_addr), "r"(src_smem_addr), "r"(load_bytes), "r"(mbar_addr)
);
}
}
#pragma once
#include <exception>
#include <string>
#include <sstream>
#include <vector>
#include <cuda_runtime_api.h>
#include <cuda.h>
#include <cutlass/cuda_host_adapter.hpp>
#include "kerutils/common/common.h"
namespace kerutils {
class KUException final : public std::exception {
std::string message = {};
public:
template<typename... Args>
explicit KUException(const char *name, const char* file, const int line, Args&&... args) {
std::ostringstream oss;
oss << name << " error (" << file << ":" << line << "): ";
(oss << ... << args);
message = oss.str();
}
const char *what() const noexcept override {
return message.c_str();
}
};
#define THROW_KU_EXCEPTION(name, ...) \
throw kerutils::KUException(name, __FILE__, __LINE__, __VA_ARGS__)
#define KU_CUDA_CHECK(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
THROW_KU_EXCEPTION("CUDA", "CUDA error: ", cudaGetErrorString(status_)); \
} \
} while(0)
#define KU_CUTLASS_CHECK(call) \
do { \
cutlass::Status status_ = call; \
if (status_ != cutlass::Status::kSuccess) { \
fprintf(stderr, "CUTLASS error (%s:%d): %d\n", __FILE__, __LINE__, static_cast<int>(status_)); \
THROW_KU_EXCEPTION("CUTLASS", "CUTLASS error: ", static_cast<int>(status_)); \
} \
} while(0)
// This `KU_ASSERT` is triggered no matter if the code is compiled with `-DNDEBUG` or not.
#define KU_ASSERT(cond, ...) \
do { \
if (not (cond)) { \
fprintf(stderr, "Assertion `%s` failed (%s:%d): ", #cond, __FILE__, __LINE__); \
if constexpr (sizeof(#__VA_ARGS__) > 1) { \
fprintf(stderr, ", " __VA_ARGS__); \
} \
fprintf(stderr, "\n"); \
THROW_KU_EXCEPTION("Assertion", "Assertion `", #cond, "` failed."); \
} \
} while(0)
#define KU_CHECK_KERNEL_LAUNCH() KU_CUDA_CHECK(cudaGetLastError())
template<typename T>
inline __host__ __device__ constexpr T ceil_div(const T &a, const T &b) {
return (a + b - 1) / b;
}
template<typename T>
inline __host__ __device__ constexpr T ceil(const T &a, const T &b) {
return (a + b - 1) / b * b;
}
// A wrapper for make_tensor_map
static inline CUtensorMap make_tensor_map(
const std::vector<uint64_t> &size,
const std::vector<uint64_t> &strides, // PAY ATTENTION: In BYTES
const std::vector<uint32_t> &box_size,
void* global_ptr,
CUtensorMapDataType data_type,
CUtensorMapSwizzle swizzle_mode,
CUtensorMapL2promotion l2_promotion,
CUtensorMapInterleave interleave_mode = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapFloatOOBfill oob_fill = CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
const std::vector<uint32_t> &element_strides_ = {}
) {
int dim = size.size();
KU_ASSERT(dim >= 1);
std::vector<uint32_t> element_strides;
if (element_strides_.empty()) {
for (int i = 0; i < dim; ++i)
element_strides.push_back(1);
} else {
element_strides = element_strides_;
}
KU_ASSERT(strides.size() == (uint32_t)dim-1 && box_size.size() == (uint32_t)dim && element_strides.size() == (uint32_t)dim);
CUtensorMap result;
CUresult ret_code = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&result,
data_type,
dim,
global_ptr,
size.data(),
strides.data(),
box_size.data(),
element_strides.data(),
interleave_mode,
swizzle_mode,
l2_promotion,
oob_fill
);
if (ret_code != CUresult::CUDA_SUCCESS) {
auto print_vector = [&](auto t, const char* fmt, const char end='\n') {
for (auto elem : t) {
printf(fmt, elem);
}
printf("%c", end);
};
fprintf(stderr, "Failed to create tensormap\n");
fprintf(stderr, "Dim: %d\n", dim);
printf("size: "); print_vector(size, "%lu ");
printf("strides: "); print_vector(strides, "%lu ");
printf("box_size: "); print_vector(box_size, "%u ");
printf("element_strides: "); print_vector(element_strides, "%u ");
printf("global ptr: 0x%lx\n", (int64_t)global_ptr);
printf("data_type: %d\n", (int)data_type);
printf("swizzle_mode: %d\n", (int)swizzle_mode);
printf("l2_promotion: %d\n", (int)l2_promotion);
printf("interleave_mode: %d\n", (int)interleave_mode);
printf("oob_fill: %d\n", (int)oob_fill);
KU_ASSERT(false);
}
return result;
}
// Given strides (in number of elements), this function converts their datatype in uint64_t and then multiplies by elem_size
template<typename T>
static inline std::vector<uint64_t> make_stride_helper(const std::vector<T> &strides_in_elems, size_t elem_size) {
std::vector<uint64_t> res;
for (auto stride : strides_in_elems) {
res.push_back(((uint64_t)stride) * elem_size);
}
return res;
}
}
\ No newline at end of file
#pragma once
#include "host/host.h"
#include "device/device.cuh"
#pragma once
#include <functional>
#include <torch/python.h>
#include "kerutils/common/common.h"
namespace kerutils {
// Check whether the given tensor or optional tensor satisfies the given condition
// If tensor_or_opt is a tensor, check_fn is applied directly
// If tensor_or_opt is an optional tensor, check_fn is applied only when the optional has value
template<typename T>
static inline bool _check_optional_tensor(const T& tensor_or_opt, const std::function<bool(const at::Tensor&)>& check_fn) {
if constexpr (std::is_same<T, at::Tensor>::value) {
return check_fn(tensor_or_opt);
} else {
if (tensor_or_opt.has_value()) {
return check_fn(tensor_or_opt.value());
} else {
return true;
}
}
}
// Get the pointer of the given tensor
// Return (PtrT*)tensor.data_ptr() if the tensor has a backend storage, nullptr otherwise
template<typename PtrT>
static inline PtrT* get_tensor_ptr(const at::Tensor& tensor) {
if (tensor.has_storage()) {
return (PtrT*)tensor.data_ptr();
} else {
return nullptr;
}
}
// Get the pointer of the given tensor or optional tensor
// Return (PtrT*)tensor.data_ptr() if tensor_or_opt has value and points to a valid tensor, return nullptr otherwise
template<typename PtrT, typename T>
static inline PtrT* get_optional_tensor_ptr(const T& tensor_or_opt) {
if constexpr (std::is_same<T, at::Tensor>::value) {
return get_tensor_ptr<PtrT>(tensor_or_opt);
} else {
if (tensor_or_opt.has_value()) {
return get_tensor_ptr<PtrT>(*tensor_or_opt);
} else {
return nullptr;
}
}
}
}
// Check whether the given tensor (or optional<tensor>) is on cuda
#define KU_CHECK_DEVICE(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.is_cuda(); }), #tensor " must be on CUDA")
// Check whether the given tensor (or optional<tensor>) has the given number of dimensions
#define KU_CHECK_NDIM(tensor, ndim) TORCH_CHECK(ku::_check_optional_tensor(tensor, [&](const at::Tensor& t) { return t.dim() == (ndim); }), #tensor " must have " #ndim " dimensions")
// Check whether the given tensor (or optional<tensor>) has the given shape
#define KU_CHECK_SHAPE(tensor, ...) TORCH_CHECK(ku::_check_optional_tensor(tensor, [&](const at::Tensor& t) { return t.sizes() == torch::IntArrayRef({__VA_ARGS__}); }), #tensor " must have shape (" #__VA_ARGS__ ")")
// Check whether the given tensor (or optional<tensor>) is contiguous
#define KU_CHECK_CONTIGUOUS(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.is_contiguous(); }), #tensor " must be contiguous")
// Check whether the last dimention of the given tensor (or optional<tensor>)
#define KU_CHECK_LAST_DIM_CONTIGUOUS(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.size(-1) == 1 || t.stride(-1) == 1; }), #tensor " must have contiguous last dimension")
// Check whether the given tensor (or optional<tensor>) has the specified dtype
#define KU_CHECK_DTYPE(tensor, target_dtype) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.dtype() == (target_dtype); }), #tensor " must have dtype " #target_dtype)
......@@ -2,7 +2,21 @@
#include "cutlass/bfloat16.h"
struct DecodingParams {
enum class ModelType {
V32,
MODEL1
};
struct __align__(4*8) DecodingSchedMeta {
int begin_req_idx, end_req_idx; // Both inclusive
int begin_block_idx, end_block_idx; // Inclusive, exclusive
int begin_split_idx;
int is_first_req_splitted, is_last_req_splitted;
int _pad[1];
};
static constexpr int DecodingSchedMetaSize = sizeof(DecodingSchedMeta);
struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams
using index_t = int64_t;
int b; // batch size
......@@ -14,13 +28,11 @@ struct DecodingParams {
int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k
bool is_causal;
float scale_softmax, scale_softmax_log2;
int topk;
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ o_ptr;
void *__restrict__ softmax_lse_ptr;
int *__restrict__ indices_ptr;
float *__restrict__ softmax_lse_ptr;
index_t q_batch_stride;
index_t k_batch_stride;
......@@ -31,38 +43,106 @@ struct DecodingParams {
index_t q_head_stride;
index_t k_head_stride;
index_t o_head_stride;
index_t indices_batch_stride;
index_t indices_row_stride;
int *__restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr;
int num_sm_parts;
int *__restrict__ num_splits_ptr;
int total_num_splits;
void *__restrict__ softmax_lseaccum_ptr;
void *__restrict__ oaccum_ptr;
float *__restrict__ softmax_lseaccum_ptr;
float *__restrict__ oaccum_ptr;
cudaStream_t stream;
};
static constexpr int TileSchedulerMetaDataSize = 8;
// [begin_idx (inclusive), begin_block_idx (inclusive), end_idx (inclusive), end_block_idx (exclusive), begin_n_split_idx, _, _, _]
struct SparseAttnDecodeParams {
int b, s_q;
int h_q, h_kv;
int d_qk, d_v;
float sm_scale, sm_scale_div_log2;
int num_blocks, page_block_size, topk;
ModelType model_type;
struct GetDecodingMetadataParams {
int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
int *__restrict__ num_splits_ptr;
int batch_size;
cutlass::bfloat16_t* __restrict__ q; // [b, s_q, h_q, d_qk]
cutlass::bfloat16_t* __restrict__ kv; // [num_blocks, page_block_size, d_qk]
int* __restrict__ indices; // [b, s_q, topk]
int* __restrict__ topk_length; // [b], may be nullptr
float* __restrict__ attn_sink; // [h_q], may be nullptr
float* __restrict__ lse; // [b, s_q, h_q]
cutlass::bfloat16_t* __restrict__ out; // [b, s_q, h_q, d_v]
int extra_num_blocks, extra_page_block_size, extra_topk;
cutlass::bfloat16_t* __restrict__ extra_kv; // [extra_num_blocks, extra_page_block_size, d_qk]
int* __restrict__ extra_indices; // [b, s_q, extra_topk]
int* __restrict__ extra_topk_length; // [b], may be nullptr
int stride_q_b, stride_q_s_q, stride_q_h_q;
int stride_kv_block, stride_kv_row;
int stride_indices_b, stride_indices_s_q;
int stride_lse_b, stride_lse_s_q;
int stride_o_b, stride_o_s_q, stride_o_h_q;
int stride_extra_kv_block, stride_extra_kv_row;
int stride_extra_indices_b, stride_extra_indices_s_q;
cudaStream_t stream;
// SplitKV-related parameters
float* __restrict__ lse_accum; // [num_splits, s_q, h_q]
float* __restrict__ o_accum; // [num_splits, s_q, h_q, d_v]
int stride_lse_accum_split, stride_lse_accum_s_q;
int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q;
DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; // [num_sm_parts, ], contiguous
int* __restrict__ num_splits_ptr; // [batch_size+1, ], contiguous
int num_sm_parts;
};
struct CombineParams {
int b, s_q, h_q, d_v;
float* __restrict__ lse; // [b, s_q, h_q]
void* __restrict__ out; // [b, s_q, h_q, d_v]
int stride_lse_b, stride_lse_s_q;
int stride_o_b, stride_o_s_q, stride_o_h_q;
float* __restrict__ lse_accum; // [num_splits, s_q, h_q]
float* __restrict__ o_accum; // [num_splits, s_q, h_q, d_v]
int stride_lse_accum_split, stride_lse_accum_s_q;
int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q;
DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; // [num_sm_parts, ], contiguous
int* __restrict__ num_splits_ptr; // [batch_size+1, ], contiguous
int num_sm_parts;
float* attn_sink; // [h_q], may be nullptr
cudaStream_t stream;
};
struct GetDecodeSchedMetaParams {
int b; // batch size
int s_q;
int block_size_n;
int fixed_overhead_num_blocks;
int topk, extra_topk; // -1 if sparse attention (or extra topk) is disabled
int *__restrict__ topk_length, *__restrict__ extra_topk_length;
int *__restrict__ seqlens_k_ptr; // Only necessary for dense attention
DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr;
int *__restrict__ num_splits_ptr;
int num_sm_parts;
int topk;
cudaStream_t stream;
};
struct SparsePrefillParams {
struct SparseAttnFwdParams {
int s_q, s_kv, h_q, h_kv, d_qk, d_v, topk;
float sm_scale, sm_scale_div_log2;
......@@ -70,7 +150,10 @@ struct SparsePrefillParams {
cutlass::bfloat16_t* __restrict__ q; // [s_q, h_q, d_qk]
cutlass::bfloat16_t* __restrict__ kv; // [s_kv, h_kv, d_qk]
int* __restrict__ indices; // [s_q, h_kv, topk]
float* __restrict__ attn_sink; // [h_q], may be nullptr
int* __restrict__ topk_length; // [s_q], may be nullptr
// Strides
int stride_q_s_q; int stride_q_h_q;
int stride_kv_s_kv; int stride_kv_h_kv;
int stride_indices_s_q; int stride_indices_h_kv;
......@@ -80,5 +163,18 @@ struct SparsePrefillParams {
float* __restrict__ max_logits; // [s_q, h_q]
float* __restrict__ lse; // [s_q, h_q]
int num_sm;
cudaStream_t stream;
};
// We have some kernels that implement both prefill and decode modes in a single kernel (with different template instantiations). The following enum helps to distinguish the modes.
enum class SparseAttnFwdMode {
Prefill, // Normal prefill mode
DecodeWithSplitKV, // To trigger decoding mode for kernels that support both prefill and decode
};
template<SparseAttnFwdMode FWD_MODE>
inline constexpr bool is_decode_v = std::bool_constant<FWD_MODE == SparseAttnFwdMode::DecodeWithSplitKV>::value;
template<SparseAttnFwdMode FWD_MODE>
using SparseFwdArgT = std::conditional_t<is_decode_v<FWD_MODE>, SparseAttnDecodeParams, SparseAttnFwdParams>;
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/fast_math.h>
#include "params.h"
#include "smxx/get_mla_metadata.h"
#include "smxx/mla_combine.h"
#include "sm90/decode/dense/splitkv_mla.h"
#include "sm90/decode/sparse_fp8/splitkv_mla.h"
#include "sm90/prefill/sparse/fwd.h"
#include "sm100/decode/sparse_fp8/splitkv_mla.h"
#include "sm100/prefill/dense/interface.h"
#include "sm100/prefill/sparse/fwd.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
struct Arch {
int major;
int minor;
bool is_sm90() const {
return major == 9 && minor == 0;
}
bool is_sm100() const {
return major == 10;
}
void assert_is_supported() const {
TORCH_CHECK(is_sm90() || is_sm100(), "Only SM90 and SM100 are supported");
}
};
// DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e. SM90 Dense BF16, SM90 Sparse FP8, etc.)
struct DecodingAttnImplMeta {
int num_sm_parts;
int fixed_overhead_num_blocks;
int k_block_size;
};
DecodingAttnImplMeta get_attn_impl_meta(
Arch arch,
int sm_count,
int num_q_tokens_per_head_k,
int h_k,
std::optional<int> h_q_,
bool is_fp8_kvcache,
bool is_sparse_attn
) {
if (arch.is_sm90()) {
if (is_sparse_attn) {
if (is_fp8_kvcache) {
TORCH_CHECK(h_q_.has_value());
int h_q = h_q_.value();
TORCH_CHECK(h_q % h_k == 0);
int s_q = num_q_tokens_per_head_k * h_k / h_q;
// FP8 + Sparse MLA
return {
std::max((sm_count/2) / h_k / (cutlass::ceil_div(h_q/h_k, 2*64) * s_q), 1),
5,
64
};
} else {
// Sparse BF16 MLA
TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM90");
}
} else {
if (is_fp8_kvcache) {
// Dense FP8 MLA
TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90");
} else {
// Dense BF16 MLA
return {
std::max(sm_count / h_k / cutlass::ceil_div(num_q_tokens_per_head_k, 64), 1),
5,
64
};
}
}
} else if (arch.is_sm100()) {
if (is_sparse_attn) {
if (is_fp8_kvcache) {
TORCH_CHECK(h_q_.has_value());
int h_q = h_q_.value();
TORCH_CHECK(h_q % h_k == 0);
int s_q = num_q_tokens_per_head_k * h_k / h_q;
// FP8 + Sparse MLA
return {
std::max(sm_count / h_k / (cutlass::ceil_div(h_q/h_k, 64) * s_q), 1),
5,
64
};
} else {
// Sparse BF16 MLA
TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM100");
}
} else {
if (is_fp8_kvcache) {
// FP8 MLA
TORCH_CHECK(false, "FP8 Dence MLA is not supported on SM100");
} else {
// Normal BF16 MLA
TORCH_CHECK(false, "BF16 Dence MLA is not supported on SM100");
}
}
} else {
TORCH_CHECK(false, "Unsupported GPU architecture");
}
}
std::vector<at::Tensor>
get_mla_decoding_metadata(
at::Tensor &seqlens_k,
const int num_q_tokens_per_head_k,
const int h_k,
const std::optional<int> h_q,
const bool is_fp8_kvcache,
const std::optional<int> topk
) {
bool is_sparse_attn = topk.has_value();
CHECK_DEVICE(seqlens_k);
TORCH_CHECK(seqlens_k.is_contiguous());
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
if (is_sparse_attn)
TORCH_CHECK(h_q.has_value(), "num_heads_q must be provided when topk is provided");
int batch_size = seqlens_k.size(0);
int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
auto options = seqlens_k.options();
auto dprops = at::cuda::getCurrentDeviceProperties();
int sm_count = dprops->multiProcessorCount;
Arch arch = {dprops->major, dprops->minor};
arch.assert_is_supported();
DecodingAttnImplMeta attn_impl_meta = get_attn_impl_meta(arch, sm_count, num_q_tokens_per_head_k, h_k, h_q, is_fp8_kvcache, is_sparse_attn);
auto tile_scheduler_metadata = torch::empty({attn_impl_meta.num_sm_parts, TileSchedulerMetaDataSize}, options);
auto num_splits = torch::empty({batch_size + 1}, options);
int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
int *num_splits_ptr = num_splits.data_ptr<int>();
at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
GetDecodingMetadataParams params = {};
params.seqlens_k_ptr = seqlens_k_ptr;
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
params.num_splits_ptr = num_splits_ptr;
params.batch_size = batch_size;
params.block_size_n = attn_impl_meta.k_block_size;
params.fixed_overhead_num_blocks = attn_impl_meta.fixed_overhead_num_blocks;
params.num_sm_parts = attn_impl_meta.num_sm_parts;
params.topk = is_sparse_attn ? topk.value() : -1;
run_get_mla_metadata_kernel(params, stream);
return {tile_scheduler_metadata, num_splits};
}
std::vector<at::Tensor>
fwd_kvcache_mla(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits, // batch_size + 1
const bool &is_fp8,
const std::optional<at::Tensor> &indices // None, or batch_size x seqlen_q x topk
) {
bool is_sparse_attn = indices.has_value();
int topk = is_sparse_attn ? indices->size(-1) : -1;
// Check the architecture
auto dprops = at::cuda::getCurrentDeviceProperties();
Arch arch = {dprops->major, dprops->minor};
arch.assert_is_supported();
// Check data types
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
if (!is_fp8) {
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
} else {
TORCH_CHECK(kcache.dtype() == torch::kFloat8_e4m3fn || kcache.dtype() == torch::kInt8 || kcache.dtype() == torch::kUInt8, "key must have dtype fp8_e4m3fn or int8 or uint8");
}
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
TORCH_CHECK(!is_sparse_attn || indices->dtype() == torch::kInt32, "indices must have dtype int32");
// Check device
CHECK_DEVICE(q);
CHECK_DEVICE(kcache);
CHECK_DEVICE(seqlens_k);
CHECK_DEVICE(block_table);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_DEVICE(num_splits);
if (is_sparse_attn) CHECK_DEVICE(indices.value());
// Check layout
TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension");
CHECK_CONTIGUOUS(seqlens_k);
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
CHECK_CONTIGUOUS(tile_scheduler_metadata);
CHECK_CONTIGUOUS(num_splits);
TORCH_CHECK(!is_sparse_attn || indices->stride(-1) == 1, "indices must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_q = sizes[2];
const int head_size_k = sizes[3];
TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported");
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int num_q_heads_per_hk = num_heads_q / num_heads_k;
const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
if (!is_fp8) {
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
} else {
int bytes_per_token = 512 + 64*2 + (512/128)*4;
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, bytes_per_token);
TORCH_CHECK(num_heads_k == 1, "Currently the number of k heads must be 1 when is_fp8_kvcache is True");
TORCH_CHECK(kcache.stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True");
}
CHECK_SHAPE(seqlens_k, batch_size);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_SHAPE(num_splits, batch_size+1);
if (is_sparse_attn) CHECK_SHAPE(indices.value(), batch_size, seqlen_q_ori, topk);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
CHECK_CONTIGUOUS(softmax_lse);
DecodingParams params = {};
// Set the sizes.
params.b = batch_size;
params.s_q = seqlen_q_ori;
params.q_seq_per_hk = q_seq_per_hk;
params.seqlens_k_ptr = seqlens_k.data_ptr<int>();
params.h_q = num_heads_q;
params.h_k = num_heads_k;
params.num_blocks = num_blocks;
params.q_head_per_hk = num_q_heads_per_hk;
params.is_causal = is_causal;
params.d = head_size_k;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
params.topk = topk;
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.o_ptr = out.data_ptr();
params.indices_ptr = is_sparse_attn ? indices->data_ptr<int>() : nullptr;
params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(-3);
params.k_row_stride = kcache.stride(1);
params.o_row_stride = out.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = kcache.stride(2);
params.o_head_stride = out.stride(-2);
params.indices_batch_stride = is_sparse_attn ? indices->stride(0) : 0;
params.indices_row_stride = is_sparse_attn ? indices->stride(1) : 0;
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
params.num_sm_parts = tile_scheduler_metadata.size(0);
params.num_splits_ptr = num_splits.data_ptr<int>();
const int total_num_splits = batch_size + params.num_sm_parts;
at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
CHECK_CONTIGUOUS(softmax_lse_accum);
CHECK_CONTIGUOUS(out_accum);
params.total_num_splits = total_num_splits;
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size_k == 576);
if (q_dtype == torch::kHalf) {
#ifdef FLASH_MLA_DISABLE_FP16
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
#endif
}
if (arch.is_sm90()) {
if (is_sparse_attn) {
if (is_fp8) {
TORCH_CHECK(q_dtype == torch::kBFloat16, "Sparse FP8 MLA only supports BFloat16 on SM90");
sm90::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream);
} else {
TORCH_CHECK(false, "Only FP8 kvcahe is supported for sparse MLA on SM90");
}
} else {
if (is_fp8) {
TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90");
} else {
if (q_dtype == torch::kBFloat16) {
sm90::run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params, stream);
} else if (q_dtype == torch::kHalf) {
#ifndef FLASH_MLA_DISABLE_FP16
sm90::run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
#endif
} else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
}
}
}
} else if (arch.is_sm100()) {
TORCH_CHECK(is_fp8 && is_sparse_attn, "Only FP8 + Sparse attention is supported on SM100");
sm100::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream);
} else {
TORCH_CHECK(false, "Unsupported GPU architecture");
}
if (q_dtype == torch::kBFloat16) {
run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
} else if (q_dtype == torch::kHalf) {
#ifndef FLASH_MLA_DISABLE_FP16
run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
#endif
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)
.reshape({batch_size, num_heads_q, seqlen_q_ori});
return {out, softmax_lse};
}
inline int int64_stride_to_int(int64_t orig_stride) {
if (orig_stride > std::numeric_limits<int>::max()) {
TORCH_CHECK(false, "[Sparse TopK Attention] Stride exceeds int32 limit: ", orig_stride);
}
return static_cast<int>(orig_stride);
}
std::vector<at::Tensor> sparse_prefill_fwd(
const at::Tensor &q,
const at::Tensor &kv,
const at::Tensor &indices,
float sm_scale,
int d_v
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9;
bool is_sm100 = dprops->major == 10;
TORCH_CHECK(is_sm90 || is_sm100, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 or SM100 architectures");
CHECK_DEVICE(q);
CHECK_DEVICE(kv);
CHECK_DEVICE(indices);
TORCH_CHECK(q.dtype() == torch::kBFloat16);
TORCH_CHECK(kv.dtype() == torch::kBFloat16);
TORCH_CHECK(indices.dtype() == torch::kInt32);
int s_q = q.size(0);
int s_kv = kv.size(0);
int h_q = q.size(1);
int h_kv = kv.size(1);
int d_qk = q.size(2);
int topk = indices.size(2);
CHECK_SHAPE(q, s_q, h_q, d_qk);
CHECK_SHAPE(kv, s_kv, h_kv, d_qk);
CHECK_SHAPE(indices, s_q, h_kv, topk);
TORCH_CHECK(q.stride(-1) == 1);
TORCH_CHECK(kv.stride(-1) == 1);
TORCH_CHECK(indices.stride(-1) == 1);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({s_q, h_q, d_v}, opts);
CHECK_CONTIGUOUS(out);
at::Tensor buf_attn_score, max_logits, lse, p_sum;
max_logits = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
CHECK_CONTIGUOUS(max_logits);
CHECK_CONTIGUOUS(lse);
SparsePrefillParams params = {
s_q, s_kv, h_q, h_kv, d_qk, d_v, topk,
sm_scale, sm_scale * 1.44269504f,
(cutlass::bfloat16_t*)q.data_ptr(),
(cutlass::bfloat16_t*)kv.data_ptr(),
(int*)indices.data_ptr(),
int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)),
int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)),
int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)),
(cutlass::bfloat16_t*)out.data_ptr(),
(float*)max_logits.data_ptr(),
(float*)lse.data_ptr(),
at::cuda::getCurrentCUDAStream().stream()
};
if (is_sm90) {
sm90::run_fwd_kernel(params);
} else if (is_sm100) {
sm100::run_fwd_kernel(params);
} else {
TORCH_CHECK(false, "Unknown architecture");
}
return {out, max_logits, lse};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashMLA";
m.def("get_mla_decoding_metadata", &get_mla_decoding_metadata);
m.def("fwd_kvcache_mla", &fwd_kvcache_mla);
m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun);
m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun);
m.def("sparse_prefill_fwd", &sparse_prefill_fwd);
}
Head128 decoding kernels are located at `csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu` (for k_dim = 512) or simulated using 2x head64 kernel
\ No newline at end of file
#pragma once
#include "kernel.h"
#include <cuda_fp8.h>
#include <cutlass/barrier.h>
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
#include "defines.h"
#include "params.h"
namespace sm100::decode::head64 {
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::NamedBarrier;
using e8m0 = __nv_fp8_e8m0;
using e4m3 = cutlass::float_e4m3_t;
using namespace cute;
enum NamedBarriers : uint32_t {
main_loop_sync = 0,
wg0_sync = 1,
wg0_warp02_sync = 2,
wg0_warp13_sync = 3,
everyone_sync = 4
};
template<ModelType MODEL_TYPE>
struct KernelTemplate {
static constexpr int D_Q = MODEL_TYPE == ModelType::V32 ? 576 : 512;
static constexpr int D_K = D_Q;
static constexpr int D_V = 512;
static constexpr int D_NOPE = MODEL_TYPE == ModelType::V32 ? 512 : 448;
static constexpr int D_ROPE = 64;
static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64;
static constexpr bool V_HAVE_ROPE = MODEL_TYPE == ModelType::V32 ? false : true;
static constexpr int NUM_SCALES_EACH_TOKEN = MODEL_TYPE == ModelType::V32 ? 4 : 8; // Padding is included
static constexpr int TMA_K_STRIDE = MODEL_TYPE == ModelType::V32 ? D_NOPE+2*D_ROPE+4*(D_NOPE/QUANT_TILE_SIZE) : D_NOPE+2*D_ROPE; // Stride of K's tensormap. This stride must 1) be a factor of the actual stride between tokens 2) large enough to cover the entire KV cache. Since TMA copy's coordinate can only be 32bit signed integers, this number must >= 128, perferrably >= 256. So we set this to 656 for V32 and 576 for MODEL1. Extra padding may be necessary for KV blocks.
static_assert(D_NOPE + D_ROPE == D_Q);
static_assert(V_HAVE_ROPE ? (D_NOPE + D_ROPE == D_V) : (D_NOPE == D_V));
static constexpr int B_H = 64;
static constexpr int B_TOPK = 64;
static constexpr int NUM_BUFS = 2;
static constexpr int NUM_INDEX_BUFS = 4; // Number of buffers for indices (tma_coords) & is_token_valid & scales
static constexpr int NUM_THREADS = 128*3; // 128 exp + 1/32 utcmma + 1/32 raw KV producer + 1/32 rope producer + 32 index+scale+valid_mask producer + 128 dequant
static constexpr float MAX_INIT_VAL = -1e30f; // To avoid (-inf) - (-inf) = NaN
static constexpr int D_Q_SW128 = 512;
static constexpr int D_Q_SW64 = MODEL_TYPE == ModelType::V32 ? 64 : 0;
static_assert(D_Q_SW128 + D_Q_SW64 == D_Q);
static constexpr int K_ROPE_SW = MODEL_TYPE == ModelType::V32 ? 64 : 128; // RoPE part stored in SW64 (for V32) or SW128 (for MODEL1), in bytes
template<
typename Shape_Q_SW128, typename TMA_Q_SW128,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q_SW128 shape_Q_SW128; TMA_Q_SW128 tma_Q_SW128;
Shape_O shape_O; TMA_O tma_O;
CUtensorMap tensor_map_q_sw64; // Invalid if D_Q_SW64 == 0
CUtensorMap tensor_map_kv_nope;
CUtensorMap tensor_map_kv_rope;
CUtensorMap tensor_map_extra_kv_nope;
CUtensorMap tensor_map_extra_kv_rope;
};
// Tensor memory columns
struct tmem_cols {
// 0 ~ 256: output
// 256 ~ 256 + 64*D_Q/256: Q
// 400 ~ 464: P
static constexpr int O = 0;
static constexpr int Q = 256;
static constexpr int Q_Tail = 256 + B_H*D_NOPE/2/128;
static constexpr int P = 400;
};
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<NUM_TILES*64>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutQ_SW128 = SmemLayoutQTiles<D_Q_SW128/64>;
using SmemLayoutOBuf = decltype(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<D_V>>{}
));
using SmemLayoutOBuf_TMA = decltype(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64>>{}
)); // A TMA tile
static_assert(D_V == 512);
using SmemLayoutOAccumBuf = Layout<
Shape<Int<B_H>, Int<D_V>>,
Stride<Int<520>, _1> // We use stride = 520 here to avoid bank conflict
>;
using SmemLayoutS = decltype(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H>, Int<B_TOPK>>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles_SW128 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTiles_DualGemm_SW128 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H*2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed_SW128 = decltype(composition(
SmemLayoutKTiles_SW128<NUM_TILES>{},
Layout<
Shape<Int<64*NUM_TILES>, Int<B_TOPK>>,
Stride<Int<B_TOPK>, _1>
>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles_SW64 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H>, Int<32*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTiles_DualGemm_SW64 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H*2>, Int<32*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed_SW64 = decltype(composition(
SmemLayoutKTiles_SW64<NUM_TILES>{},
Layout<
Shape<Int<32*NUM_TILES>, Int<B_TOPK>>,
Stride<Int<B_TOPK>, _1>
>{}
));
struct SharedMemoryPlan {
union {
struct {
array_aligned<bf16, cosize_v<SmemLayoutQ_SW128>> q;
bf16 q_sw64[B_H*D_Q_SW64]; // NOTE D_Q_SW64 may be 0 but array_aligned<bf16, 0> will have a size of 16, so we use array here. The former tensor (`q`) promises its alignment.
union {
array_aligned<bf16, cosize_v<SmemLayoutOBuf>> o_buf;
array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> o_accum_buf;
} o;
} qo;
struct {
struct {
array_aligned<bf16, B_H*D_NOPE> nope; // NoPE part, dequantized
array_aligned<bf16, B_H*D_ROPE> rope; // RoPE part, dequantized. SW64 in v32 mode, SW128 in MODEL1 mode
} dequant[NUM_BUFS];
static_assert(sizeof(dequant) >= sizeof(bf16) * (B_H*D_Q)); // So that Q does not covers raw_nope
array_aligned<e4m3, B_H*D_NOPE> raw_nope[NUM_BUFS]; // Raw (quantized) NoPE part
} kv;
} u;
union {
float4 p_exchange_buf[4][16 * B_TOPK / 4];
array_aligned<bf16, cosize_v<SmemLayoutS>> s;
} s_p;
CUTE_ALIGNAS(16) float rowwise_max_buf[128];
char is_token_valid[NUM_INDEX_BUFS][B_TOPK/8];
int tma_coord[NUM_INDEX_BUFS][B_TOPK];
e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN];
array_aligned<uint32_t, 1> tmem_start_addr;
transac_bar_t bar_last_store_done;
transac_bar_t bar_q_tma, bar_q_utccp;
transac_bar_t bar_rope_ready[NUM_BUFS];
transac_bar_t bar_nope_ready[NUM_BUFS];
transac_bar_t bar_raw_ready[NUM_BUFS], bar_raw_free[NUM_BUFS];
transac_bar_t bar_valid_coord_scale_ready[NUM_INDEX_BUFS], bar_valid_coord_scale_free[NUM_INDEX_BUFS];
transac_bar_t bar_qk_done[NUM_BUFS], bar_so_ready[NUM_BUFS], bar_sv_done[NUM_BUFS];
};
using TiledMMA_P = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_TS_NOELECT<bf16, bf16, float, B_H, B_TOPK*2, UMMA::Major::K, UMMA::Major::K>{}
)); // *2 for dual gemm
using TiledMMA_O = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{}
));
template<typename TmaParam>
static __device__ void
flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams &params, const TmaParam &tma_params);
static void run(const SparseAttnDecodeParams &params);
};
}
\ No newline at end of file
#include "../kernel.cuh"
namespace sm100::decode::head64 {
template
void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1>(const SparseAttnDecodeParams &params);
}
#include "../kernel.cuh"
namespace sm100::decode::head64 {
template
void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32>(const SparseAttnDecodeParams &params);
}
This diff is collapsed.
#pragma once
#include "params.h"
namespace sm100::decode::head64 {
template<ModelType MODEL_TYPE>
void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params);
}
This diff is collapsed.
This diff is collapsed.
#pragma once
#include "params.h"
namespace sm100 {
void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams &params, cudaStream_t stream);
}
This diff is collapsed.
This diff is collapsed.
......@@ -34,7 +34,7 @@
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
#include "utils.h" // for IS_SM100
#include <kerutils/kerutils.cuh> // for KERUTILS_ENABLE_SM100A
namespace cutlass::fmha::kernel {
......@@ -139,7 +139,7 @@ struct FmhaKernelBwdConvert {
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if IS_SM100
#if defined(KERUTILS_ENABLE_SM100A)
if (params.ptr_src_dQ != nullptr) {
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment