Commit 082094b7 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Multiple updates and refactorings (#150)

* Multiple updates and refactorings

* Remove dead code
parent 1408756a
#pragma once
#include "kerutils/device/common.h"
namespace kerutils {
// st.async (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-async)
template<typename T>
CUTE_DEVICE
static void st_async(void* dst_ptr, const T& data, transac_bar_t &mbar) {
static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async.");
long2 data_long2 = *reinterpret_cast<const long2*>(&data);
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar);
asm volatile (
"st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n"
:
: "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr)
);
}
static constexpr int PEER_ADDR_MASK = 16777216;
// Given an address in the current CTA, return the corresponding address in the peer CTA
template<typename T>
CUTE_DEVICE
T* get_peer_addr(const T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
}
// Given an address in the current CTA, return the corresponding address in the peer CTA (if the current CTA_id%2 == 1) or the address itself (if CTA_id%2 == 0)
template<typename T>
CUTE_DEVICE
T* get_cta0_addr(const T* p) {
constexpr int CTA0_ADDR_MASK = 0xFEFFFFFF;
return (T*)((int64_t)(p) & CTA0_ADDR_MASK);
}
// TMA bulk reduce add (cp.reduce.async.bulk), shared to global, float32, add. (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk)
CUTE_DEVICE
void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) {
uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(src_ptr);
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
:
: "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes)
: "memory");
}
// Cluster barrier arrive with .release modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)
CUTE_DEVICE
void barrier_cluster_arrive_release() {
asm volatile("barrier.cluster.arrive.release;" : : : "memory");
}
// Cluster barrier arrive with .relaxed modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)
CUTE_DEVICE
void barrier_cluster_arrive_relaxed() {
asm volatile("barrier.cluster.arrive.relaxed;" : : :);
}
// Cluster barrier wait with .acquire modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)
CUTE_DEVICE
void barrier_cluster_wait_acquire() {
asm volatile("barrier.cluster.wait.acquire;" : : : "memory");
}
// mbarrier.arrive with .relaxed.cluster qualifier (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
CUTE_DEVICE
void mbarrier_arrive_relaxed_cluster(transac_bar_t &mbar) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(&mbar);
asm volatile(
"{\n\t"
"mbarrier.arrive.relaxed.cluster.shared::cta.b64 _, [%0];\n\t"
"}"
:
: "r"(smem_addr));
}
// AtomicAdd with v4.f32 type (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red)
CUTE_DEVICE
void atomicadd_f32x4_with_policy_and_pred(void* global_addr, const float4 &data, int64_t cache_policy, uint32_t pred = true) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.eq.u32 p, %6, 1;\n\t"
"@p red.relaxed.gpu.global.add.L2::cache_hint.v4.f32 [%4], {%0, %1, %2, %3}, %5; \n\t"
"}"
:
: "f"(data.x), "f"(data.y), "f"(data.z), "f"(data.w),
"l"((int64_t)global_addr), "l"(cache_policy), "r"(pred)
);
}
// cp.async.bulk, from .shared::cta to .shared::cluster (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk)
CUTE_DEVICE
void cp_async_bulk_shared_cta_to_shared_cluster(void* dst_ptr, const void* src_ptr, int32_t load_bytes, transac_bar_t &mbar) {
uint32_t dst_smem_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t src_smem_addr = cute::cast_smem_ptr_to_uint(src_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar);
asm volatile(
"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \n"
:
: "r"(dst_smem_addr), "r"(src_smem_addr), "r"(load_bytes), "r"(mbar_addr)
);
}
}
#pragma once
#include <exception>
#include <string>
#include <sstream>
#include <vector>
#include <cuda_runtime_api.h>
#include <cuda.h>
#include <cutlass/cuda_host_adapter.hpp>
#include "kerutils/common/common.h"
namespace kerutils {
class KUException final : public std::exception {
std::string message = {};
public:
template<typename... Args>
explicit KUException(const char *name, const char* file, const int line, Args&&... args) {
std::ostringstream oss;
oss << name << " error (" << file << ":" << line << "): ";
(oss << ... << args);
message = oss.str();
}
const char *what() const noexcept override {
return message.c_str();
}
};
#define THROW_KU_EXCEPTION(name, ...) \
throw kerutils::KUException(name, __FILE__, __LINE__, __VA_ARGS__)
#define KU_CUDA_CHECK(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
THROW_KU_EXCEPTION("CUDA", "CUDA error: ", cudaGetErrorString(status_)); \
} \
} while(0)
#define KU_CUTLASS_CHECK(call) \
do { \
cutlass::Status status_ = call; \
if (status_ != cutlass::Status::kSuccess) { \
fprintf(stderr, "CUTLASS error (%s:%d): %d\n", __FILE__, __LINE__, static_cast<int>(status_)); \
THROW_KU_EXCEPTION("CUTLASS", "CUTLASS error: ", static_cast<int>(status_)); \
} \
} while(0)
// This `KU_ASSERT` is triggered no matter if the code is compiled with `-DNDEBUG` or not.
#define KU_ASSERT(cond, ...) \
do { \
if (not (cond)) { \
fprintf(stderr, "Assertion `%s` failed (%s:%d): ", #cond, __FILE__, __LINE__); \
if constexpr (sizeof(#__VA_ARGS__) > 1) { \
fprintf(stderr, ", " __VA_ARGS__); \
} \
fprintf(stderr, "\n"); \
THROW_KU_EXCEPTION("Assertion", "Assertion `", #cond, "` failed."); \
} \
} while(0)
#define KU_CHECK_KERNEL_LAUNCH() KU_CUDA_CHECK(cudaGetLastError())
template<typename T>
inline __host__ __device__ constexpr T ceil_div(const T &a, const T &b) {
return (a + b - 1) / b;
}
template<typename T>
inline __host__ __device__ constexpr T ceil(const T &a, const T &b) {
return (a + b - 1) / b * b;
}
// A wrapper for make_tensor_map
static inline CUtensorMap make_tensor_map(
const std::vector<uint64_t> &size,
const std::vector<uint64_t> &strides, // PAY ATTENTION: In BYTES
const std::vector<uint32_t> &box_size,
void* global_ptr,
CUtensorMapDataType data_type,
CUtensorMapSwizzle swizzle_mode,
CUtensorMapL2promotion l2_promotion,
CUtensorMapInterleave interleave_mode = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapFloatOOBfill oob_fill = CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
const std::vector<uint32_t> &element_strides_ = {}
) {
int dim = size.size();
KU_ASSERT(dim >= 1);
std::vector<uint32_t> element_strides;
if (element_strides_.empty()) {
for (int i = 0; i < dim; ++i)
element_strides.push_back(1);
} else {
element_strides = element_strides_;
}
KU_ASSERT(strides.size() == (uint32_t)dim-1 && box_size.size() == (uint32_t)dim && element_strides.size() == (uint32_t)dim);
CUtensorMap result;
CUresult ret_code = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&result,
data_type,
dim,
global_ptr,
size.data(),
strides.data(),
box_size.data(),
element_strides.data(),
interleave_mode,
swizzle_mode,
l2_promotion,
oob_fill
);
if (ret_code != CUresult::CUDA_SUCCESS) {
auto print_vector = [&](auto t, const char* fmt, const char end='\n') {
for (auto elem : t) {
printf(fmt, elem);
}
printf("%c", end);
};
fprintf(stderr, "Failed to create tensormap\n");
fprintf(stderr, "Dim: %d\n", dim);
printf("size: "); print_vector(size, "%lu ");
printf("strides: "); print_vector(strides, "%lu ");
printf("box_size: "); print_vector(box_size, "%u ");
printf("element_strides: "); print_vector(element_strides, "%u ");
printf("global ptr: 0x%lx\n", (int64_t)global_ptr);
printf("data_type: %d\n", (int)data_type);
printf("swizzle_mode: %d\n", (int)swizzle_mode);
printf("l2_promotion: %d\n", (int)l2_promotion);
printf("interleave_mode: %d\n", (int)interleave_mode);
printf("oob_fill: %d\n", (int)oob_fill);
KU_ASSERT(false);
}
return result;
}
// Given strides (in number of elements), this function converts their datatype in uint64_t and then multiplies by elem_size
template<typename T>
static inline std::vector<uint64_t> make_stride_helper(const std::vector<T> &strides_in_elems, size_t elem_size) {
std::vector<uint64_t> res;
for (auto stride : strides_in_elems) {
res.push_back(((uint64_t)stride) * elem_size);
}
return res;
}
}
\ No newline at end of file
#pragma once
#include "host/host.h"
#include "device/device.cuh"
#pragma once
#include <functional>
#include <torch/python.h>
#include "kerutils/common/common.h"
namespace kerutils {
// Check whether the given tensor or optional tensor satisfies the given condition
// If tensor_or_opt is a tensor, check_fn is applied directly
// If tensor_or_opt is an optional tensor, check_fn is applied only when the optional has value
template<typename T>
static inline bool _check_optional_tensor(const T& tensor_or_opt, const std::function<bool(const at::Tensor&)>& check_fn) {
if constexpr (std::is_same<T, at::Tensor>::value) {
return check_fn(tensor_or_opt);
} else {
if (tensor_or_opt.has_value()) {
return check_fn(tensor_or_opt.value());
} else {
return true;
}
}
}
// Get the pointer of the given tensor
// Return (PtrT*)tensor.data_ptr() if the tensor has a backend storage, nullptr otherwise
template<typename PtrT>
static inline PtrT* get_tensor_ptr(const at::Tensor& tensor) {
if (tensor.has_storage()) {
return (PtrT*)tensor.data_ptr();
} else {
return nullptr;
}
}
// Get the pointer of the given tensor or optional tensor
// Return (PtrT*)tensor.data_ptr() if tensor_or_opt has value and points to a valid tensor, return nullptr otherwise
template<typename PtrT, typename T>
static inline PtrT* get_optional_tensor_ptr(const T& tensor_or_opt) {
if constexpr (std::is_same<T, at::Tensor>::value) {
return get_tensor_ptr<PtrT>(tensor_or_opt);
} else {
if (tensor_or_opt.has_value()) {
return get_tensor_ptr<PtrT>(*tensor_or_opt);
} else {
return nullptr;
}
}
}
}
// Check whether the given tensor (or optional<tensor>) is on cuda
#define KU_CHECK_DEVICE(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.is_cuda(); }), #tensor " must be on CUDA")
// Check whether the given tensor (or optional<tensor>) has the given number of dimensions
#define KU_CHECK_NDIM(tensor, ndim) TORCH_CHECK(ku::_check_optional_tensor(tensor, [&](const at::Tensor& t) { return t.dim() == (ndim); }), #tensor " must have " #ndim " dimensions")
// Check whether the given tensor (or optional<tensor>) has the given shape
#define KU_CHECK_SHAPE(tensor, ...) TORCH_CHECK(ku::_check_optional_tensor(tensor, [&](const at::Tensor& t) { return t.sizes() == torch::IntArrayRef({__VA_ARGS__}); }), #tensor " must have shape (" #__VA_ARGS__ ")")
// Check whether the given tensor (or optional<tensor>) is contiguous
#define KU_CHECK_CONTIGUOUS(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.is_contiguous(); }), #tensor " must be contiguous")
// Check whether the last dimention of the given tensor (or optional<tensor>)
#define KU_CHECK_LAST_DIM_CONTIGUOUS(tensor) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.size(-1) == 1 || t.stride(-1) == 1; }), #tensor " must have contiguous last dimension")
// Check whether the given tensor (or optional<tensor>) has the specified dtype
#define KU_CHECK_DTYPE(tensor, target_dtype) TORCH_CHECK(ku::_check_optional_tensor(tensor, [](const at::Tensor& t) { return t.dtype() == (target_dtype); }), #tensor " must have dtype " #target_dtype)
...@@ -2,7 +2,21 @@ ...@@ -2,7 +2,21 @@
#include "cutlass/bfloat16.h" #include "cutlass/bfloat16.h"
struct DecodingParams { enum class ModelType {
V32,
MODEL1
};
struct __align__(4*8) DecodingSchedMeta {
int begin_req_idx, end_req_idx; // Both inclusive
int begin_block_idx, end_block_idx; // Inclusive, exclusive
int begin_split_idx;
int is_first_req_splitted, is_last_req_splitted;
int _pad[1];
};
static constexpr int DecodingSchedMetaSize = sizeof(DecodingSchedMeta);
struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams
using index_t = int64_t; using index_t = int64_t;
int b; // batch size int b; // batch size
...@@ -14,13 +28,11 @@ struct DecodingParams { ...@@ -14,13 +28,11 @@ struct DecodingParams {
int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k
bool is_causal; bool is_causal;
float scale_softmax, scale_softmax_log2; float scale_softmax, scale_softmax_log2;
int topk;
void *__restrict__ q_ptr; void *__restrict__ q_ptr;
void *__restrict__ k_ptr; void *__restrict__ k_ptr;
void *__restrict__ o_ptr; void *__restrict__ o_ptr;
void *__restrict__ softmax_lse_ptr; float *__restrict__ softmax_lse_ptr;
int *__restrict__ indices_ptr;
index_t q_batch_stride; index_t q_batch_stride;
index_t k_batch_stride; index_t k_batch_stride;
...@@ -31,38 +43,106 @@ struct DecodingParams { ...@@ -31,38 +43,106 @@ struct DecodingParams {
index_t q_head_stride; index_t q_head_stride;
index_t k_head_stride; index_t k_head_stride;
index_t o_head_stride; index_t o_head_stride;
index_t indices_batch_stride;
index_t indices_row_stride;
int *__restrict__ block_table; int *__restrict__ block_table;
index_t block_table_batch_stride; index_t block_table_batch_stride;
int page_block_size; int page_block_size;
int *__restrict__ seqlens_k_ptr; int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr; DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr;
int num_sm_parts; int num_sm_parts;
int *__restrict__ num_splits_ptr; int *__restrict__ num_splits_ptr;
int total_num_splits; int total_num_splits;
void *__restrict__ softmax_lseaccum_ptr; float *__restrict__ softmax_lseaccum_ptr;
void *__restrict__ oaccum_ptr; float *__restrict__ oaccum_ptr;
cudaStream_t stream;
}; };
static constexpr int TileSchedulerMetaDataSize = 8; struct SparseAttnDecodeParams {
// [begin_idx (inclusive), begin_block_idx (inclusive), end_idx (inclusive), end_block_idx (exclusive), begin_n_split_idx, _, _, _] int b, s_q;
int h_q, h_kv;
int d_qk, d_v;
float sm_scale, sm_scale_div_log2;
int num_blocks, page_block_size, topk;
ModelType model_type;
struct GetDecodingMetadataParams { cutlass::bfloat16_t* __restrict__ q; // [b, s_q, h_q, d_qk]
int *__restrict__ seqlens_k_ptr; cutlass::bfloat16_t* __restrict__ kv; // [num_blocks, page_block_size, d_qk]
int *__restrict__ tile_scheduler_metadata_ptr; int* __restrict__ indices; // [b, s_q, topk]
int *__restrict__ num_splits_ptr; int* __restrict__ topk_length; // [b], may be nullptr
int batch_size; float* __restrict__ attn_sink; // [h_q], may be nullptr
float* __restrict__ lse; // [b, s_q, h_q]
cutlass::bfloat16_t* __restrict__ out; // [b, s_q, h_q, d_v]
int extra_num_blocks, extra_page_block_size, extra_topk;
cutlass::bfloat16_t* __restrict__ extra_kv; // [extra_num_blocks, extra_page_block_size, d_qk]
int* __restrict__ extra_indices; // [b, s_q, extra_topk]
int* __restrict__ extra_topk_length; // [b], may be nullptr
int stride_q_b, stride_q_s_q, stride_q_h_q;
int stride_kv_block, stride_kv_row;
int stride_indices_b, stride_indices_s_q;
int stride_lse_b, stride_lse_s_q;
int stride_o_b, stride_o_s_q, stride_o_h_q;
int stride_extra_kv_block, stride_extra_kv_row;
int stride_extra_indices_b, stride_extra_indices_s_q;
cudaStream_t stream;
// SplitKV-related parameters
float* __restrict__ lse_accum; // [num_splits, s_q, h_q]
float* __restrict__ o_accum; // [num_splits, s_q, h_q, d_v]
int stride_lse_accum_split, stride_lse_accum_s_q;
int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q;
DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; // [num_sm_parts, ], contiguous
int* __restrict__ num_splits_ptr; // [batch_size+1, ], contiguous
int num_sm_parts;
};
struct CombineParams {
int b, s_q, h_q, d_v;
float* __restrict__ lse; // [b, s_q, h_q]
void* __restrict__ out; // [b, s_q, h_q, d_v]
int stride_lse_b, stride_lse_s_q;
int stride_o_b, stride_o_s_q, stride_o_h_q;
float* __restrict__ lse_accum; // [num_splits, s_q, h_q]
float* __restrict__ o_accum; // [num_splits, s_q, h_q, d_v]
int stride_lse_accum_split, stride_lse_accum_s_q;
int stride_o_accum_split, stride_o_accum_s_q, stride_o_accum_h_q;
DecodingSchedMeta* __restrict__ tile_scheduler_metadata_ptr; // [num_sm_parts, ], contiguous
int* __restrict__ num_splits_ptr; // [batch_size+1, ], contiguous
int num_sm_parts;
float* attn_sink; // [h_q], may be nullptr
cudaStream_t stream;
};
struct GetDecodeSchedMetaParams {
int b; // batch size
int s_q;
int block_size_n; int block_size_n;
int fixed_overhead_num_blocks; int fixed_overhead_num_blocks;
int topk, extra_topk; // -1 if sparse attention (or extra topk) is disabled
int *__restrict__ topk_length, *__restrict__ extra_topk_length;
int *__restrict__ seqlens_k_ptr; // Only necessary for dense attention
DecodingSchedMeta *__restrict__ tile_scheduler_metadata_ptr;
int *__restrict__ num_splits_ptr;
int num_sm_parts; int num_sm_parts;
int topk;
cudaStream_t stream;
}; };
struct SparsePrefillParams { struct SparseAttnFwdParams {
int s_q, s_kv, h_q, h_kv, d_qk, d_v, topk; int s_q, s_kv, h_q, h_kv, d_qk, d_v, topk;
float sm_scale, sm_scale_div_log2; float sm_scale, sm_scale_div_log2;
...@@ -70,7 +150,10 @@ struct SparsePrefillParams { ...@@ -70,7 +150,10 @@ struct SparsePrefillParams {
cutlass::bfloat16_t* __restrict__ q; // [s_q, h_q, d_qk] cutlass::bfloat16_t* __restrict__ q; // [s_q, h_q, d_qk]
cutlass::bfloat16_t* __restrict__ kv; // [s_kv, h_kv, d_qk] cutlass::bfloat16_t* __restrict__ kv; // [s_kv, h_kv, d_qk]
int* __restrict__ indices; // [s_q, h_kv, topk] int* __restrict__ indices; // [s_q, h_kv, topk]
float* __restrict__ attn_sink; // [h_q], may be nullptr
int* __restrict__ topk_length; // [s_q], may be nullptr
// Strides
int stride_q_s_q; int stride_q_h_q; int stride_q_s_q; int stride_q_h_q;
int stride_kv_s_kv; int stride_kv_h_kv; int stride_kv_s_kv; int stride_kv_h_kv;
int stride_indices_s_q; int stride_indices_h_kv; int stride_indices_s_q; int stride_indices_h_kv;
...@@ -80,5 +163,18 @@ struct SparsePrefillParams { ...@@ -80,5 +163,18 @@ struct SparsePrefillParams {
float* __restrict__ max_logits; // [s_q, h_q] float* __restrict__ max_logits; // [s_q, h_q]
float* __restrict__ lse; // [s_q, h_q] float* __restrict__ lse; // [s_q, h_q]
int num_sm;
cudaStream_t stream; cudaStream_t stream;
}; };
// We have some kernels that implement both prefill and decode modes in a single kernel (with different template instantiations). The following enum helps to distinguish the modes.
enum class SparseAttnFwdMode {
Prefill, // Normal prefill mode
DecodeWithSplitKV, // To trigger decoding mode for kernels that support both prefill and decode
};
template<SparseAttnFwdMode FWD_MODE>
inline constexpr bool is_decode_v = std::bool_constant<FWD_MODE == SparseAttnFwdMode::DecodeWithSplitKV>::value;
template<SparseAttnFwdMode FWD_MODE>
using SparseFwdArgT = std::conditional_t<is_decode_v<FWD_MODE>, SparseAttnDecodeParams, SparseAttnFwdParams>;
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/fast_math.h>
#include "params.h"
#include "smxx/get_mla_metadata.h"
#include "smxx/mla_combine.h"
#include "sm90/decode/dense/splitkv_mla.h"
#include "sm90/decode/sparse_fp8/splitkv_mla.h"
#include "sm90/prefill/sparse/fwd.h"
#include "sm100/decode/sparse_fp8/splitkv_mla.h"
#include "sm100/prefill/dense/interface.h"
#include "sm100/prefill/sparse/fwd.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
struct Arch {
int major;
int minor;
bool is_sm90() const {
return major == 9 && minor == 0;
}
bool is_sm100() const {
return major == 10;
}
void assert_is_supported() const {
TORCH_CHECK(is_sm90() || is_sm100(), "Only SM90 and SM100 are supported");
}
};
// DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e. SM90 Dense BF16, SM90 Sparse FP8, etc.)
struct DecodingAttnImplMeta {
int num_sm_parts;
int fixed_overhead_num_blocks;
int k_block_size;
};
DecodingAttnImplMeta get_attn_impl_meta(
Arch arch,
int sm_count,
int num_q_tokens_per_head_k,
int h_k,
std::optional<int> h_q_,
bool is_fp8_kvcache,
bool is_sparse_attn
) {
if (arch.is_sm90()) {
if (is_sparse_attn) {
if (is_fp8_kvcache) {
TORCH_CHECK(h_q_.has_value());
int h_q = h_q_.value();
TORCH_CHECK(h_q % h_k == 0);
int s_q = num_q_tokens_per_head_k * h_k / h_q;
// FP8 + Sparse MLA
return {
std::max((sm_count/2) / h_k / (cutlass::ceil_div(h_q/h_k, 2*64) * s_q), 1),
5,
64
};
} else {
// Sparse BF16 MLA
TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM90");
}
} else {
if (is_fp8_kvcache) {
// Dense FP8 MLA
TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90");
} else {
// Dense BF16 MLA
return {
std::max(sm_count / h_k / cutlass::ceil_div(num_q_tokens_per_head_k, 64), 1),
5,
64
};
}
}
} else if (arch.is_sm100()) {
if (is_sparse_attn) {
if (is_fp8_kvcache) {
TORCH_CHECK(h_q_.has_value());
int h_q = h_q_.value();
TORCH_CHECK(h_q % h_k == 0);
int s_q = num_q_tokens_per_head_k * h_k / h_q;
// FP8 + Sparse MLA
return {
std::max(sm_count / h_k / (cutlass::ceil_div(h_q/h_k, 64) * s_q), 1),
5,
64
};
} else {
// Sparse BF16 MLA
TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM100");
}
} else {
if (is_fp8_kvcache) {
// FP8 MLA
TORCH_CHECK(false, "FP8 Dence MLA is not supported on SM100");
} else {
// Normal BF16 MLA
TORCH_CHECK(false, "BF16 Dence MLA is not supported on SM100");
}
}
} else {
TORCH_CHECK(false, "Unsupported GPU architecture");
}
}
std::vector<at::Tensor>
get_mla_decoding_metadata(
at::Tensor &seqlens_k,
const int num_q_tokens_per_head_k,
const int h_k,
const std::optional<int> h_q,
const bool is_fp8_kvcache,
const std::optional<int> topk
) {
bool is_sparse_attn = topk.has_value();
CHECK_DEVICE(seqlens_k);
TORCH_CHECK(seqlens_k.is_contiguous());
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
if (is_sparse_attn)
TORCH_CHECK(h_q.has_value(), "num_heads_q must be provided when topk is provided");
int batch_size = seqlens_k.size(0);
int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
auto options = seqlens_k.options();
auto dprops = at::cuda::getCurrentDeviceProperties();
int sm_count = dprops->multiProcessorCount;
Arch arch = {dprops->major, dprops->minor};
arch.assert_is_supported();
DecodingAttnImplMeta attn_impl_meta = get_attn_impl_meta(arch, sm_count, num_q_tokens_per_head_k, h_k, h_q, is_fp8_kvcache, is_sparse_attn);
auto tile_scheduler_metadata = torch::empty({attn_impl_meta.num_sm_parts, TileSchedulerMetaDataSize}, options);
auto num_splits = torch::empty({batch_size + 1}, options);
int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
int *num_splits_ptr = num_splits.data_ptr<int>();
at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
GetDecodingMetadataParams params = {};
params.seqlens_k_ptr = seqlens_k_ptr;
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
params.num_splits_ptr = num_splits_ptr;
params.batch_size = batch_size;
params.block_size_n = attn_impl_meta.k_block_size;
params.fixed_overhead_num_blocks = attn_impl_meta.fixed_overhead_num_blocks;
params.num_sm_parts = attn_impl_meta.num_sm_parts;
params.topk = is_sparse_attn ? topk.value() : -1;
run_get_mla_metadata_kernel(params, stream);
return {tile_scheduler_metadata, num_splits};
}
std::vector<at::Tensor>
fwd_kvcache_mla(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits, // batch_size + 1
const bool &is_fp8,
const std::optional<at::Tensor> &indices // None, or batch_size x seqlen_q x topk
) {
bool is_sparse_attn = indices.has_value();
int topk = is_sparse_attn ? indices->size(-1) : -1;
// Check the architecture
auto dprops = at::cuda::getCurrentDeviceProperties();
Arch arch = {dprops->major, dprops->minor};
arch.assert_is_supported();
// Check data types
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
if (!is_fp8) {
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
} else {
TORCH_CHECK(kcache.dtype() == torch::kFloat8_e4m3fn || kcache.dtype() == torch::kInt8 || kcache.dtype() == torch::kUInt8, "key must have dtype fp8_e4m3fn or int8 or uint8");
}
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
TORCH_CHECK(!is_sparse_attn || indices->dtype() == torch::kInt32, "indices must have dtype int32");
// Check device
CHECK_DEVICE(q);
CHECK_DEVICE(kcache);
CHECK_DEVICE(seqlens_k);
CHECK_DEVICE(block_table);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_DEVICE(num_splits);
if (is_sparse_attn) CHECK_DEVICE(indices.value());
// Check layout
TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension");
CHECK_CONTIGUOUS(seqlens_k);
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
CHECK_CONTIGUOUS(tile_scheduler_metadata);
CHECK_CONTIGUOUS(num_splits);
TORCH_CHECK(!is_sparse_attn || indices->stride(-1) == 1, "indices must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_q = sizes[2];
const int head_size_k = sizes[3];
TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported");
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int num_q_heads_per_hk = num_heads_q / num_heads_k;
const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
if (!is_fp8) {
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
} else {
int bytes_per_token = 512 + 64*2 + (512/128)*4;
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, bytes_per_token);
TORCH_CHECK(num_heads_k == 1, "Currently the number of k heads must be 1 when is_fp8_kvcache is True");
TORCH_CHECK(kcache.stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True");
}
CHECK_SHAPE(seqlens_k, batch_size);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_SHAPE(num_splits, batch_size+1);
if (is_sparse_attn) CHECK_SHAPE(indices.value(), batch_size, seqlen_q_ori, topk);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
CHECK_CONTIGUOUS(softmax_lse);
DecodingParams params = {};
// Set the sizes.
params.b = batch_size;
params.s_q = seqlen_q_ori;
params.q_seq_per_hk = q_seq_per_hk;
params.seqlens_k_ptr = seqlens_k.data_ptr<int>();
params.h_q = num_heads_q;
params.h_k = num_heads_k;
params.num_blocks = num_blocks;
params.q_head_per_hk = num_q_heads_per_hk;
params.is_causal = is_causal;
params.d = head_size_k;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
params.topk = topk;
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.o_ptr = out.data_ptr();
params.indices_ptr = is_sparse_attn ? indices->data_ptr<int>() : nullptr;
params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(-3);
params.k_row_stride = kcache.stride(1);
params.o_row_stride = out.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = kcache.stride(2);
params.o_head_stride = out.stride(-2);
params.indices_batch_stride = is_sparse_attn ? indices->stride(0) : 0;
params.indices_row_stride = is_sparse_attn ? indices->stride(1) : 0;
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
params.num_sm_parts = tile_scheduler_metadata.size(0);
params.num_splits_ptr = num_splits.data_ptr<int>();
const int total_num_splits = batch_size + params.num_sm_parts;
at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
CHECK_CONTIGUOUS(softmax_lse_accum);
CHECK_CONTIGUOUS(out_accum);
params.total_num_splits = total_num_splits;
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size_k == 576);
if (q_dtype == torch::kHalf) {
#ifdef FLASH_MLA_DISABLE_FP16
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
#endif
}
if (arch.is_sm90()) {
if (is_sparse_attn) {
if (is_fp8) {
TORCH_CHECK(q_dtype == torch::kBFloat16, "Sparse FP8 MLA only supports BFloat16 on SM90");
sm90::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream);
} else {
TORCH_CHECK(false, "Only FP8 kvcahe is supported for sparse MLA on SM90");
}
} else {
if (is_fp8) {
TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90");
} else {
if (q_dtype == torch::kBFloat16) {
sm90::run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params, stream);
} else if (q_dtype == torch::kHalf) {
#ifndef FLASH_MLA_DISABLE_FP16
sm90::run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
#endif
} else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
}
}
}
} else if (arch.is_sm100()) {
TORCH_CHECK(is_fp8 && is_sparse_attn, "Only FP8 + Sparse attention is supported on SM100");
sm100::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream);
} else {
TORCH_CHECK(false, "Unsupported GPU architecture");
}
if (q_dtype == torch::kBFloat16) {
run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
} else if (q_dtype == torch::kHalf) {
#ifndef FLASH_MLA_DISABLE_FP16
run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
#endif
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)
.reshape({batch_size, num_heads_q, seqlen_q_ori});
return {out, softmax_lse};
}
inline int int64_stride_to_int(int64_t orig_stride) {
if (orig_stride > std::numeric_limits<int>::max()) {
TORCH_CHECK(false, "[Sparse TopK Attention] Stride exceeds int32 limit: ", orig_stride);
}
return static_cast<int>(orig_stride);
}
std::vector<at::Tensor> sparse_prefill_fwd(
const at::Tensor &q,
const at::Tensor &kv,
const at::Tensor &indices,
float sm_scale,
int d_v
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9;
bool is_sm100 = dprops->major == 10;
TORCH_CHECK(is_sm90 || is_sm100, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 or SM100 architectures");
CHECK_DEVICE(q);
CHECK_DEVICE(kv);
CHECK_DEVICE(indices);
TORCH_CHECK(q.dtype() == torch::kBFloat16);
TORCH_CHECK(kv.dtype() == torch::kBFloat16);
TORCH_CHECK(indices.dtype() == torch::kInt32);
int s_q = q.size(0);
int s_kv = kv.size(0);
int h_q = q.size(1);
int h_kv = kv.size(1);
int d_qk = q.size(2);
int topk = indices.size(2);
CHECK_SHAPE(q, s_q, h_q, d_qk);
CHECK_SHAPE(kv, s_kv, h_kv, d_qk);
CHECK_SHAPE(indices, s_q, h_kv, topk);
TORCH_CHECK(q.stride(-1) == 1);
TORCH_CHECK(kv.stride(-1) == 1);
TORCH_CHECK(indices.stride(-1) == 1);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({s_q, h_q, d_v}, opts);
CHECK_CONTIGUOUS(out);
at::Tensor buf_attn_score, max_logits, lse, p_sum;
max_logits = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
CHECK_CONTIGUOUS(max_logits);
CHECK_CONTIGUOUS(lse);
SparsePrefillParams params = {
s_q, s_kv, h_q, h_kv, d_qk, d_v, topk,
sm_scale, sm_scale * 1.44269504f,
(cutlass::bfloat16_t*)q.data_ptr(),
(cutlass::bfloat16_t*)kv.data_ptr(),
(int*)indices.data_ptr(),
int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)),
int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)),
int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)),
(cutlass::bfloat16_t*)out.data_ptr(),
(float*)max_logits.data_ptr(),
(float*)lse.data_ptr(),
at::cuda::getCurrentCUDAStream().stream()
};
if (is_sm90) {
sm90::run_fwd_kernel(params);
} else if (is_sm100) {
sm100::run_fwd_kernel(params);
} else {
TORCH_CHECK(false, "Unknown architecture");
}
return {out, max_logits, lse};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashMLA";
m.def("get_mla_decoding_metadata", &get_mla_decoding_metadata);
m.def("fwd_kvcache_mla", &fwd_kvcache_mla);
m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun);
m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun);
m.def("sparse_prefill_fwd", &sparse_prefill_fwd);
}
Head128 decoding kernels are located at `csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu` (for k_dim = 512) or simulated using 2x head64 kernel
\ No newline at end of file
#pragma once
#include "kernel.h"
#include <cuda_fp8.h>
#include <cutlass/barrier.h>
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
#include "defines.h"
#include "params.h"
namespace sm100::decode::head64 {
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::NamedBarrier;
using e8m0 = __nv_fp8_e8m0;
using e4m3 = cutlass::float_e4m3_t;
using namespace cute;
enum NamedBarriers : uint32_t {
main_loop_sync = 0,
wg0_sync = 1,
wg0_warp02_sync = 2,
wg0_warp13_sync = 3,
everyone_sync = 4
};
template<ModelType MODEL_TYPE>
struct KernelTemplate {
static constexpr int D_Q = MODEL_TYPE == ModelType::V32 ? 576 : 512;
static constexpr int D_K = D_Q;
static constexpr int D_V = 512;
static constexpr int D_NOPE = MODEL_TYPE == ModelType::V32 ? 512 : 448;
static constexpr int D_ROPE = 64;
static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64;
static constexpr bool V_HAVE_ROPE = MODEL_TYPE == ModelType::V32 ? false : true;
static constexpr int NUM_SCALES_EACH_TOKEN = MODEL_TYPE == ModelType::V32 ? 4 : 8; // Padding is included
static constexpr int TMA_K_STRIDE = MODEL_TYPE == ModelType::V32 ? D_NOPE+2*D_ROPE+4*(D_NOPE/QUANT_TILE_SIZE) : D_NOPE+2*D_ROPE; // Stride of K's tensormap. This stride must 1) be a factor of the actual stride between tokens 2) large enough to cover the entire KV cache. Since TMA copy's coordinate can only be 32bit signed integers, this number must >= 128, perferrably >= 256. So we set this to 656 for V32 and 576 for MODEL1. Extra padding may be necessary for KV blocks.
static_assert(D_NOPE + D_ROPE == D_Q);
static_assert(V_HAVE_ROPE ? (D_NOPE + D_ROPE == D_V) : (D_NOPE == D_V));
static constexpr int B_H = 64;
static constexpr int B_TOPK = 64;
static constexpr int NUM_BUFS = 2;
static constexpr int NUM_INDEX_BUFS = 4; // Number of buffers for indices (tma_coords) & is_token_valid & scales
static constexpr int NUM_THREADS = 128*3; // 128 exp + 1/32 utcmma + 1/32 raw KV producer + 1/32 rope producer + 32 index+scale+valid_mask producer + 128 dequant
static constexpr float MAX_INIT_VAL = -1e30f; // To avoid (-inf) - (-inf) = NaN
static constexpr int D_Q_SW128 = 512;
static constexpr int D_Q_SW64 = MODEL_TYPE == ModelType::V32 ? 64 : 0;
static_assert(D_Q_SW128 + D_Q_SW64 == D_Q);
static constexpr int K_ROPE_SW = MODEL_TYPE == ModelType::V32 ? 64 : 128; // RoPE part stored in SW64 (for V32) or SW128 (for MODEL1), in bytes
template<
typename Shape_Q_SW128, typename TMA_Q_SW128,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q_SW128 shape_Q_SW128; TMA_Q_SW128 tma_Q_SW128;
Shape_O shape_O; TMA_O tma_O;
CUtensorMap tensor_map_q_sw64; // Invalid if D_Q_SW64 == 0
CUtensorMap tensor_map_kv_nope;
CUtensorMap tensor_map_kv_rope;
CUtensorMap tensor_map_extra_kv_nope;
CUtensorMap tensor_map_extra_kv_rope;
};
// Tensor memory columns
struct tmem_cols {
// 0 ~ 256: output
// 256 ~ 256 + 64*D_Q/256: Q
// 400 ~ 464: P
static constexpr int O = 0;
static constexpr int Q = 256;
static constexpr int Q_Tail = 256 + B_H*D_NOPE/2/128;
static constexpr int P = 400;
};
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<NUM_TILES*64>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutQ_SW128 = SmemLayoutQTiles<D_Q_SW128/64>;
using SmemLayoutOBuf = decltype(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<D_V>>{}
));
using SmemLayoutOBuf_TMA = decltype(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64>>{}
)); // A TMA tile
static_assert(D_V == 512);
using SmemLayoutOAccumBuf = Layout<
Shape<Int<B_H>, Int<D_V>>,
Stride<Int<520>, _1> // We use stride = 520 here to avoid bank conflict
>;
using SmemLayoutS = decltype(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H>, Int<B_TOPK>>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles_SW128 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTiles_DualGemm_SW128 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H*2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed_SW128 = decltype(composition(
SmemLayoutKTiles_SW128<NUM_TILES>{},
Layout<
Shape<Int<64*NUM_TILES>, Int<B_TOPK>>,
Stride<Int<B_TOPK>, _1>
>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles_SW64 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H>, Int<32*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTiles_DualGemm_SW64 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H*2>, Int<32*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed_SW64 = decltype(composition(
SmemLayoutKTiles_SW64<NUM_TILES>{},
Layout<
Shape<Int<32*NUM_TILES>, Int<B_TOPK>>,
Stride<Int<B_TOPK>, _1>
>{}
));
struct SharedMemoryPlan {
union {
struct {
array_aligned<bf16, cosize_v<SmemLayoutQ_SW128>> q;
bf16 q_sw64[B_H*D_Q_SW64]; // NOTE D_Q_SW64 may be 0 but array_aligned<bf16, 0> will have a size of 16, so we use array here. The former tensor (`q`) promises its alignment.
union {
array_aligned<bf16, cosize_v<SmemLayoutOBuf>> o_buf;
array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> o_accum_buf;
} o;
} qo;
struct {
struct {
array_aligned<bf16, B_H*D_NOPE> nope; // NoPE part, dequantized
array_aligned<bf16, B_H*D_ROPE> rope; // RoPE part, dequantized. SW64 in v32 mode, SW128 in MODEL1 mode
} dequant[NUM_BUFS];
static_assert(sizeof(dequant) >= sizeof(bf16) * (B_H*D_Q)); // So that Q does not covers raw_nope
array_aligned<e4m3, B_H*D_NOPE> raw_nope[NUM_BUFS]; // Raw (quantized) NoPE part
} kv;
} u;
union {
float4 p_exchange_buf[4][16 * B_TOPK / 4];
array_aligned<bf16, cosize_v<SmemLayoutS>> s;
} s_p;
CUTE_ALIGNAS(16) float rowwise_max_buf[128];
char is_token_valid[NUM_INDEX_BUFS][B_TOPK/8];
int tma_coord[NUM_INDEX_BUFS][B_TOPK];
e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN];
array_aligned<uint32_t, 1> tmem_start_addr;
transac_bar_t bar_last_store_done;
transac_bar_t bar_q_tma, bar_q_utccp;
transac_bar_t bar_rope_ready[NUM_BUFS];
transac_bar_t bar_nope_ready[NUM_BUFS];
transac_bar_t bar_raw_ready[NUM_BUFS], bar_raw_free[NUM_BUFS];
transac_bar_t bar_valid_coord_scale_ready[NUM_INDEX_BUFS], bar_valid_coord_scale_free[NUM_INDEX_BUFS];
transac_bar_t bar_qk_done[NUM_BUFS], bar_so_ready[NUM_BUFS], bar_sv_done[NUM_BUFS];
};
using TiledMMA_P = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_TS_NOELECT<bf16, bf16, float, B_H, B_TOPK*2, UMMA::Major::K, UMMA::Major::K>{}
)); // *2 for dual gemm
using TiledMMA_O = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{}
));
template<typename TmaParam>
static __device__ void
flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams &params, const TmaParam &tma_params);
static void run(const SparseAttnDecodeParams &params);
};
}
\ No newline at end of file
#include "../kernel.cuh"
namespace sm100::decode::head64 {
template
void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1>(const SparseAttnDecodeParams &params);
}
#include "../kernel.cuh"
namespace sm100::decode::head64 {
template
void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32>(const SparseAttnDecodeParams &params);
}
#include "kernel.h"
#include <math_constants.h>
#include <cutlass/barrier.h>
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/tensor.hpp>
#include <cute/arch/tmem_allocator_sm100.hpp>
#include "kerutils/kerutils.cuh"
#include "utils.h"
#include "sm100/helpers.h"
#include "config.h"
namespace sm100::decode::head64 {
template<ModelType MODEL_TYPE>
template<typename TmaParam>
__device__ void
KernelTemplate<MODEL_TYPE>
::flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams &params, const TmaParam &tma_params) {
#if defined(KERUTILS_ENABLE_SM100A)
const int s_q_idx = blockIdx.x;
const int partition_idx = blockIdx.y;
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int idx_in_warpgroup = threadIdx.x % 128;
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int lane_idx = threadIdx.x % 32;
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
if (warp_idx == 0 && elect_one_sync()) {
cute::prefetch_tma_descriptor(tma_params.tma_Q_SW128.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());
cute::prefetch_tma_descriptor(&tma_params.tensor_map_q_sw64);
cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_nope);
cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_rope);
}
if (warp_idx == 0) {
if (elect_one_sync()) {
plan.bar_last_store_done.init(128);
plan.bar_q_tma.init(1);
plan.bar_q_utccp.init(1);
for (int i = 0; i < NUM_BUFS; ++i) {
plan.bar_rope_ready[i].init(1);
plan.bar_nope_ready[i].init(128);
plan.bar_raw_ready[i].init(1);
plan.bar_raw_free[i].init(128);
plan.bar_qk_done[i].init(1);
plan.bar_so_ready[i].init(128);
plan.bar_sv_done[i].init(1);
}
for (int i = 0; i < NUM_INDEX_BUFS; ++i) {
plan.bar_valid_coord_scale_ready[i].init(32);
plan.bar_valid_coord_scale_free[i].init(128+128+1+1);
}
cutlass::arch::fence_barrier_init();
}
cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data());
KU_TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0);
cute::TMEM::Allocator1Sm().release_allocation_lock();
}
__syncthreads();
struct MainLoopArgs {
int batch_idx, start_block_idx, end_block_idx;
bool is_no_split; int n_split_idx;
bool bar_phase_batch_rel; // Bar phase of barriers that are used once per batch
int topk_length, extra_topk_length, num_orig_kv_blocks;
bool is_last_batch;
};
auto run_main_loop = [&](auto f) {
// NOTE Putting the following code outside the warpgroup specialization switch results in register spilling.
// [[maybe_unused]] int begin_req_idx, end_req_idx, sched_begin_block_idx, sched_end_block_idx, begin_n_split_idx, is_first_req_splitted, is_last_req_splitted;
DecodingSchedMeta sched_meta;
KU_LDG_256(
params.tile_scheduler_metadata_ptr + partition_idx,
&sched_meta,
".nc",
"no_allocate",
"evict_normal",
"256B"
);
if (sched_meta.begin_req_idx >= params.b) {
return;
}
bool bar_phase_batch_rel = 0;
#pragma unroll 1
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx, bar_phase_batch_rel ^= 1) {
int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk;
int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK);
int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;
int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK); // % B_TOPK == 0
int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / B_TOPK;
bool is_split = batch_idx == sched_meta.begin_req_idx ? sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? sched_meta.is_last_req_splitted : false);
int n_split_idx = batch_idx == sched_meta.begin_req_idx ? (__ldg(params.num_splits_ptr+batch_idx) + sched_meta.begin_split_idx) : __ldg(params.num_splits_ptr+batch_idx);
MainLoopArgs args = {
batch_idx, start_block_idx, end_block_idx,
!is_split, n_split_idx,
bar_phase_batch_rel,
topk_length, extra_topk_length,
orig_topk_padded / B_TOPK,
batch_idx == sched_meta.end_req_idx
};
f(args);
NamedBarrier(NUM_THREADS, NamedBarriers::everyone_sync).arrive_and_wait_unaligned();
}
};
struct RingState {
int buf_idx = 0;
bool bar_phase = 0;
int index_buf_idx = 0;
bool index_bar_phase = 0;
CUTE_DEVICE void update() {
bar_phase ^= (buf_idx == NUM_BUFS-1);
buf_idx = (buf_idx+1) % NUM_BUFS;
index_bar_phase ^= (index_buf_idx == NUM_INDEX_BUFS-1);
index_buf_idx = (index_buf_idx+1) % NUM_INDEX_BUFS;
}
};
RingState rs;
if (warpgroup_idx == 0) {
// Scale & Exp warpgroup
// The same technique (and highly similar code) as the sm100 sparse prefill head64 kernel
cutlass::arch::warpgroup_reg_alloc<224>();
constexpr int B_EPI = 64; // Must be equal to the size of the swizzle atom
Tensor sO = make_tensor(make_smem_ptr(plan.u.qo.o.o_buf.data()), SmemLayoutOBuf{});
bf16* sO_bases[B_EPI/8]; // 64 is the size of the swizzle atom (in number of elements) while 8 is the width of each write
CUTE_UNROLL
for (int i = 0; i < B_EPI/8; ++i)
sO_bases[i] = &sO(idx_in_warpgroup%64, (idx_in_warpgroup/64)*128 + i*8);
const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2};
bf16* sS_base = plan.s_p.s.data() + lane_idx*8 + (warp_idx&1)*(B_H/2)*8 + (warp_idx/2)*B_H*(B_TOPK/2);
float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg((float*)params.attn_sink + (idx_in_warpgroup%64)) * CUDART_L2E_F;
run_main_loop([&](const MainLoopArgs &args) {
cute::tma_store_wait<0>();
plan.bar_last_store_done.arrive();
float mi = MAX_INIT_VAL;
float li = 0.0f;
float real_mi = -CUDART_INF_F;
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); // Make sure all intermediate buffers (including p_exchange_buf, rowwise max_buf) are free
plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase); // Put the barrier wait here for more code reordering space
plan.bar_qk_done[rs.buf_idx].wait(rs.bar_phase);
ku::tcgen05_after_thread_sync();
// Load P
float p[B_TOPK/2], p_peer[B_TOPK/2];
if (warp_idx < 2) {
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P, p);
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P+32, p_peer);
} else {
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P, p_peer);
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P+32, p);
}
cutlass::arch::fence_view_async_tmem_load();
ku::tcgen05_before_thread_sync();
// Reduce within shared mem
{
// Store
// Warp 0, 1 store their right (col 32 ~ 63) part, while warp 2, 3 store their left (row 0 ~ 31) part
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/4; ++i)
plan.s_p.p_exchange_buf[warp_idx^2][i*32 + lane_idx] = *(float4*)(p_peer + i*4);
NamedBarrier::arrive_and_wait(64, NamedBarriers::wg0_warp02_sync+(warp_idx&1)); // Synchronize between warp 0 and warp 2, as well as warp 1 - warp 3
// Load
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/4; ++i) {
float2 t[2];
*(float4*)t = plan.s_p.p_exchange_buf[warp_idx][i*32 + lane_idx];
float2* cur_p = (float2*)(p + i*4);
cur_p[0] = ku::float2_add(cur_p[0], t[0]);
cur_p[1] = ku::float2_add(cur_p[1], t[1]);
}
}
// Since dual gemm is utilized, the layout of P in register now look like:
//
// 32 32
// +-------+-------+
// | | |
// 32 | Warp0 | Warp2 |
// | | |
// +-------+-------+
// | | |
// 32 | Warp1 | Warp3 |
// | | |
// +-------+-------+
// Mask
uint32_t valid_mask = *((uint32_t*)plan.is_token_valid[rs.index_buf_idx] + (idx_in_warpgroup>=64?1:0));
CUTE_UNROLL
for (int i = 0; i < B_TOPK/2; i += 1) {
if (!(valid_mask>>i&1))
p[i] = -CUDART_INF_F;
}
// Get rowwise max of Pi
float cur_pi_max = -CUDART_INF_F;
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2); i += 1) {
cur_pi_max = max(cur_pi_max, p[i]);
}
cur_pi_max *= params.sm_scale_div_log2;
plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max;
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); // This also separates "reading p_exchange_buf" and "writing S"
plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive();
cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]);
real_mi = max(real_mi, cur_pi_max);
bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f);
// By this point:
// - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)
// - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127)
// Calc scale factor, and scale li
float new_max, scale_for_old;
if (!should_scale_o) {
// Don't scale O
scale_for_old = 1.0f;
new_max = mi;
} else {
new_max = max(cur_pi_max, mi);
scale_for_old = exp2f(mi - new_max);
}
mi = new_max; // mi is still identical within each row
// Calculate S
__nv_bfloat162 s[(B_TOPK/2)/2];
float2 neg_new_max = float2 {-new_max, -new_max};
float2 cur_sum = float2 {0.0f, 0.0f};
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/2; i += 1) {
float2 d = ku::float2_fma(float2{p[i*2], p[i*2+1]}, scale, neg_new_max);
d.x = exp2f(d.x);
d.y = exp2f(d.y);
cur_sum = ku::float2_add(cur_sum, d);
s[i] = __float22bfloat162_rn(d);
}
li = fma(li, scale_for_old, (cur_sum.x + cur_sum.y));
// Write S
CUTE_UNROLL
for (int i = 0; i < B_TOPK/2/8; i += 1) {
*(uint128_t*)(sS_base + B_H*8*i) = *(uint128_t*)(s + i*4);
}
// Scale O
if (block_idx != args.start_block_idx && should_scale_o) {
float2 scale_for_old_float2 = float2 {scale_for_old, scale_for_old};
ku::tcgen05_after_thread_sync();
static constexpr int CHUNK_SIZE = 64;
float2 o[CHUNK_SIZE/2];
CUTE_UNROLL
for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) {
// Load O
ku::tmem_ld_32dp32bNx<CHUNK_SIZE>(tmem_cols::O + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_load();
// Mult
for (int i = 0; i < CHUNK_SIZE/2; ++i) {
o[i] = ku::float2_mul(o[i], scale_for_old_float2);
}
// Store O
ku::tmem_st_32dp32bNx<CHUNK_SIZE>(tmem_cols::O + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_store();
}
ku::tcgen05_before_thread_sync();
}
fence_view_async_shared();
plan.bar_so_ready[rs.buf_idx].arrive();
if (block_idx != args.end_block_idx-1) {
rs.update(); // Don't update rs for the last round since we want to wait for the last SV gemm
}
}
if (real_mi == -CUDART_INF_F) {
// real_mi == -CUDART_INF_F <=> No valid TopK indices
// We set li to 0 to fit the definition that li := exp(x[i] - mi)
li = 0.0f;
mi = -CUDART_INF_F;
}
// Exchange li
plan.rowwise_max_buf[idx_in_warpgroup] = li;
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);
li += plan.rowwise_max_buf[idx_in_warpgroup^64];
// Store li
if (idx_in_warpgroup < B_H) {
if (args.is_no_split) {
float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li));
cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse;
float* gSoftmaxLse = (float*)params.lse + args.batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + idx_in_warpgroup;
*gSoftmaxLse = cur_lse;
} else {
float cur_lse = log2f(li) + mi;
float* gSoftmaxLseAccum = (float*)params.lse_accum + args.n_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + idx_in_warpgroup;
*gSoftmaxLseAccum = cur_lse;
}
}
plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase);
rs.update();
ku::tcgen05_after_thread_sync();
if (args.is_last_batch) {
cudaTriggerProgrammaticLaunchCompletion();
}
if (args.is_no_split) {
Tensor tma_gO = flat_divide(
tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, args.batch_idx),
Shape<Int<B_H>, Int<64>>{}
)(_, _, _0{}, _);
auto thr_tma = tma_params.tma_O.get_slice(_0{});
Tensor tma_sO = flat_divide(
sO,
Shape<Int<B_H>, Int<64>>{}
)(_, _, _0{}, _);
float o_scale = li == 0.0f ? 0.0f : __fdividef(1.0f, li + exp2f(attn_sink - mi));
float2 o_scale_float2 = {o_scale, o_scale};
float2 o[B_EPI/2];
__nv_bfloat162 o_bf16[B_EPI/2];
CUTE_UNROLL
for (int i = 0; i < (D_V/2) / B_EPI; ++i) {
// Load
ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::O + i*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
// Scale & Convert
CUTE_UNROLL
for (int j = 0; j < B_EPI/2; ++j) {
o[j] = ku::float2_mul(o[j], o_scale_float2);
o_bf16[j] = __float22bfloat162_rn(o[j]);
}
// Store
int col_base = (i*B_EPI>=D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4));
CUTE_UNROLL
for (int j = 0; j < B_EPI / 8; ++j)
*(__int128_t*)(sO_bases[j] + col_base*B_H) = *(__int128_t*)(&o_bf16[j*4]);
// Sync
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);
// S -> G
if (warp_idx == 0 && elect_one_sync()) {
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(tma_sO(_, _, col_base/64)),
thr_tma.partition_D(tma_gO(_, _, col_base/64))
);
}
if (warp_idx == 1 && elect_one_sync()) {
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(tma_sO(_, _, col_base/64 + (D_V/4)/64)),
thr_tma.partition_D(tma_gO(_, _, col_base/64 + (D_V/4)/64))
);
}
}
cute::tma_store_arrive();
} else {
float o_scale = li == 0.0f ? 0.0f : __fdividef(1.0f, li); // Here we leave attn_sink to the combine kernel, otherwise attn_sink will take effect for multiple times
float2 o_scale_float2 = {o_scale, o_scale};
constexpr int B_EPI = 64;
float2 o[B_EPI/2];
Tensor sO = make_tensor(make_smem_ptr(plan.u.qo.o.o_accum_buf.data()), SmemLayoutOAccumBuf{});
CUTE_UNROLL
for (int i = 0; i < (D_V/2) / B_EPI; ++i) {
// Load
ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::O + i*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
// Scale & Convert
CUTE_UNROLL
for (int j = 0; j < B_EPI/2; ++j)
o[j] = ku::float2_mul(o[j], o_scale_float2);
// Store
int col_base = (idx_in_warpgroup/64)*128 + (i*B_EPI >= D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4));
CUTE_UNROLL
for (int j = 0; j < B_EPI / 4; ++j)
*(__int128_t*)&sO(idx_in_warpgroup%64, col_base + j*4) = *(__int128_t*)(&o[j*2]);
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);
if (elect_one_sync()) {
CUTE_UNROLL
for (int local_row = 0; local_row < B_H/4; ++local_row) {
int smem_row = local_row*4 + warp_idx;
SM90_BULK_COPY_S2G::copy(
&sO(smem_row, _0{}),
(float*)params.o_accum + args.n_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + smem_row*params.stride_o_accum_h_q,
D_V*sizeof(float)
);
}
cute::tma_store_arrive();
}
}
});
if (warp_idx == 0) {
cute::TMEM::Allocator1Sm().free(0, 512);
}
} else if (warpgroup_idx == 1) {
cutlass::arch::warpgroup_reg_dealloc<72>();
const int warp_idx = cutlass::canonical_warp_idx_sync(); // Missing this leads to reg spilling
if (warp_idx == 4 && elect_one_sync()) {
// MMA Warp
run_main_loop([&](const MainLoopArgs &args) {
if (args.start_block_idx >= args.end_block_idx) {
ku::trap();
}
// Issue Q (SW128) G->S
{
Tensor gQ = tma_params.tma_Q_SW128.get_tma_tensor(tma_params.shape_Q_SW128)(_, _, s_q_idx, args.batch_idx);
Tensor sQ = make_tensor(make_smem_ptr(plan.u.qo.q.data()), SmemLayoutQ_SW128{});
ku::launch_tma_copy(
tma_params.tma_Q_SW128,
gQ,
sQ,
plan.bar_q_tma,
TMA::CacheHintSm90::EVICT_FIRST
);
}
// Issue Q (SW64) G -> S
if constexpr (D_Q_SW64 > 0) {
cute::SM90_TMA_LOAD_5D::copy(
&tma_params.tensor_map_q_sw64,
(uint64_t*)&plan.bar_q_tma,
(uint64_t)TMA::CacheHintSm90::EVICT_FIRST,
plan.u.qo.q_sw64,
0, 0, 0,
s_q_idx, args.batch_idx
);
}
plan.bar_q_tma.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16));
plan.bar_q_tma.wait(args.bar_phase_batch_rel);
ku::tcgen05_after_thread_sync();
// Issue Q (SW128) UTCCP
{
UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc<UMMA::Major::K>(
make_tensor(
make_smem_ptr(plan.u.qo.q.data()),
tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H*2>, Int<64>>{} // *2 to leverage dual GEMM
)
)
);
static_assert(D_Q_SW128%128 == 0);
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < D_Q_SW128/128; ++tile_idx) {
// Each tile: 64 x (64*2) logically, 128 x 64 bf16 on TMEM
CUTE_UNROLL
for (int subtile_idx = 0; subtile_idx < 64/16; ++subtile_idx) {
// Each subtile: 64 x (16*2) logically, 128 x 16 bf16 (128dp256b) on TMEM
SM100_UTCCP_128dp256bit_1cta::copy(
sQ_desc + (tile_idx*(B_H*128) + subtile_idx*16) * 2 / 16,
tmem_cols::Q + tile_idx*32 + subtile_idx*8
);
}
}
}
// Issue Q (SW64) UTCCP
if constexpr (D_Q_SW64 > 0) {
UMMA::SmemDescriptor sQ_SW64_desc = UMMA::make_umma_desc<UMMA::Major::K>(
make_tensor(
make_smem_ptr(plan.u.qo.q_sw64),
tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H*2>, Int<32>>{} // *2 to leverage dual GEMM
)
)
);
static_assert(D_Q_SW64%64 == 0);
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < D_Q_SW64/64; ++tile_idx) {
// Each tile: 64 x (32*2) logically, 128 x 32 bf16 on TMEM
CUTE_UNROLL
for (int subtile_idx = 0; subtile_idx < 32/16; ++subtile_idx) {
// Each subtile: 64 x (16*2) logically, 128 x 16 bf16 (128dp256b) on TMEM
SM100_UTCCP_128dp256bit_1cta::copy(
sQ_SW64_desc + (tile_idx*(B_H*64) + subtile_idx*16) * 2 / 16,
tmem_cols::Q + (B_H*D_Q_SW128/2/128) + tile_idx*16 + subtile_idx*8
);
}
}
}
ku::umma_arrive_noelect(plan.bar_q_utccp);
// Allocate tmem tensors
TiledMMA tiled_mma_P = TiledMMA_P{};
TiledMMA tiled_mma_O = TiledMMA_O{};
// NOTE These tXXX tensors are only for a forged layout (so that CuTe is able to generate correct address in cute::gemm)
Tensor tP = partition_fragment_C(tiled_mma_P, Shape<Int<B_H>, _128>{});
Tensor tO = partition_fragment_C(tiled_mma_O, Shape<Int<B_H>, Int<D_V>>{});
tP.data().get() = tmem_cols::P;
tO.data().get() = tmem_cols::O;
// Wait for UTCCP
plan.bar_q_utccp.wait(args.bar_phase_batch_rel);
ku::tcgen05_after_thread_sync();
// Mainloop
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
if constexpr (MODEL_TYPE == ModelType::V32) {
// V3.2: RoPE behaves like an extra block with size 64, so we can do RoPE first
// QK RoPE
plan.bar_rope_ready[rs.buf_idx].wait(rs.bar_phase);
ku::tcgen05_after_thread_sync();
Tensor tQ_rope = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<D_ROPE/2>>{})
);
tQ_rope.data().get() = tmem_cols::Q_Tail;
Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].rope.data()), SmemLayoutKTiles_DualGemm_SW64<2/2>{});
ku::utcmma_ts(tiled_mma_P, tQ_rope, sK_rope, tP, true);
// QK NoPE
plan.bar_nope_ready[rs.buf_idx].wait(rs.bar_phase);
ku::tcgen05_after_thread_sync();
Tensor tQ_nope = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<D_NOPE/2>>{})
);
tQ_nope.data().get() = tmem_cols::Q;
Tensor sK_nope = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTiles_DualGemm_SW128<512/64/2>{});
ku::utcmma_ts(tiled_mma_P, tQ_nope, sK_nope, tP, false);
} else {
// MODEL1: RoPE is the last 64 dims within the full 512 dim, which couples with the last 64 dim from the NoPE part when performing dual GEMM. i.e.
//
// logical view: |0|1|2|3|4|5|6|7| (where 7 is the RoPE part)
// dual gemm's view:
// |0|2|4|6|
// |1|3|5|7|
//
// So we must wait for both the NoPE and the RoPE part, and then perform dual GEMM
plan.bar_rope_ready[rs.buf_idx].wait(rs.bar_phase);
plan.bar_nope_ready[rs.buf_idx].wait(rs.bar_phase);
ku::tcgen05_after_thread_sync();
Tensor tQ = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<D_Q/2>>{})
);
tQ.data().get() = tmem_cols::Q;
Tensor sK = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTiles_DualGemm_SW128<512/64/2>{});
ku::utcmma_ts(tiled_mma_P, tQ, sK, tP, true);
}
ku::umma_arrive_noelect(plan.bar_qk_done[rs.buf_idx]);
// SV
plan.bar_so_ready[rs.buf_idx].wait(rs.bar_phase);
ku::tcgen05_after_thread_sync();
Tensor sS = make_tensor(make_smem_ptr(plan.s_p.s.data()), SmemLayoutS{});
Tensor sV = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTilesTransposed_SW128<D_V/64>{}); // NOTE: For MODEL1, it "expands" to the RoPE part.
ku::utcmma_ss(tiled_mma_O, sS, sV, tO, block_idx == args.start_block_idx);
ku::umma_arrive_noelect(plan.bar_sv_done[rs.buf_idx]);
rs.update();
}
});
} else if (warp_idx == 5 && elect_one_sync()) {
// Raw KV NoPE retrieval warp
run_main_loop([&](const MainLoopArgs &args) {
plan.bar_q_utccp.wait(args.bar_phase_batch_rel);
plan.bar_last_store_done.wait(args.bar_phase_batch_rel);
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase);
plan.bar_raw_free[rs.buf_idx].wait(rs.bar_phase^1);
int4 cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + 0);
int4 nxt_cur_indices;
CUTE_UNROLL
for (int row = 0; row < B_TOPK; row += 4) {
if (row+4 < B_TOPK)
nxt_cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + row + 4);
ku::tma_gather4(
block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_nope : &tma_params.tensor_map_kv_nope,
plan.bar_raw_ready[rs.buf_idx],
plan.u.kv.raw_nope[rs.buf_idx].data() + D_NOPE*row,
0,
cur_indices,
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
cur_indices = nxt_cur_indices;
}
plan.bar_raw_ready[rs.buf_idx].arrive_and_expect_tx(B_TOPK*D_NOPE*sizeof(e4m3));
plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive();
rs.update();
}
});
} else if (warp_idx == 6 && elect_one_sync()) {
// KV RoPE retrieval warp
run_main_loop([&](const MainLoopArgs &args) {
plan.bar_q_utccp.wait(args.bar_phase_batch_rel);
plan.bar_last_store_done.wait(args.bar_phase_batch_rel);
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase);
if constexpr (MODEL_TYPE == ModelType::V32) {
plan.bar_qk_done[rs.buf_idx].wait(rs.bar_phase^1);
} else {
plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase^1);
}
int4 cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + 0);
int4 nxt_cur_indices;
CUTE_UNROLL
for (int row = 0; row < B_TOPK; row += 4) {
if (row+4 < B_TOPK)
nxt_cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + row + 4);
CUTE_UNROLL
for (int t = 0; t < D_ROPE/(K_ROPE_SW/2); ++t) {
ku::tma_gather4(
block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_rope : &tma_params.tensor_map_kv_rope,
plan.bar_rope_ready[rs.buf_idx],
plan.u.kv.dequant[rs.buf_idx].rope.data() + (K_ROPE_SW/2)*row + t*B_TOPK*(K_ROPE_SW/2),
t*(K_ROPE_SW/2),
cur_indices,
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
}
cur_indices = nxt_cur_indices;
}
plan.bar_rope_ready[rs.buf_idx].arrive_and_expect_tx(B_TOPK*D_ROPE*sizeof(bf16));
plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive();
rs.update();
}
});
} else if (warp_idx == 7) {
// Indices transformation warp
// Responsible for generating: TMA coordinates, scale factors, and valid masks
static_assert(B_TOPK == 64);
static constexpr int tma_coords_step_per_token = MODEL_TYPE == ModelType::V32 ? 656/TMA_K_STRIDE : 576/TMA_K_STRIDE;
int tma_coords_step_per_block = params.stride_kv_block / TMA_K_STRIDE; // must < 2G since k_batch_stride < 1T and TMA_K_STRIDE > 512
int tma_coords_step_per_extra_block = params.stride_extra_kv_block / TMA_K_STRIDE;
uint8_t* k_scales_ptr =
MODEL_TYPE == ModelType::V32 ?
(uint8_t*)params.kv + D_NOPE :
(uint8_t*)params.kv + params.page_block_size*(D_NOPE+2*D_ROPE);
uint8_t* extra_k_scales_ptr =
MODEL_TYPE == ModelType::V32 ?
(uint8_t*)params.extra_kv + D_NOPE :
(uint8_t*)params.extra_kv + params.extra_page_block_size*(D_NOPE+2*D_ROPE);
run_main_loop([&](const MainLoopArgs &args) {
int* indices = (int*)params.indices + params.stride_indices_b*args.batch_idx + params.stride_indices_s_q*s_q_idx;
int* extra_indices = (int*)params.extra_indices + params.stride_extra_indices_b*args.batch_idx + params.stride_extra_indices_s_q*s_q_idx;
struct IsOrigBlock {};
struct IsExtraBlock {};
auto process_one_block = [&](int block_idx, auto is_extra_block_t) {
static constexpr bool IS_EXTRA_BLOCK = std::is_same_v<decltype(is_extra_block_t), IsExtraBlock>;
int cur_block_size = IS_EXTRA_BLOCK ? params.extra_page_block_size : params.page_block_size;
int64_t cur_k_block_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_block : params.stride_kv_block;
[[maybe_unused]] int cur_k_row_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_row : params.stride_kv_row;
uint8_t* cur_k_scales_ptr = IS_EXTRA_BLOCK ? extra_k_scales_ptr : k_scales_ptr;
int cur_tma_coords_step_per_block = IS_EXTRA_BLOCK ? tma_coords_step_per_extra_block : tma_coords_step_per_block;
int abs_pos, my_indices[2];
if (!IS_EXTRA_BLOCK) {
abs_pos = block_idx*B_TOPK + lane_idx*2;
*(int2*)my_indices = __ldg((int2*)(indices + abs_pos));
} else {
abs_pos = (block_idx-args.num_orig_kv_blocks)*B_TOPK + lane_idx*2;
*(int2*)my_indices = __ldg((int2*)(extra_indices + abs_pos));
}
plan.bar_valid_coord_scale_free[rs.index_buf_idx].wait(rs.index_bar_phase^1);
int tma_coords[2];
e8m0 scales[2*NUM_SCALES_EACH_TOKEN];
char valid_mask = 0;
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
int block_idx, idx_in_block;
block_idx = (unsigned int)my_indices[i] / cur_block_size;
idx_in_block = (unsigned int)my_indices[i] % cur_block_size;
bool is_token_valid = my_indices[i] != -1 && (abs_pos+i < (IS_EXTRA_BLOCK?args.extra_topk_length:args.topk_length));
valid_mask |= is_token_valid << i;
tma_coords[i] = is_token_valid ? block_idx*cur_tma_coords_step_per_block + idx_in_block*tma_coords_step_per_token : -1; // If the token is invalid because it topk position exceeds topk_length, we must manually fill tma_coords with -1 to avoid copying-in NaN.
if constexpr (MODEL_TYPE == ModelType::V32) {
int64_t offset = is_token_valid ? block_idx*cur_k_block_stride + idx_in_block*cur_k_row_stride : 0;
float4 cur_scale_fp32 = __ldg((float4*)(cur_k_scales_ptr + offset));
e8m0 res[4];
*(__nv_fp8x2_storage_t*)(res+0) = __nv_cvt_float2_to_e8m0x2(float2{cur_scale_fp32.x, cur_scale_fp32.y}, __NV_NOSAT, cudaRoundZero);
*(__nv_fp8x2_storage_t*)(res+2) = __nv_cvt_float2_to_e8m0x2(float2{cur_scale_fp32.z, cur_scale_fp32.w}, __NV_NOSAT, cudaRoundZero);
if (!is_token_valid) *(uint32_t*)res = (uint32_t)0;
*(uint32_t*)(scales+i*NUM_SCALES_EACH_TOKEN) = *(uint32_t*)(res);
} else {
int64_t offset = block_idx*cur_k_block_stride + idx_in_block*8; // Each token has 7 scale factors with an extra 1B padding
uint64_t scalesx8 = is_token_valid ? __ldg((uint64_t*)(cur_k_scales_ptr + offset)) : 0;
*(uint64_t*)(scales+i*NUM_SCALES_EACH_TOKEN) = scalesx8;
}
}
valid_mask <<= lane_idx%4*2;
valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x1);
valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x2);
if constexpr (MODEL_TYPE == ModelType::V32) {
*(uint64_t*)(plan.scales[rs.index_buf_idx] + lane_idx*2) = *(uint64_t*)scales;
} else {
*(__int128_t*)(plan.scales[rs.index_buf_idx] + lane_idx*2) = *(__int128_t*)scales;
}
*(int2*)(plan.tma_coord[rs.index_buf_idx] + lane_idx*2) = *(int2*)tma_coords;
if (lane_idx%4 == 0)
plan.is_token_valid[rs.index_buf_idx][lane_idx/4] = valid_mask;
plan.bar_valid_coord_scale_ready[rs.index_buf_idx].arrive();
rs.update();
};
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) {
process_one_block(block_idx, IsOrigBlock{});
}
CUTE_NO_UNROLL
for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks); block_idx < args.end_block_idx; ++block_idx) {
process_one_block(block_idx, IsExtraBlock{});
}
});
} else {
run_main_loop([&](const MainLoopArgs &args) {});
}
} else {
// Dequant warpgroup
cutlass::arch::warpgroup_reg_alloc<208>();
// 8 threads per token
constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/8, ROWS_PER_GROUP = B_TOPK / NUM_GROUPS, COLS_PER_GROUP = D_NOPE/(GROUP_SIZE*8);
int group_idx = idx_in_warpgroup/GROUP_SIZE, idx_in_group = idx_in_warpgroup%GROUP_SIZE;
Tensor nope0 = make_tensor(make_smem_ptr(plan.u.kv.dequant[0].nope.data()), SmemLayoutKTiles_SW128<D_NOPE/64>{});
bf16* nope0_base = &nope0(group_idx, idx_in_group*8);
bf16* nope1_base = nope0_base + (plan.u.kv.dequant[1].nope.data() - plan.u.kv.dequant[0].nope.data());
e4m3* raw_nope0_base = plan.u.kv.raw_nope[rs.buf_idx].data() + group_idx*D_NOPE + idx_in_group*8;
e4m3* raw_nope1_base = raw_nope0_base + B_H*D_NOPE;
run_main_loop([&](const MainLoopArgs &args) {
// plan.bar_last_store_done.wait(args.bar_phase_batch_rel); // No need to wait since the raw nope producer must wait
plan.bar_q_utccp.wait(args.bar_phase_batch_rel);
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase);
plan.bar_raw_ready[rs.buf_idx].wait(rs.bar_phase);
plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase^1);
uint32_t cur_nope_base_uint_addr = cute::cast_smem_ptr_to_uint(rs.buf_idx == 0 ? nope0_base : nope1_base);
e4m3* raw_nope_base = rs.buf_idx == 0 ? raw_nope0_base : raw_nope1_base;
auto st_128b = [&](int local_row_idx, int local_col_idx, __int128_t &data) {
asm volatile ("st.weak.shared::cta.b128 [%0], %1;\n"
:
: "r"(cur_nope_base_uint_addr + 2*(local_row_idx*NUM_GROUPS*64 + local_col_idx*B_TOPK*64)), "q"(data) // 2 for sizeof(bf16)
); // We have this `asm volatile` here, otherwise the compiler generates ST.E instead of STS
};
auto get_raw_fp8 = [&](int local_row_idx, int local_col_idx) -> uint64_t {
return *(uint64_t*)(raw_nope_base + local_row_idx*NUM_GROUPS*D_NOPE + local_col_idx*(GROUP_SIZE*8));
};
// The following code suffers from a 2-way bank conflict when reading from SMEM.
if constexpr (MODEL_TYPE == ModelType::V32) {
CUTE_UNROLL
for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) {
int row_idx = local_row_idx*NUM_GROUPS + group_idx;
bf16 scales[4];
e8m0 scales_e8m0[4];
*(uint32_t*)scales_e8m0 = *(uint32_t*)plan.scales[rs.index_buf_idx][row_idx];
*(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0));
*(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2));
uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0);
CUTE_UNROLL
for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) {
ku::nve4m3x2 data_fp8[4];
ku::nvbf16x2 data_bf16[4];
*(uint64_t*)data_fp8 = cur_data_fp8x8;
if (local_col_idx+1 < COLS_PER_GROUP)
cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1);
bf16 scale = scales[local_col_idx / (D_NOPE/(GROUP_SIZE*8)/4)];
CUTE_UNROLL
for (int i = 0; i < 4; ++i) {
data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale));
}
st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16);
}
}
} else {
CUTE_UNROLL
for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) {
int row_idx = local_row_idx*NUM_GROUPS + group_idx;
bf16 scales[8];
e8m0 scales_e8m0[8];
*(uint64_t*)scales_e8m0 = *(uint64_t*)plan.scales[rs.index_buf_idx][row_idx];
*(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0));
*(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2));
*(__nv_bfloat162_raw*)(scales+4) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+4));
*(__nv_bfloat162_raw*)(scales+6) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+6));
uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0);
CUTE_UNROLL
for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) {
ku::nve4m3x2 data_fp8[4];
ku::nvbf16x2 data_bf16[4];
*(uint64_t*)data_fp8 = cur_data_fp8x8;
if (local_col_idx+1 < COLS_PER_GROUP)
cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1);
bf16 scale = scales[local_col_idx];
CUTE_UNROLL
for (int i = 0; i < 4; ++i) {
data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale));
}
st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16);
}
}
}
cutlass::arch::fence_view_async_shared();
plan.bar_nope_ready[rs.buf_idx].arrive();
plan.bar_raw_free[rs.buf_idx].arrive();
plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive();
rs.update();
}
});
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100 ~ sm119");
}
#endif
}
template<typename Kernel, typename TmaParams>
__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 1)
flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const SparseAttnDecodeParams params, __grid_constant__ const TmaParams tma_params) {
Kernel::flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(params, tma_params);
}
template<ModelType MODEL_TYPE>
void KernelTemplate<MODEL_TYPE>::run(const SparseAttnDecodeParams &params) {
KU_ASSERT(params.topk % B_TOPK == 0, "topk (%d) mod B_TOPK (%d) must be 0", params.topk, B_TOPK);
KU_ASSERT(params.extra_topk % B_TOPK == 0, "extra_topk (%d) mod B_TOPK (%d) must be 0", params.extra_topk, B_TOPK);
KU_ASSERT(params.h_q == B_H);
KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.d_qk == D_Q);
KU_ASSERT(params.d_v == D_V);
if constexpr (MODEL_TYPE == ModelType::MODEL1) {
constexpr int BYTES_PER_TOKEN = D_NOPE + 2*D_ROPE + 8;
KU_ASSERT(params.stride_kv_row == BYTES_PER_TOKEN, "Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1"); // Each block must be contiguous
}
auto shape_Q_SW128 = make_shape(B_H, D_Q, params.s_q, params.b);
auto tma_Q_SW128 = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q),
make_layout(
shape_Q_SW128,
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q, params.stride_q_b)
)
),
SmemLayoutQ_SW128{}
);
auto shape_O = make_shape(B_H, D_V, params.s_q, params.b);
auto tma_O = cute::make_tma_copy(
SM90_TMA_STORE{},
make_tensor(
make_gmem_ptr((bf16*)params.out),
make_layout(
shape_O,
make_stride(params.stride_o_h_q, _1{}, params.stride_o_s_q, params.stride_o_b)
)
),
SmemLayoutOBuf_TMA{}
);
CUtensorMap tensor_map_q_sw64{};
if constexpr (D_Q_SW64 > 0) {
tensor_map_q_sw64 = ku::make_tensor_map(
{D_Q_SW64, (uint64_t)params.h_q, D_Q_SW64/32, (uint64_t)params.s_q, (uint64_t)params.b},
ku::make_stride_helper(std::vector<int64_t>{params.stride_q_h_q, (int64_t)32, params.stride_q_s_q, params.stride_q_b}, sizeof(bf16)),
{32, B_H, D_Q_SW64/32, 1, 1},
(bf16*)params.q + D_Q_SW128,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B
);
}
auto get_nope_rope_tensormap = [&](bool is_extra, void* k_ptr, int num_blocks, int64_t k_batch_stride) -> std::pair<CUtensorMap, CUtensorMap> {
static_assert(D_NOPE%8 == 0);
KU_ASSERT((int64_t)k_ptr % 16 == 0, "The base address of %sk_ptr (%p) must be 16B aligned for sparse fp8 attention on sm100f", is_extra?"extra_":"", k_ptr);
KU_ASSERT(k_batch_stride % TMA_K_STRIDE == 0, "%sk_cache.stride(0) (%ld) must be a multiple of %d. Padding might be necessary", is_extra?"extra_":"", k_batch_stride, TMA_K_STRIDE);
CUtensorMap tensor_map_kv_nope = ku::make_tensor_map(
{D_NOPE/8, (uint64_t)num_blocks * (k_batch_stride/TMA_K_STRIDE)},
{TMA_K_STRIDE},
{D_NOPE/8, 1},
k_ptr,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_INT64,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B
); // NOTE We combine 8 float8 into 1 int64 since boxdim cannot > 256
CUtensorMap tensor_map_kv_rope = ku::make_tensor_map(
{D_ROPE, (uint64_t)num_blocks * (k_batch_stride/TMA_K_STRIDE)},
{TMA_K_STRIDE},
{K_ROPE_SW/2, 1},
(uint8_t*)k_ptr + (MODEL_TYPE == ModelType::V32 ? (D_NOPE+16) : D_NOPE),
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
K_ROPE_SW == 64 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B : CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B
);
return {tensor_map_kv_nope, tensor_map_kv_rope};
};
auto [tensor_map_kv_nope, tensor_map_kv_rope] = get_nope_rope_tensormap(false, params.kv, params.num_blocks, params.stride_kv_block);
CUtensorMap tensor_map_extra_kv_nope{}, tensor_map_extra_kv_rope{};
if (params.extra_topk > 0) {
std::tie(tensor_map_extra_kv_nope, tensor_map_extra_kv_rope) = get_nope_rope_tensormap(true, params.extra_kv, params.extra_num_blocks, params.stride_extra_kv_block);
}
TmaParams<
decltype(shape_Q_SW128), decltype(tma_Q_SW128),
decltype(shape_O), decltype(tma_O)
> tma_params = {
shape_Q_SW128, tma_Q_SW128,
shape_O, tma_O,
tensor_map_q_sw64,
tensor_map_kv_nope,
tensor_map_kv_rope,
tensor_map_extra_kv_nope,
tensor_map_extra_kv_rope
};
auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel<KernelTemplate<MODEL_TYPE>, decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
static_assert(smem_size < 227*1024);
KU_CUDA_CHECK(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// NOTE Don't use PDL because of potential compiler bugs!
mla_kernel<<<dim3(params.s_q, params.num_sm_parts, 1), dim3(NUM_THREADS, 1, 1), smem_size, params.stream>>>(params, tma_params);
KU_CHECK_KERNEL_LAUNCH();
}
template<ModelType MODEL_TYPE>
void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params) {
KernelTemplate<MODEL_TYPE>::run(params);
}
}
#pragma once
#include "params.h"
namespace sm100::decode::head64 {
template<ModelType MODEL_TYPE>
void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params);
}
#pragma once
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#include "sm100/defines.h"
namespace sm100 {
struct fp8x8 {
__nv_fp8x4_e4m3 lo;
__nv_fp8x4_e4m3 hi;
};
struct fp8x32 {
fp8x8 a0, a1, a2, a3;
};
struct fp8x16 {
fp8x8 a0, a1;
};
__device__ __forceinline__
bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) {
__nv_bfloat162 scale_bf162 = __float2bfloat162_rn(scale);
#define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \
{ \
float4 fp32x4 = (float4)(FP8x4); \
OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \
OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \
}
bf16x8 result;
DEQUANT_FP8x4(result.a01, result.a23, inputs.lo);
DEQUANT_FP8x4(result.a45, result.a67, inputs.hi);
return result;
}
__device__ __forceinline__
fp8x32 ldg_256_fp8x32(void* src_ptr) {
int32x8_t val;
asm volatile("ld.global.nc.L1::evict_normal.L2::evict_normal.L2::256B.v8.s32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];"
: "=r"(val.a0), "=r"(val.a1), "=r"(val.a2), "=r"(val.a3),
"=r"(val.a4), "=r"(val.a5), "=r"(val.a6), "=r"(val.a7)
: "l"(src_ptr)
);
return *reinterpret_cast<fp8x32*>(&val);
}
__device__ __forceinline__
fp8x16 ldg_128_fp8x16(void* src_ptr) {
int4 ret;
asm volatile("ld.global.nc.L1::evict_first.v4.s32 {%0, %1, %2, %3}, [%4];"
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w)
: "l"(src_ptr));
return *reinterpret_cast<fp8x16*>(&ret);
}
}
#include "splitkv_mla.h"
#include <cutlass/barrier.h>
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/tensor.hpp>
#include <cute/arch/tmem_allocator_sm100.hpp>
#include "utils.h"
#include "dequant.h"
#include "sm100/defines.h"
#include "sm100/helpers.h"
#include "sm100/intrinsics.h"
#include "sm100/ws_gemm.h"
namespace sm100 {
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::NamedBarrier;
using namespace cute;
constexpr int B_H = 64;
constexpr int B_TOPK = 64;
constexpr int D_K = 576;
constexpr int D_V = 512;
constexpr int NUM_BUFS = 2;
constexpr int NUM_THREADS = 128*3;
constexpr int NUM_WORKING_THREADS = 128 + 128 + 32;
constexpr float MAX_INIT_VAL = -1e30f; // To avoid (-inf) - (-inf) = NaN
template<
typename Shape_Q, typename TMA_Q,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
Shape_O shape_O; TMA_O tma_O;
};
namespace tmem_addr {
constexpr int o = 0; // o: [0, 256]
constexpr int p = 256; // p: [256, 288]
};
using SmemLayoutQ = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<D_K>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutOBuf = decltype(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{}, // TODO This may lead to TMA double traffic
Shape<Int<B_H>, Int<D_V>>{}
));
using SmemLayoutOAccumBuf = Layout<
Shape<Int<B_H>, Int<D_V>>,
Stride<Int<520>, _1> // We use stride = 520 here to avoid bank conflict
>;
using SmemLayoutS = decltype(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H>, Int<B_TOPK>>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed = decltype(composition(
SmemLayoutKTiles<NUM_TILES>{},
Layout<
Shape<Int<64*NUM_TILES>, Int<B_TOPK>>,
Stride<Int<B_TOPK>, _1>
>{}
));
using SmemLayoutK = SmemLayoutKTiles<9>;
using SmemLayoutV = SmemLayoutKTilesTransposed<8>;
struct SharedMemoryPlan {
array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
union {
array_aligned<bf16, cosize_v<SmemLayoutOBuf>> o_buf;
array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> o_accum_buf;
array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_BUFS];
} u;
array_aligned<bf16, cosize_v<SmemLayoutS>> s;
transac_bar_t bar_q;
transac_bar_t bar_k_ready[NUM_BUFS], bar_k_free[NUM_BUFS];
transac_bar_t bar_qk_done[NUM_BUFS], bar_so_ready[NUM_BUFS];
float rowwise_max_buf[128], rowwise_li_buf[128];
bool is_token_valid[NUM_BUFS][B_TOPK];
array_aligned<uint32_t, 1> tmem_start_addr;
};
using TiledMMA_QK = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_SS_NOELECT<bf16, bf16, float, B_H, B_TOPK, UMMA::Major::K, UMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{}
)); // TODO Use TS?
using TiledMMA_SV = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{},
Tile<Int<B_H>, Int<D_V>>{}
));
template<typename T>
CUTE_DEVICE
void store_128b(void* smem_ptr, const T &data) {
static_assert(sizeof(T) == 16);
*(__int128*)smem_ptr = *(__int128*)&data;
}
template<typename TmaParams>
__global__ void __launch_bounds__(NUM_THREADS, 1, 1)
flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) {
#if IS_SM100
const int head_block_idx = blockIdx.x;
const int s_q_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int idx_in_warpgroup = threadIdx.x % 128;
const int warp_idx = cutlass::canonical_warp_idx_sync();
// Define shared tensors
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
Tensor sQ = make_tensor(make_smem_ptr(plan.q.data()), SmemLayoutQ{});
if (warp_idx == 0 && elect_one_sync()) {
cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());
}
if (warp_idx == 0) {
if (elect_one_sync()) {
plan.bar_q.init(1);
for (int i = 0; i < NUM_BUFS; ++i) {
plan.bar_k_ready[i].init(128);
plan.bar_k_free[i].init(1);
plan.bar_qk_done[i].init(1);
plan.bar_so_ready[i].init(128);
}
cutlass::arch::fence_barrier_init();
}
cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data());
TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0);
cute::TMEM::Allocator1Sm().release_allocation_lock();
}
__syncthreads();
int bar_phase_k = 0;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int sched_begin_block_idx = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int sched_end_block_idx = tile_scheduler_metadata.w;
if (begin_idx >= params.b) {
if (warp_idx == 0) {
cute::TMEM::Allocator1Sm().free(0, 512);
}
return;
}
auto get_cur_req_info = [&](int batch_idx) -> std::tuple<int, int, bool> {
int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0;
int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : params.topk / B_TOPK;
bool is_no_split = start_block_idx == 0 && end_block_idx == params.topk / B_TOPK;
return {start_block_idx, end_block_idx, is_no_split};
};
if (warpgroup_idx == 0) {
// Producer warpgroup
#pragma unroll 1
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx);
int* gIndices = params.indices_ptr + batch_idx*params.indices_batch_stride + s_q_idx*params.indices_row_stride; // (topk) : (1)
constexpr int GROUP_SIZE = 4, NUM_GROUPS = 128 / GROUP_SIZE;
constexpr int ROWS_PER_GROUP = B_TOPK / NUM_GROUPS;
int group_idx = idx_in_warpgroup / GROUP_SIZE;
int idx_in_group = idx_in_warpgroup % GROUP_SIZE;
NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1);
CUTE_NO_UNROLL
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) {
int buf_idx = block_idx % NUM_BUFS;
// Wait for buffer to be available
plan.bar_k_free[buf_idx].wait(bar_phase_k>>buf_idx&1^1);
// Load
Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{});
CUTE_UNROLL
for (int local_row = 0; local_row < ROWS_PER_GROUP; ++local_row) {
int smem_row = group_idx + local_row*NUM_GROUPS;
int token_index = __ldg(gIndices + block_idx*B_TOPK + smem_row);
bool is_token_invalid = token_index == -1;
if (idx_in_group == 0)
plan.is_token_valid[buf_idx][smem_row] = !is_token_invalid;
if (is_token_invalid) {
uint128_t zeros = uint128_t{};
CUTE_UNROLL
for (int local_col = 0; local_col < D_V / (GROUP_SIZE*16); ++local_col) {
int col_base = local_col*(GROUP_SIZE*16) + idx_in_group*16;
store_128b(&sK(smem_row, col_base ), zeros);
store_128b(&sK(smem_row, col_base+8), zeros);
}
CUTE_UNROLL
for (int local_col = 0; local_col < (D_K-D_V) / (GROUP_SIZE*8); ++local_col) {
int col_base = local_col*(GROUP_SIZE*8) + idx_in_group*8;
store_128b(&sK(smem_row, D_V+col_base), zeros);
}
} else {
int block_index = token_index/B_TOPK;
int rel_idx_in_block = (token_index+B_TOPK) % B_TOPK; // NOTE When token_index is -1, -1/B_TOPK = 0 and (-1+B_TOPK)%B_TOPK = 63, so there will be no illegal-memory-access error. However, masking is necessary to prevent NaN (TODO Skip some rows instead?) TODO Masking
fp8* gK_base = (fp8*)params.k_ptr + block_index*params.k_batch_stride + rel_idx_in_block*params.k_row_stride;
float4 scales = __ldg((float4*)(gK_base + D_V));
CUTE_UNROLL
for (int local_col = 0; local_col < D_V / (GROUP_SIZE*16); ++local_col) {
int col_base = local_col*(GROUP_SIZE*16) + idx_in_group*16;
fp8x16 cur_fp8s = ldg_128_fp8x16(gK_base + col_base);
float cur_scale = local_col < (256/(GROUP_SIZE*16)) ?
(local_col < (128/(GROUP_SIZE*16)) ? scales.x : scales.y) :
(local_col < (384/(GROUP_SIZE*16)) ? scales.z : scales.w);
store_128b(&sK(smem_row, col_base ), cvt_fp8x8_bf16x8(cur_fp8s.a0, cur_scale));
store_128b(&sK(smem_row, col_base+8), cvt_fp8x8_bf16x8(cur_fp8s.a1, cur_scale));
}
CUTE_UNROLL
for (int local_col = 0; local_col < (D_K-D_V) / (GROUP_SIZE*8); ++local_col) {
int col_base = local_col*(GROUP_SIZE*8) + idx_in_group*8;
fp8x16 cur_k_rope_fp8s = ldg_128_fp8x16(gK_base + D_V + 4*sizeof(float) + col_base*sizeof(bf16));
bf16x8 cur_k_rope = *reinterpret_cast<bf16x8*>(&cur_k_rope_fp8s);
store_128b(&sK(smem_row, D_V+col_base), cur_k_rope);
}
}
}
fence_view_async_shared();
// Signal
plan.bar_k_ready[buf_idx].arrive();
bar_phase_k ^= 1<<buf_idx;
}
}
} else if (warpgroup_idx == 1) {
// Scale & Exp warpgroup
cutlass::arch::warpgroup_reg_alloc<240>();
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
#pragma unroll 1
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx);
NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1);
float li = 0.0f;
float mi = MAX_INIT_VAL;
CUTE_NO_UNROLL
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) {
int buf_idx = block_idx % NUM_BUFS;
// Wait for P
plan.bar_qk_done[buf_idx].wait(bar_phase_k>>buf_idx&1);
tcgen05_after_thread_sync();
// Load P from TMEM
float p[B_TOPK/2];
float2* p_float2 = reinterpret_cast<float2*>(p);
tmem_ld_32dp32bNx<B_TOPK/2>(tmem_addr::p, p);
cutlass::arch::fence_view_async_tmem_load();
// Get rowwise max
float cur_max = -INFINITY;
CUTE_UNROLL
for (int i = 0; i < B_TOPK/2; ++i) {
if (!plan.is_token_valid[buf_idx][(idx_in_warpgroup/64)*(B_TOPK/2)+i]) p[i] = -INFINITY;
cur_max = max(cur_max, p[i]);
}
cur_max *= params.scale_softmax_log2;
NamedBarrier::arrive_and_wait(128, 0); // TODO Name these barriers
plan.rowwise_max_buf[idx_in_warpgroup] = cur_max;
NamedBarrier::arrive_and_wait(128, 0);
cur_max = max(cur_max, plan.rowwise_max_buf[idx_in_warpgroup ^ 64]);
float new_max = max(mi, cur_max);
float scale_for_old = exp2f(mi - new_max);
float2 scale_for_old_float2 = {scale_for_old, scale_for_old};
// Get S
float2 scale_softmax_log2_float2 = {params.scale_softmax_log2, params.scale_softmax_log2};
float2 neg_new_max_float2 = {-new_max, -new_max};
bf16 s[B_TOPK/2];
float2 cur_sum = {0.0f, 0.0f};
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/2; ++i) {
float2 t = float2_fma(p_float2[i], scale_softmax_log2_float2, neg_new_max_float2);
t.x = exp2(t.x);
t.y = exp2(t.y);
*(__nv_bfloat162*)&s[i*2] = __float22bfloat162_rn(t);
cur_sum = float2_add(cur_sum, t);
}
// Save S
// NOTE We don't need a barrier here, since the current QK^T has finished implies that the previous SV has finished
bf16* sS_base = plan.s.data() + (idx_in_warpgroup/64)*(B_H*B_TOPK/2) + (idx_in_warpgroup%64) * 8;
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/8; i += 1) {
store_128b(sS_base + i*8*B_H, *((bf16x8*)s + i));
}
fence_view_async_shared();
// Rescale O
if (block_idx != start_block_idx) {
constexpr int B_SCALE_O = 64;
float2 o[B_SCALE_O/2];
CUTE_UNROLL
for (int b = 0; b < (D_V/2)/B_SCALE_O; ++b) {
tmem_ld_32dp32bNx<B_SCALE_O>(tmem_addr::o + b*B_SCALE_O, o);
cutlass::arch::fence_view_async_tmem_load();
CUTE_UNROLL
for (int i = 0; i < B_SCALE_O/2; ++i)
o[i] = float2_mul(o[i], scale_for_old_float2);
tmem_st_32dp32bNx<B_SCALE_O>(tmem_addr::o + b*B_SCALE_O, o);
cutlass::arch::fence_view_async_tmem_store();
}
}
plan.bar_so_ready[buf_idx].arrive();
// Update mi and li
mi = new_max;
li = li * scale_for_old + cur_sum.x + cur_sum.y;
bar_phase_k ^= 1<<buf_idx;
}
// Epilogue
// Deal with no valid token cases
if (mi == MAX_INIT_VAL) {
mi = -INFINITY;
li = 0.0f;
}
// Reduce li
plan.rowwise_li_buf[idx_in_warpgroup] = li;
NamedBarrier::arrive_and_wait(128, 0);
li += plan.rowwise_li_buf[idx_in_warpgroup ^ 64];
// Save li
int num_valid_heads = min(B_H, params.q_head_per_hk - head_block_idx*B_H);
int start_seq_idx = s_q_idx*params.q_head_per_hk + head_block_idx*B_H;
int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0;
int split_idx = is_no_split ? 0 : (__ldg(params.num_splits_ptr+batch_idx) + n_split_idx);
if (idx_in_warpgroup < num_valid_heads) {
if (is_no_split) {
float* gSoftmaxLse = (float*)params.softmax_lse_ptr + batch_idx*params.q_seq_per_hk + start_seq_idx + idx_in_warpgroup;
*gSoftmaxLse = li == 0.0f ? INFINITY : logf(li) + mi / (float)M_LOG2E; // NOTE Follows Flash MLA's approach, which returns +inf when there are no valid indices
} else {
float* gSoftmaxLseAccum = (float*)params.softmax_lseaccum_ptr + split_idx*params.q_seq_per_hk + start_seq_idx + idx_in_warpgroup;
*gSoftmaxLseAccum = li == 0.0f ? -INFINITY : log2f(li) + mi;
}
}
// Wait for the last SV gemm
plan.bar_k_free[(end_block_idx-1)%NUM_BUFS].wait(bar_phase_k>>((end_block_idx-1)%NUM_BUFS)&1^1);
tcgen05_after_thread_sync();
// Save O
float o_scale = li == 0.0f ? 0.0f : 1.0f / li;
float2 o_scale_float2 = {o_scale, o_scale};
if (is_no_split) {
constexpr int B_EPI = 32;
float2 o[B_EPI/2];
__nv_bfloat162 o_bf16[B_EPI/2];
Tensor sO = make_tensor(make_smem_ptr(plan.u.o_buf.data()), SmemLayoutOBuf{});
bf16* sO_base = plan.u.o_buf.data() + ((idx_in_warpgroup/64)*128)*B_H + (idx_in_warpgroup%64)*8;
CUTE_UNROLL
for (int i = 0; i < (D_V/2) / B_EPI; ++i) {
// Load
tmem_ld_32dp32bNx<B_EPI>(tmem_addr::o + i*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
// Scale & Convert
CUTE_UNROLL
for (int j = 0; j < B_EPI/2; ++j) {
o[j] = float2_mul(o[j], o_scale_float2);
o_bf16[j] = __float22bfloat162_rn(o[j]);
}
// Store
int col_base = (i*B_EPI>=D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4));
CUTE_UNROLL
for (int j = 0; j < B_EPI / 8; ++j)
store_128b(sO_base + (col_base+j*8)*B_H, *reinterpret_cast<bf16x8*>(&o_bf16[j*4]));
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, 0);
if (warp_idx == 4 && elect_one_sync()) {
Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, batch_idx);
auto thr_tma = tma_params.tma_O.get_slice(_0{});
Tensor my_tma_gO = flat_divide(tma_gO, Shape<Int<B_H>, Int<D_V>>{})(_, _, head_block_idx, _0{});
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sO),
thr_tma.partition_D(my_tma_gO)
);
cute::tma_store_arrive();
}
} else {
constexpr int B_EPI = 64;
float2 o[B_EPI/2];
Tensor sO = make_tensor(make_smem_ptr(plan.u.o_accum_buf.data()), SmemLayoutOAccumBuf{});
CUTE_UNROLL
for (int i = 0; i < (D_V/2) / B_EPI; ++i) {
// Load
tmem_ld_32dp32bNx<B_EPI>(tmem_addr::o + i*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
// Scale & Convert
CUTE_UNROLL
for (int j = 0; j < B_EPI/2; ++j)
o[j] = float2_mul(o[j], o_scale_float2);
// Store
int col_base = (idx_in_warpgroup/64)*128 + (i*B_EPI >= D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4));
CUTE_UNROLL
for (int j = 0; j < B_EPI / 4; ++j)
store_128b(&sO(idx_in_warpgroup%64, col_base + j*4), *reinterpret_cast<float4*>(&o[j*2]));
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, 0);
if (elect_one_sync()) {
CUTE_UNROLL
for (int local_row = 0; local_row < B_H/4; ++local_row) {
int smem_row = local_row*4 + (warp_idx-4);
if (smem_row < num_valid_heads) {
SM90_BULK_COPY_S2G::copy(
&sO(smem_row, _0{}),
(float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx + smem_row)*D_V,
D_V*sizeof(float)
);
}
}
cute::tma_store_arrive();
}
}
cute::tma_store_wait<0>();
}
if (warp_idx == 4) {
cute::TMEM::Allocator1Sm().free(0, 512);
}
} else {
cutlass::arch::warpgroup_reg_dealloc<96>();
if (warp_idx == 8) {
// UTCMMA warp
bool bar_phase_q = 0;
TiledMMA tiled_mma_qk = TiledMMA_QK{};
TiledMMA tiled_mma_sv = TiledMMA_SV{};
Tensor tP = partition_fragment_C(tiled_mma_qk, Shape<Int<B_H>, Int<B_TOPK>>{});
Tensor tO = partition_fragment_C(tiled_mma_sv, Shape<Int<B_H>, Int<D_V>>{});
tO.data().get() = tmem_addr::o;
tP.data().get() = tmem_addr::p;
Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{});
#pragma unroll 1
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx);
if (elect_one_sync()) {
// Copy Q
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx),
Tile<Int<B_H>, Int<D_K>>{}
)(_, _, head_block_idx, _0{});
launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(B_H*D_K*sizeof(bf16));
}
NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1);
if (elect_one_sync()) {
// Wait for Q
plan.bar_q.wait(bar_phase_q);
bar_phase_q ^= 1;
tcgen05_after_thread_sync();
CUTE_NO_UNROLL
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) {
int buf_idx = block_idx % NUM_BUFS;
// Wait for K
plan.bar_k_ready[buf_idx].wait(bar_phase_k>>buf_idx&1);
tcgen05_after_thread_sync();
Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{});
// Issue P = Q @ K^T
utcmma_ss(tiled_mma_qk, sQ, sK, tP, true);
umma_arrive_noelect(plan.bar_qk_done[buf_idx]);
// Wait for S
plan.bar_so_ready[buf_idx].wait(bar_phase_k>>buf_idx&1);
tcgen05_after_thread_sync();
Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutV{});
// Issue O += S @ V
utcmma_ss(tiled_mma_sv, sS, sV, tO, block_idx == start_block_idx);
umma_arrive_noelect(plan.bar_k_free[buf_idx]);
bar_phase_k ^= 1<<buf_idx;
}
}
__syncwarp();
// NOTE If we reach this point, we must have done the QK gemm (since we've waited for bar_so_ready)
// So we can launch the copy of the next Q block immediately
}
}
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100 ~ sm119");
}
#endif
}
void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams &params, cudaStream_t stream) {
FLASH_ASSERT(params.h_k == 1);
FLASH_ASSERT(params.topk % B_TOPK == 0);
auto shape_Q = make_shape(params.q_head_per_hk, params.d, params.s_q, params.b);
auto tma_Q = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q_ptr),
make_layout(
shape_Q,
make_stride(params.q_row_stride, _1{}, params.q_head_per_hk*params.q_row_stride, params.q_batch_stride)
)
),
SmemLayoutQ{}
);
auto shape_O = make_shape(params.q_head_per_hk, params.d_v, params.s_q, params.b);
auto tma_O = cute::make_tma_copy(
SM90_TMA_STORE{},
make_tensor(
make_gmem_ptr((bf16*)params.o_ptr),
make_layout(
shape_O,
make_stride(params.o_row_stride, _1{}, params.q_head_per_hk*params.o_row_stride, params.o_batch_stride)
)
),
SmemLayoutOBuf{}
);
TmaParams<
decltype(shape_Q), decltype(tma_Q),
decltype(shape_O), decltype(tma_O)
> tma_params = {
shape_Q, tma_Q,
shape_O, tma_O
};
auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel<decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
const int num_m_blocks = cute::ceil_div(params.q_head_per_hk, B_H);
// NOTE Don't use PDL because of potential compiler bugs!
mla_kernel<<<dim3(num_m_blocks, params.s_q, params.num_sm_parts), dim3(NUM_THREADS, 1, 1), smem_size, stream>>>(params, tma_params);
CHECK_CUDA_KERNEL_LAUNCH();
}
}
\ No newline at end of file
#pragma once
#include "params.h"
namespace sm100 {
void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams &params, cudaStream_t stream);
}
#pragma once #pragma once
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include "defines.h" #include "defines.h"
namespace sm100 { namespace sm100 {
using namespace cute; using namespace cute;
using _72 = Int<72>;
using _576 = Int<576>;
template<
typename TMA,
typename Tensor0,
typename Tensor1
>
CUTE_DEVICE CUTE_DEVICE
void launch_tma_copy( int int4_max(int4 t) {
const TMA &tma_copy, return max(max(t.x, t.y), max(t.z, t.w));
Tensor0 src,
Tensor1 dst,
transac_bar_t &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL
) {
auto thr_tma = tma_copy.get_slice(_0{});
cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint),
thr_tma.partition_S(src),
thr_tma.partition_D(dst)
);
} }
template<
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE CUTE_DEVICE
void utcmma_ss( int int4_min(int4 t) {
TiledMMA &tiled_mma, return min(min(t.x, t.y), min(t.z, t.w));
TensorA sA,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sA_frag = thr_mma.partition_fragment_A(sA);
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(sA_frag) == size<2>(sB_frag));
static_assert(size<1>(sA_frag) == size<1>(tC_frag));
static_assert(size<1>(sB_frag) == size<2>(tC_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(sA_frag); ++k) {
cute::gemm(
tiled_mma,
sA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
} }
template< // Convert 2x fp8_e4m3 to 2x bf16 with scaling
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE CUTE_DEVICE
void utcmma_ts( nv_bfloat162 fp8x2_to_bf16x2_with_scale(__nv_fp8x2_e4m3 data, nv_bfloat16 scale) {
TiledMMA &tiled_mma, // TODO Use native conversion for CUDA >= 13.1
TensorA tA_frag, float2 data_float2 = (float2)data;
TensorB sB, nv_bfloat162 data_bf16x2 = __float22bfloat162_rn(data_float2);
TensorFragC tC_frag, return nv_bfloat162 {
bool clear_accum data_bf16x2.x * scale,
) { data_bf16x2.y * scale
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; };
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(tA_frag) == size<2>(sB_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(tA_frag); ++k) {
cute::gemm(
tiled_mma,
tA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
} }
} }
#pragma once
#include <cute/tensor.hpp>
#include <cute/arch/simd_sm100.hpp>
#include "defines.h"
namespace sm100 {
using namespace cute;
__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n"
:: "r"(dst_addr),
"l"(src),
"n"(16));
}
CUTE_DEVICE
int64_t createpolicy_evict_last() {
int64_t res;
asm volatile(
"createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t"
: "=l"(res)
:
);
return res;
}
template<typename T>
CUTE_DEVICE
static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) {
static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async_128b.");
long2 data_long2 = *reinterpret_cast<const long2*>(&data);
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);
asm volatile (
"st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n"
:
: "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr)
);
}
__device__ __forceinline__ void tcgen05_before_thread_sync() {
asm volatile("tcgen05.fence::before_thread_sync;");
}
__device__ __forceinline__ void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
CUTE_DEVICE
void umma_arrive_multicast_noelect(transac_bar_t &smem_ptr, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr);
asm volatile(
"{\n\t"
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t"
"}"
:
:"r"(bar_intptr), "h"(cta_mask));
}
CUTE_DEVICE
void umma_arrive_multicast_2x1SM_noelect(transac_bar_t &smem_ptr, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr);
asm volatile(
"{\n\t"
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t"
"}"
:
:"r"(bar_intptr), "h"(cta_mask));
}
CUTE_DEVICE
void umma_arrive_noelect(transac_bar_t &smem_ptr) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr);
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];"
:
:"r"(bar_intptr));
}
CUTE_DEVICE
void umma_arrive_2x1SM_noelect(transac_bar_t &smem_ptr) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr);
asm volatile("tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];"
:
:"r"(bar_intptr));
}
CUTE_DEVICE
float2 float2_add(const float2 &a, const float2 &b) {
float2 res;
cute::add(res, a, b);
return res;
}
CUTE_DEVICE
float2 float2_mul(const float2 &a, const float2 &b) {
float2 res;
cute::mul(res, a, b);
return res;
}
CUTE_DEVICE
float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) {
// return a*b+c
float2 res;
cute::fma(res, a, b, c);
return res;
}
CUTE_DEVICE
float2 float2_neg(const float2 &a) {
float2 t = {-1.0f, -1.0f};
return float2_mul(a, t);
}
template<bool USE_CTA0_MBAR = false>
CUTE_DEVICE void tma_gather4(const void* desc_ptr, transac_bar_t* mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, TMA::CacheHintSm90 cache_hint) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);
if constexpr (USE_CTA0_MBAR) {
mbar_addr &= Sm100MmaPeerBitMask;
}
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbar_addr), "l"(uint64_t(cache_hint))
: "memory"
);
}
// 32 data path lanes, 32-bit pattern, repeated N times
template <int N, typename T>
CUTE_DEVICE void tmem_ld_32dp32bNx(uint32_t const &src_addr, T* dst_ptr_) {
static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128");
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(dst_ptr_);
if constexpr (N == 1) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x1.b32"
"{%0},"
"[%1];\n"
: "=r"(dst_ptr[0])
: "r"(src_addr));
} else if constexpr (N == 2) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x2.b32"
"{%0, %1},"
"[%2];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1])
: "r"(src_addr));
} else if constexpr (N == 4) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x4.b32"
"{%0, %1, %2, %3},"
"[%4];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3])
: "r"(src_addr));
} else if constexpr (N == 8) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
"[%8];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7])
: "r"(src_addr));
} else if constexpr (N == 16) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x16.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15},"
"[%16];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15])
: "r"(src_addr));
} else if constexpr (N == 32) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x32.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, "
"%26, %27, %28, %29, %30, %31},"
"[%32];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31])
: "r"(src_addr));
} else if constexpr (N == 64) {
asm volatile(
"tcgen05.ld.sync.aligned.32x32b.x64.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63},"
"[%64];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]),
"=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]),
"=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]),
"=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]),
"=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]),
"=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]),
"=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]),
"=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]),
"=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]),
"=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]),
"=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]),
"=r"(dst_ptr[63])
: "r"(src_addr));
} else if constexpr (N == 128) {
asm volatile(
"tcgen05.ld.sync.aligned.32x32b.x128.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, "
"%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, "
"%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, "
"%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, "
"%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, "
"%121, %122, %123, %124, %125, %126, %127},"
"[%128];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]),
"=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]),
"=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]),
"=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]),
"=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]),
"=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]),
"=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]),
"=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]),
"=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]),
"=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]),
"=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]),
"=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]),
"=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]),
"=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]),
"=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]),
"=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]),
"=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]),
"=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]),
"=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]),
"=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]),
"=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]),
"=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]),
"=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]),
"=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]),
"=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]),
"=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]),
"=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]),
"=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]),
"=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]),
"=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]),
"=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]),
"=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]),
"=r"(dst_ptr[126]), "=r"(dst_ptr[127])
: "r"(src_addr));
} else {
asm volatile ("trap");
}
}
// 32 data path lanes, 32-bit pattern, repeated N times
template <int N, typename T>
CUTE_DEVICE void tmem_st_32dp32bNx(uint32_t const &dst_addr, T* src_ptr_) {
static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128");
uint32_t* src_ptr = reinterpret_cast<uint32_t*>(src_ptr_);
if constexpr (N == 1) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x1.b32"
"[%1], {%0};\n"
:
: "r"(src_ptr[0]),
"r"(dst_addr));
} else if constexpr (N == 2) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x2.b32"
"[%2], {%0, %1};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]),
"r"(dst_addr));
} else if constexpr (N == 4) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x4.b32"
"[%4], {%0, %1, %2, %3};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]),
"r"(dst_addr));
} else if constexpr (N == 8) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32"
"[%8], {%0, %1, %2, %3, %4, %5, %6, %7};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]),
"r"(dst_addr));
} else if constexpr (N == 16) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x16.b32"
"[%16], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]),
"r"(dst_addr));
} else if constexpr (N == 32) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x32.b32"
"[%32], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, "
"%26, %27, %28, %29, %30, %31};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]),
"r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]),
"r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]),
"r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]),
"r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]),
"r"(src_ptr[30]), "r"(src_ptr[31]),
"r"(dst_addr));
} else if constexpr (N == 64) {
asm volatile(
"tcgen05.st.sync.aligned.32x32b.x64.b32"
"[%64], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]),
"r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]),
"r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]),
"r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]),
"r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]),
"r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]),
"r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]),
"r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]),
"r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]),
"r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]),
"r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]),
"r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]),
"r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]),
"r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]),
"r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]),
"r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]),
"r"(src_ptr[63]),
"r"(dst_addr));
} else if constexpr (N == 128) {
asm volatile(
"tcgen05.st.sync.aligned.32x32b.x128.b32"
"[%128], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, "
"%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, "
"%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, "
"%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, "
"%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, "
"%121, %122, %123, %124, %125, %126, %127};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]),
"r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]),
"r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]),
"r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]),
"r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]),
"r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]),
"r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]),
"r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]),
"r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]),
"r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]),
"r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]),
"r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]),
"r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]),
"r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]),
"r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]),
"r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]),
"r"(src_ptr[63]), "r"(src_ptr[64]), "r"(src_ptr[65]),
"r"(src_ptr[66]), "r"(src_ptr[67]), "r"(src_ptr[68]),
"r"(src_ptr[69]), "r"(src_ptr[70]), "r"(src_ptr[71]),
"r"(src_ptr[72]), "r"(src_ptr[73]), "r"(src_ptr[74]),
"r"(src_ptr[75]), "r"(src_ptr[76]), "r"(src_ptr[77]),
"r"(src_ptr[78]), "r"(src_ptr[79]), "r"(src_ptr[80]),
"r"(src_ptr[81]), "r"(src_ptr[82]), "r"(src_ptr[83]),
"r"(src_ptr[84]), "r"(src_ptr[85]), "r"(src_ptr[86]),
"r"(src_ptr[87]), "r"(src_ptr[88]), "r"(src_ptr[89]),
"r"(src_ptr[90]), "r"(src_ptr[91]), "r"(src_ptr[92]),
"r"(src_ptr[93]), "r"(src_ptr[94]), "r"(src_ptr[95]),
"r"(src_ptr[96]), "r"(src_ptr[97]), "r"(src_ptr[98]),
"r"(src_ptr[99]), "r"(src_ptr[100]), "r"(src_ptr[101]),
"r"(src_ptr[102]), "r"(src_ptr[103]), "r"(src_ptr[104]),
"r"(src_ptr[105]), "r"(src_ptr[106]), "r"(src_ptr[107]),
"r"(src_ptr[108]), "r"(src_ptr[109]), "r"(src_ptr[110]),
"r"(src_ptr[111]), "r"(src_ptr[112]), "r"(src_ptr[113]),
"r"(src_ptr[114]), "r"(src_ptr[115]), "r"(src_ptr[116]),
"r"(src_ptr[117]), "r"(src_ptr[118]), "r"(src_ptr[119]),
"r"(src_ptr[120]), "r"(src_ptr[121]), "r"(src_ptr[122]),
"r"(src_ptr[123]), "r"(src_ptr[124]), "r"(src_ptr[125]),
"r"(src_ptr[126]), "r"(src_ptr[127]),
"r"(dst_addr));
} else {
asm volatile ("trap");
}
}
static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字
template<typename T>
CUTE_DEVICE
T* get_peer_addr(const T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
}
}
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cute/layout.hpp" #include "cute/layout.hpp"
#include "utils.h" // for IS_SM100 #include <kerutils/kerutils.cuh> // for KERUTILS_ENABLE_SM100A
namespace cutlass::fmha::kernel { namespace cutlass::fmha::kernel {
...@@ -139,7 +139,7 @@ struct FmhaKernelBwdConvert { ...@@ -139,7 +139,7 @@ struct FmhaKernelBwdConvert {
} }
CUTLASS_DEVICE void operator()(const Params &params, char* smem) { CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if IS_SM100 #if defined(KERUTILS_ENABLE_SM100A)
if (params.ptr_src_dQ != nullptr) { if (params.ptr_src_dQ != nullptr) {
copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape)); copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape));
} }
......
...@@ -34,7 +34,7 @@ ...@@ -34,7 +34,7 @@
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cute/layout.hpp" #include "cute/layout.hpp"
#include "utils.h" // for IS_SM100 #include <kerutils/kerutils.cuh> // for KERUTILS_ENABLE_SM100A
namespace cutlass::fmha::kernel { namespace cutlass::fmha::kernel {
...@@ -105,7 +105,7 @@ struct FmhaKernelBwdSumOdO { ...@@ -105,7 +105,7 @@ struct FmhaKernelBwdSumOdO {
} }
CUTLASS_DEVICE void operator()(const Params &params, char* smem) { CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if IS_SM100 #if defined(KERUTILS_ENABLE_SM100A)
auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O);
auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO);
auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO);
......
...@@ -41,7 +41,7 @@ ...@@ -41,7 +41,7 @@
#include "cutlass/arch/memory_sm80.h" #include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp"
#include "utils.h" // for IS_SM100 #include <kerutils/kerutils.cuh> // for KERUTILS_ENABLE_SM100A
#include "../collective/fmha_common.hpp" #include "../collective/fmha_common.hpp"
#include <cmath> #include <cmath>
...@@ -949,8 +949,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ...@@ -949,8 +949,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
TensorC const& coord, TensorC const& coord,
TensorShape const& tensor_shape) { TensorShape const& tensor_shape) {
// TODO: Performance of FlashMLA on sm90 is dropped with latest cutlass, so here revert the to the old version. Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
// Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto copy_op = make_cotiled_copy( auto copy_op = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, Element>{}, Copy_Atom<UniversalCopy<uint128_t>, Element>{},
...@@ -960,23 +959,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ...@@ -960,23 +959,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
auto thr_copy = copy_op.get_slice(_0{}); auto thr_copy = copy_op.get_slice(_0{});
Tensor quantized_regs = quantize(regs); Tensor quantized_regs = quantize(regs);
auto tCg = thr_copy.partition_D(gmem); Tensor tCr = thr_copy.partition_S(quantized_regs);
auto tCr = thr_copy.partition_S(quantize(regs)); Tensor tCg = thr_copy.partition_D(gmem);
auto tCc = thr_copy.partition_D(coord); Tensor tPc = thr_copy.partition_D(preds);
constexpr int R = decltype(tCr.layout())::rank;
auto tCg_v = group_modes<1, R>(tCg);
auto tCr_v = group_modes<1, R>(tCr);
auto tCc_v = group_modes<1, R>(tCc);
auto tCp_v = make_tensor<bool>(shape<1>(tCc_v));
for (int i = 0; i < size(tCp_v); ++i) {
tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape);
}
copy_if(copy_op, tCp_v, tCr_v, tCg_v);
copy_if(copy_op, tPc, tCr, tCg);
} }
...@@ -1500,7 +1487,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ...@@ -1500,7 +1487,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
CUTLASS_DEVICE void operator()(Params const& params, char* smem) { CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
#if IS_SM100 #if defined(KERUTILS_ENABLE_SM100A)
int warp_idx = cutlass::canonical_warp_idx_sync(); int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx); auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync(); uint32_t lane_predicate = cute::elect_one_sync();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment