Commit 082094b7 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Multiple updates and refactorings (#150)

* Multiple updates and refactorings

* Remove dead code
parent 1408756a
......@@ -9,3 +9,4 @@ dist/
compile_commands.json
.cache
/dev
/.clangd
......@@ -28,7 +28,8 @@ FlashMLA is DeepSeek's library of optimized attention kernels, powering the [Dee
#### Test & benchmark MLA decoding (Sparse & Dense):
```bash
python tests/test_flash_mla_decoding.py
python tests/test_flash_mla_dense_decoding.py
python tests/test_flash_mla_sparse_decoding.py
```
The dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet).
......@@ -44,7 +45,7 @@ It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation
#### Test & benchmark MLA prefill (Sparse):
```bash
python tests/test_flash_mla_prefill.py
python tests/test_flash_mla_sparse_prefill.py
```
It achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9.
......
#include <pybind11/pybind11.h>
#include "sparse_fwd.h"
#include "sparse_decode.h"
#include "dense_decode.h"
#include "dense_fwd.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashMLA";
m.def("sparse_decode_fwd", &sparse_attn_decode_interface);
m.def("dense_decode_fwd", &dense_attn_decode_interface);
m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface);
m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun);
m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun);
}
#pragma once
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <kerutils/supplemental/torch_tensors.h>
#include <cutlass/bfloat16.h>
static constexpr float LOG_2_E = 1.44269504f;
// Instantiation for tensor.data_ptr<cutlass::bfloat16_t>()
template<>
inline cutlass::bfloat16_t* at::TensorBase::data_ptr<cutlass::bfloat16_t>() const {
return reinterpret_cast<cutlass::bfloat16_t*>(this->data_ptr());
}
// A struct that holds the architecture information of the current GPU.
struct Arch {
int major;
int minor;
int num_sms;
cudaDeviceProp* device_prop;
Arch() {
device_prop = at::cuda::getCurrentDeviceProperties();
major = device_prop->major;
minor = device_prop->minor;
num_sms = device_prop->multiProcessorCount;
}
bool is_sm90a() const {
return major == 9 && minor == 0;
}
bool is_sm100f() const {
return major == 10;
}
};
// Convert int64_t stride to int32_t, with overflow check.
inline int int64_stride_to_int(int64_t orig_stride) {
if (orig_stride > std::numeric_limits<int>::max()) {
TORCH_CHECK(false, "[FlashMLA] Stride exceeds int32 limit: ", orig_stride);
}
return static_cast<int>(orig_stride);
}
#define DISPATCH_NUM_HEADS(NUM_HEADS, CONSTEXPR_NAME, ...) \
[&] () { \
if (NUM_HEADS == 128) { \
static constexpr int CONSTEXPR_NAME = 128; \
return __VA_ARGS__(); \
} else if (NUM_HEADS == 64) { \
static constexpr int CONSTEXPR_NAME = 64; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported num_heads_q: ", NUM_HEADS); \
} \
} ();
#define DISPATCH_HEAD_DIM(HEAD_DIM, CONSTEXPR_NAME, ...) \
[&] () { \
if (HEAD_DIM == 576) { \
static constexpr int CONSTEXPR_NAME = 576; \
return __VA_ARGS__(); \
} else if (HEAD_DIM == 512) { \
static constexpr int CONSTEXPR_NAME = 512; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported head_dim_qk: ", HEAD_DIM); \
} \
} ();
#define DISPATCH_BOOLEAN_FLAG(FLAG, CONSTEXPR_NAME, ...) \
[&] () { \
if (FLAG) { \
static constexpr bool CONSTEXPR_NAME = true; \
return __VA_ARGS__(); \
} else { \
static constexpr bool CONSTEXPR_NAME = false; \
return __VA_ARGS__(); \
} \
} ();
#define DISPATCH_MODEL_TYPE(MODEL_TYPE, CONSTEXPR_NAME, ...) \
[&] () { \
if (MODEL_TYPE == ModelType::V32) { \
static constexpr ModelType CONSTEXPR_NAME = ModelType::V32; \
return __VA_ARGS__(); \
} else if (MODEL_TYPE == ModelType::MODEL1) { \
static constexpr ModelType CONSTEXPR_NAME = ModelType::MODEL1; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported model type: ", (int)MODEL_TYPE); \
} \
} ();
// The following code is adapted from https://ykiko.me/en/articles/680412313/, which converts enum values to string names.
template<auto value>
constexpr auto get_static_enum_name(){
std::string_view name;
#if __GNUC__ || __clang__
name = __PRETTY_FUNCTION__;
std::size_t start = name.find('=') + 2;
std::size_t end = name.size() - 1;
name = std::string_view{ name.data() + start, end - start };
start = name.find("::");
#elif _MSC_VER
name = __FUNCSIG__;
std::size_t start = name.find('<') + 1;
std::size_t end = name.rfind(">(");
name = std::string_view{ name.data() + start, end - start };
start = name.rfind("::");
#endif
return start == std::string_view::npos ? name : std::string_view {
name.data() + start + 2, name.size() - start - 2
};
}
template<typename T, std::size_t N = 0>
static constexpr std::size_t get_enum_max(){
constexpr T value = static_cast<T>(N);
if constexpr (get_static_enum_name<value>().find(")") == std::string_view::npos)
return get_enum_max<T, N + 1>();
else
return N;
}
template<typename T> requires std::is_enum_v<T>
static constexpr std::string get_dynamic_enum_name(T value){
constexpr std::size_t num = get_enum_max<T>();
constexpr auto names = []<std::size_t... Is>(std::index_sequence<Is...>){
return std::array<std::string_view, num>{
get_static_enum_name<static_cast<T>(Is)>()...
};
}(std::make_index_sequence<num>{});
return (std::string)names[static_cast<std::size_t>(value)];
}
// A shortcut macro to declare supported features in an implementation class.
#define DECLARE_SUPPORTED_FEATURES(...) \
protected: \
static constexpr FeatureT features[] = { __VA_ARGS__ }; \
constexpr inline std::span<const FeatureT> get_supported_features() const override { \
return features; \
}
/*
ImplBase - The base class for every implementation.
Every implementation should inherit from this class and implement the pure virtual functions, including:
- `run_`: The function that runs the implementation.
- `get_supported_features`: The function that returns the supported features of the implementation. You may use `DECLARE_SUPPORTED_FEATURES` to declare the supported features in a concise way.
The dispatcher will invoke `ImplBase::run()`, which checks if all required features are supported by the implementation, and then calls `run_`.
*/
template<
typename RunArgT_,
typename FeatureT_
>
class ImplBase {
protected:
using RunArgT = RunArgT_;
using FeatureT = FeatureT_;
virtual inline void run_(const RunArgT &params, const std::vector<FeatureT> &required_features) = 0;
constexpr virtual inline std::span<const FeatureT> get_supported_features() const = 0;
virtual ~ImplBase() = default;
public:
inline bool check_if_all_features_are_supported(const std::vector<FeatureT> &required_features) {
for (const auto &required_feature : required_features) {
bool is_supported = false;
for (const auto &supported_feature : get_supported_features()) {
if (required_feature == supported_feature) {
is_supported = true;
break;
}
}
if (!is_supported) {
return false;
}
}
return true;
}
inline void check_if_all_features_are_supported_and_abort(const std::vector<FeatureT> &required_features) {
if (!check_if_all_features_are_supported(required_features)) {
fprintf(stderr, "[FlashMLA] Error: The chosen implementation does not support all required features.\n");
fprintf(stderr, "Required features:\n");
for (const auto &f : required_features) {
fprintf(stderr, " - %3d: %s\n", static_cast<int>(f), get_dynamic_enum_name(f).c_str());
}
fprintf(stderr, "\n");
fprintf(stderr, "Supported features:\n");
for (const auto &supported_feature : get_supported_features()) {
fprintf(stderr, " - %3d: %s\n", static_cast<int>(supported_feature), get_dynamic_enum_name(supported_feature).c_str());
}
fprintf(stderr, "\n");
fprintf(stderr, "Features that are required but not supported:\n");
for (const auto &required_feature : required_features) {
bool is_supported = false;
for (const auto &supported_feature : get_supported_features()) {
if (required_feature == supported_feature) {
is_supported = true;
break;
}
}
if (!is_supported) {
fprintf(stderr, " - %3d: %s\n", static_cast<int>(required_feature), get_dynamic_enum_name(required_feature).c_str());
}
}
fprintf(stderr, "\n");
Arch cur_gpu_arch = Arch();
fprintf(stderr, "Current GPU: %s, SM %d.%d with %d SMs\n", cur_gpu_arch.device_prop->name, cur_gpu_arch.major, cur_gpu_arch.minor, cur_gpu_arch.num_sms);
fprintf(stderr, "This means that the dispatcher has chosen an implementation that does not support all required features. Maybe there is a bug in the dispatcher, or you have requested an invalid combination of features.\n");
TORCH_CHECK(false, "The chosen implementation does not support all required features. See message above for details.");
}
}
inline void run(const RunArgT &params, const std::vector<FeatureT> &required_features) {
check_if_all_features_are_supported_and_abort(required_features);
run_(params, required_features);
}
};
#pragma once
#include <cutlass/half.h>
#include <cutlass/fast_math.h>
#include "common.h"
#include "params.h"
#include "sm90/decode/dense/splitkv_mla.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
dense_attn_decode_interface(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale,
bool is_causal,
std::optional<at::Tensor> &tile_scheduler_metadata, // num_sm_parts x (DecodingSchedMetaSize/4)
std::optional<at::Tensor> &num_splits // batch_size + 1
) {
// Check arch
Arch arch = Arch();
if (!arch.is_sm90a()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on SM90a architecture");
}
// Check data types
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
// Check device
KU_CHECK_DEVICE(q);
KU_CHECK_DEVICE(kcache);
KU_CHECK_DEVICE(seqlens_k);
KU_CHECK_DEVICE(block_table);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
// Check layout
TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension");
KU_CHECK_CONTIGUOUS(seqlens_k);
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_q = sizes[2];
const int head_size_k = sizes[3];
TORCH_CHECK(head_size_k == 576 || head_size_k == 512, "Only head_size_k == 576 or 512 is supported");
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int num_q_heads_per_hk = num_heads_q / num_heads_k;
const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
int num_sm_parts = std::max(arch.num_sms / num_heads_k / cutlass::ceil_div(seqlen_q_ori*num_heads_q/num_heads_k, 64), 1);
KU_CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
KU_CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
KU_CHECK_SHAPE(seqlens_k, batch_size);
KU_CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, DecodingSchedMetaSize/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, num_heads, q_seq_per_hk, head_size_v}, opts);
at::Tensor lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(out);
KU_CHECK_CONTIGUOUS(lse);
if (!tile_scheduler_metadata.has_value()) {
tile_scheduler_metadata = torch::empty({num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32));
num_splits = torch::empty({batch_size+1}, opts.dtype(torch::kInt32));
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
GetDecodeSchedMetaParams get_sched_meta_params = {
batch_size, seqlen_q_ori,
64,
5,
-1, -1,
nullptr, nullptr,
seqlens_k.data_ptr<int>(),
(DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(),
num_splits->data_ptr<int>(),
num_sm_parts,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
} else {
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
KU_CHECK_SHAPE(tile_scheduler_metadata, num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int));
KU_CHECK_SHAPE(num_splits, batch_size+1);
}
// Set the sizes
DenseAttnDecodeParams params;
params.b = batch_size;
params.s_q = seqlen_q_ori;
params.q_seq_per_hk = q_seq_per_hk;
params.seqlens_k_ptr = seqlens_k.data_ptr<int>();
params.h_q = num_heads_q;
params.h_k = num_heads_k;
params.num_blocks = num_blocks;
params.q_head_per_hk = num_q_heads_per_hk;
params.is_causal = is_causal;
params.d = head_size_k;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = lse.data_ptr<float>();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(1);
params.k_row_stride = kcache.stride(1);
params.o_row_stride = out.stride(2);
params.q_head_stride = q.stride(2);
params.k_head_stride = kcache.stride(2);
params.o_head_stride = out.stride(1);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr();
params.num_sm_parts = num_sm_parts;
params.num_splits_ptr = num_splits->data_ptr<int>();
const int total_num_splits = batch_size + params.num_sm_parts;
at::Tensor lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(lse_accum);
KU_CHECK_CONTIGUOUS(out_accum);
params.total_num_splits = total_num_splits;
params.softmax_lseaccum_ptr = lse_accum.data_ptr<float>();
params.oaccum_ptr = out_accum.data_ptr<float>();
params.stream = at::cuda::getCurrentCUDAStream().stream();
if (q_dtype == torch::kBFloat16) {
sm90::run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params);
} else if (q_dtype == torch::kHalf) {
#ifdef FLASH_MLA_DISABLE_FP16
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
#else
sm90::run_flash_splitkv_mla_kernel<cutlass::half_t>(params);
#endif
} else {
TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
}
CombineParams combine_params = {
batch_size, seqlen_q_ori,
num_heads_q, head_size_v,
params.softmax_lse_ptr,
params.o_ptr,
num_heads*q_seq_per_hk, num_heads_q,
num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,
params.softmax_lseaccum_ptr,
params.oaccum_ptr,
num_heads*q_seq_per_hk, num_heads_q,
num_heads_q*seqlen_q_ori*head_size_v, num_heads_q*head_size_v, head_size_v,
params.tile_scheduler_metadata_ptr,
params.num_splits_ptr,
params.num_sm_parts,
nullptr,
at::cuda::getCurrentCUDAStream().stream()
};
if (q_dtype == torch::kBFloat16) {
smxx::decode::run_flash_mla_combine_kernel<cutlass::bfloat16_t>(combine_params);
} else if (q_dtype == torch::kHalf) {
#ifndef FLASH_MLA_DISABLE_FP16
smxx::decode::run_flash_mla_combine_kernel<cutlass::half_t>(combine_params);
#endif
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk, head_size_v}).transpose(1, 2)
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
lse = lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)
.reshape({batch_size, num_heads_q, seqlen_q_ori});
return {out, lse, tile_scheduler_metadata, num_splits};
}
#pragma once
#include "common.h"
#include "sm100/prefill/dense/interface.h"
#pragma once
#include "common.h"
#include "params.h"
#include "sm90/decode/sparse_fp8/splitkv_mla.h"
#include "sm100/decode/head64/kernel.h"
#include "sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h"
// Feature set of sparse decoding kernels
enum class DecodeFeatures : int {
HEAD_64,
HEAD_128,
HEAD_DIM_576,
HEAD_DIM_512,
V32_KVCACHE_FORMAT,
MODEL1_KVCACHE_FORMAT,
ATTN_SINK,
TOPK_LENGTH,
EXTRA_KVCACHE,
EXTRA_TOPK_LENGTH
};
struct DecodeImplMeta {
int num_sm_parts;
int fixed_overhead_num_blocks;
int block_size_topk;
};
class DecodeImplBase : public ImplBase<
SparseAttnDecodeParams,
DecodeFeatures
> {
public:
virtual DecodeImplMeta get_meta(int h_q, int s_q) = 0;
};
class Decode_Sm90_Impl : public DecodeImplBase {
DECLARE_SUPPORTED_FEATURES(
DecodeFeatures::HEAD_64,
DecodeFeatures::HEAD_128,
DecodeFeatures::HEAD_DIM_512,
DecodeFeatures::HEAD_DIM_576,
DecodeFeatures::V32_KVCACHE_FORMAT,
DecodeFeatures::MODEL1_KVCACHE_FORMAT,
DecodeFeatures::ATTN_SINK,
DecodeFeatures::TOPK_LENGTH,
DecodeFeatures::EXTRA_KVCACHE,
DecodeFeatures::EXTRA_TOPK_LENGTH
)
public:
DecodeImplMeta get_meta(int h_q, int s_q) override {
Arch arch = Arch();
return {
std::max(arch.num_sms / s_q / (h_q/64), 1),
5,
64
};
}
protected:
void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() {
DISPATCH_NUM_HEADS(params.h_q, NUM_HEADS, [&]() {
sm90::decode::sparse_fp8::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE, NUM_HEADS>(params);
});
});
}
};
class Decode_Sm100_Head64_Impl : public DecodeImplBase {
DECLARE_SUPPORTED_FEATURES(
DecodeFeatures::HEAD_64,
DecodeFeatures::HEAD_DIM_512,
DecodeFeatures::HEAD_DIM_576,
DecodeFeatures::V32_KVCACHE_FORMAT,
DecodeFeatures::MODEL1_KVCACHE_FORMAT,
DecodeFeatures::ATTN_SINK,
DecodeFeatures::TOPK_LENGTH,
DecodeFeatures::EXTRA_KVCACHE,
DecodeFeatures::EXTRA_TOPK_LENGTH
)
public:
DecodeImplMeta get_meta(int h_q, int s_q) override {
Arch arch = Arch();
return {
std::max(arch.num_sms / s_q, 1),
5,
64
};
}
protected:
void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() {
sm100::decode::head64::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE>(params);
});
}
};
// An implementation that calls the head64 kernel twice to process head128
// Necessary for running V3.2 shape (i.e. h = 128, d_qk = 576) on SM100f
class Decode_Sm100_Head64x2_Impl : public DecodeImplBase {
DECLARE_SUPPORTED_FEATURES(
DecodeFeatures::HEAD_128,
DecodeFeatures::HEAD_DIM_512,
DecodeFeatures::HEAD_DIM_576,
DecodeFeatures::V32_KVCACHE_FORMAT,
DecodeFeatures::MODEL1_KVCACHE_FORMAT,
DecodeFeatures::ATTN_SINK,
DecodeFeatures::TOPK_LENGTH,
DecodeFeatures::EXTRA_KVCACHE,
DecodeFeatures::EXTRA_TOPK_LENGTH
)
public:
DecodeImplMeta get_meta(int h_q, int s_q) override {
Arch arch = Arch();
return {
std::max(arch.num_sms / s_q, 1),
5,
64
};
}
protected:
void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() {
for (int start_head_idx = 0; start_head_idx < 128; start_head_idx += 64) {
SparseAttnDecodeParams cur_params = params;
cur_params.q += start_head_idx * params.stride_q_h_q;
if (cur_params.attn_sink) {
cur_params.attn_sink += start_head_idx;
}
cur_params.lse += start_head_idx;
cur_params.out += start_head_idx * params.stride_o_h_q;
cur_params.lse_accum += start_head_idx;
cur_params.o_accum += start_head_idx * params.stride_o_accum_h_q;
cur_params.h_q = 64;
sm100::decode::head64::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE>(cur_params);
}
});
}
};
class Decode_Sm100_Head128_Impl : public DecodeImplBase {
DECLARE_SUPPORTED_FEATURES(
DecodeFeatures::HEAD_128,
DecodeFeatures::HEAD_DIM_512,
DecodeFeatures::MODEL1_KVCACHE_FORMAT,
DecodeFeatures::ATTN_SINK,
DecodeFeatures::TOPK_LENGTH,
DecodeFeatures::EXTRA_KVCACHE,
DecodeFeatures::EXTRA_TOPK_LENGTH
)
public:
DecodeImplMeta get_meta(int h_q, int s_q) override {
Arch arch = Arch();
return {
std::max(arch.num_sms / s_q / 2, 1),
3,
64
};
}
protected:
void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override {
sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::DecodeWithSplitKV, 512>(params);
}
};
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
sparse_attn_decode_interface(
const at::Tensor &q, // [b, s_q, h_q, d_qk]
const at::Tensor &kv, // [num_blocks, page_block_size, h_k, d_qk]
const at::Tensor &indices, // [b, s_q, topk]
const std::optional<at::Tensor> &topk_length, // [b, s_q]
const std::optional<at::Tensor> &attn_sink, // [h_q]
std::optional<at::Tensor> &tile_scheduler_metadata, // num_sm_parts x (DecodingSchedMetaSize/4)
std::optional<at::Tensor> &num_splits, // batch_size + 1
const std::optional<at::Tensor> &extra_kv,
const std::optional<at::Tensor> &extra_indices,
const std::optional<at::Tensor> &extra_topk_length,
int d_v,
float sm_scale
) {
using bf16 = cutlass::bfloat16_t;
// Check the architecture
Arch arch = Arch();
KU_CHECK_NDIM(q, 4);
KU_CHECK_NDIM(kv, 4);
KU_CHECK_NDIM(indices, 3);
int b = q.size(0);
int s_q = q.size(1);
int h_q = q.size(2);
int d_qk = q.size(3);
int num_blocks = kv.size(0);
int page_block_size = kv.size(1);
int h_kv = kv.size(2);
int topk = indices.size(2);
bool have_topk_length = topk_length.has_value();
bool have_extra_kcache = extra_kv.has_value();
bool have_extra_topk_length = extra_topk_length.has_value();
bool have_attn_sink = attn_sink.has_value();
int extra_num_blocks = 0, extra_page_block_size = 0, extra_topk = 0;
if (have_extra_kcache) {
extra_num_blocks = extra_kv->size(0);
extra_page_block_size = extra_kv->size(1);
}
if (extra_indices.has_value()) {
extra_topk = extra_indices->size(-1);
}
// metadata sanity check
TORCH_CHECK(b > 0);
TORCH_CHECK(s_q > 0);
TORCH_CHECK(h_q > 0);
TORCH_CHECK(h_kv == 1, "Currently only MQA (i.e. h_kv == 1) is supported for sparse decoding");
TORCH_CHECK(d_qk == 576 || d_qk == 512, "Only head_size_k == 576 or 512 is supported for sparse decoding");
TORCH_CHECK(d_v == 512, "Only head_size_v == 512 is supported for sparse decoding");
TORCH_CHECK(topk > 0);
if (have_extra_kcache) {
TORCH_CHECK(extra_indices.has_value(), "extra_indices_in_kvcache must be provided when extra_kcache is provided for sparse attention");
} else {
TORCH_CHECK(!extra_indices.has_value(), "extra_indices_in_kvcache must not be provided when extra_k_cache is not provided");
TORCH_CHECK(!extra_topk_length.has_value(), "extra_topk_length must not be provided when extra_k_cache is not provided");
}
// Check device
KU_CHECK_DEVICE(q);
KU_CHECK_DEVICE(kv);
KU_CHECK_DEVICE(indices);
KU_CHECK_DEVICE(topk_length);
KU_CHECK_DEVICE(attn_sink);
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_DEVICE(extra_kv);
KU_CHECK_DEVICE(extra_indices);
KU_CHECK_DEVICE(extra_topk_length);
// Check data type
KU_CHECK_DTYPE(q, torch::kBFloat16);
TORCH_CHECK(kv.dtype() == torch::kFloat8_e4m3fn || kv.dtype() == torch::kInt8 || kv.dtype() == torch::kUInt8, "key must have dtype fp8_e4m3fn, int8 or uint8");
if (extra_kv.has_value()) {
TORCH_CHECK(extra_kv->dtype() == torch::kFloat8_e4m3fn || extra_kv->dtype() == torch::kInt8 || extra_kv->dtype() == torch::kUInt8, "extra k cache must have dtype fp8_e4m3fn, int8 or uint8");
}
KU_CHECK_DTYPE(indices, torch::kInt32);
KU_CHECK_DTYPE(topk_length, torch::kInt32);
KU_CHECK_DTYPE(attn_sink, torch::kFloat32);
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_DTYPE(extra_indices, torch::kInt32);
KU_CHECK_DTYPE(extra_topk_length, torch::kInt32);
// Check layout
KU_CHECK_LAST_DIM_CONTIGUOUS(q);
KU_CHECK_LAST_DIM_CONTIGUOUS(kv);
KU_CHECK_LAST_DIM_CONTIGUOUS(indices);
KU_CHECK_CONTIGUOUS(topk_length);
KU_CHECK_CONTIGUOUS(attn_sink);
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
KU_CHECK_LAST_DIM_CONTIGUOUS(extra_kv);
KU_CHECK_LAST_DIM_CONTIGUOUS(extra_indices);
KU_CHECK_CONTIGUOUS(extra_topk_length);
// Check shape
KU_CHECK_SHAPE(q, b, s_q, h_q, d_qk);
{
int bytes_per_token;
if (d_qk == 576 && d_v == 512) {
// V3.2 style
bytes_per_token = 512 + 64*2 + (512/128)*4;
} else if (d_qk == 512 && d_v == 512) {
// MODEL1 style
bytes_per_token = 448 + 64*2 + (448/64)*1 + 1;
} else {
TORCH_CHECK(false, "Unsupported head sizes for is_fp8_kvcache == True");
}
KU_CHECK_SHAPE(kv, num_blocks, page_block_size, h_kv, bytes_per_token);
KU_CHECK_SHAPE(extra_kv, extra_num_blocks, extra_page_block_size, h_kv, bytes_per_token);
TORCH_CHECK(kv.stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True for kv cache");
if (extra_kv.has_value()) {
TORCH_CHECK(extra_kv->stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True for extra kv cache");
}
}
KU_CHECK_SHAPE(indices, b, s_q, topk);
KU_CHECK_SHAPE(topk_length, b);
KU_CHECK_SHAPE(attn_sink, h_q);
KU_CHECK_SHAPE(extra_indices, b, s_q, extra_topk);
KU_CHECK_SHAPE(extra_topk_length, b);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({b, s_q, h_q, d_v}, opts);
at::Tensor lse = torch::empty({b, s_q, h_q}, opts.dtype(at::kFloat));
ModelType model_type;
if (d_qk == 576) {
model_type = ModelType::V32;
} else if (d_qk == 512) {
model_type = ModelType::MODEL1;
} else {
TORCH_CHECK(false, "Unsupported d_qk: ", d_qk);
}
std::vector<DecodeFeatures> features;
if (h_q == 64) {
features.push_back(DecodeFeatures::HEAD_64);
} else if (h_q == 128) {
features.push_back(DecodeFeatures::HEAD_128);
} else {
TORCH_CHECK(false, "Unsupported h_q: ", h_q);
}
if (d_qk == 576) {
features.push_back(DecodeFeatures::HEAD_DIM_576);
} else if (d_qk == 512) {
features.push_back(DecodeFeatures::HEAD_DIM_512);
} else {
TORCH_CHECK(false, "Unsupported d_qk: ", d_qk);
}
if (model_type == ModelType::V32) {
features.push_back(DecodeFeatures::V32_KVCACHE_FORMAT);
} else if (model_type == ModelType::MODEL1) {
features.push_back(DecodeFeatures::MODEL1_KVCACHE_FORMAT);
} else {
TORCH_CHECK(false, "Unsupported model type: ", (int)model_type);
}
if (have_attn_sink) {
features.push_back(DecodeFeatures::ATTN_SINK);
}
if (have_topk_length) {
features.push_back(DecodeFeatures::TOPK_LENGTH);
}
if (have_extra_kcache) {
features.push_back(DecodeFeatures::EXTRA_KVCACHE);
}
if (have_extra_topk_length) {
features.push_back(DecodeFeatures::EXTRA_TOPK_LENGTH);
}
DecodeImplBase* impl;
if (arch.is_sm100f()) {
if (h_q == 64) {
impl = new Decode_Sm100_Head64_Impl();
} else if (h_q == 128) {
if (d_qk == 576) {
impl = new Decode_Sm100_Head64x2_Impl();
} else if (d_qk == 512) {
impl = new Decode_Sm100_Head128_Impl();
} else {
TORCH_CHECK(false, "Unsupported d_qk: ", d_qk);
}
} else {
TORCH_CHECK(false, "Unsupported h_q: ", h_q);
}
} else if (arch.is_sm90a()) {
impl = new Decode_Sm90_Impl();
} else {
TORCH_CHECK(false, "Unsupported architecture for sparse decode fwd");
}
DecodeImplMeta impl_meta = impl->get_meta(h_q, s_q);
SparseAttnDecodeParams params = {
b, s_q, h_q, h_kv, d_qk, d_v,
sm_scale, sm_scale * LOG_2_E,
num_blocks, page_block_size, topk,
model_type,
(bf16*)q.data_ptr(),
(bf16*)kv.data_ptr(),
(int*)indices.data_ptr(),
ku::get_optional_tensor_ptr<int>(topk_length),
ku::get_optional_tensor_ptr<float>(attn_sink),
(float*)lse.data_ptr(),
(bf16*)out.data_ptr(),
extra_num_blocks, extra_page_block_size, extra_topk,
ku::get_optional_tensor_ptr<bf16>(extra_kv),
ku::get_optional_tensor_ptr<int>(extra_indices),
ku::get_optional_tensor_ptr<int>(extra_topk_length),
int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)), int64_stride_to_int(q.stride(2)),
int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)),
int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)),
int64_stride_to_int(lse.stride(0)), int64_stride_to_int(lse.stride(1)),
int64_stride_to_int(out.stride(0)), int64_stride_to_int(out.stride(1)), int64_stride_to_int(out.stride(2)),
have_extra_kcache ? int64_stride_to_int(extra_kv->stride(0)) : 0,
have_extra_kcache ? int64_stride_to_int(extra_kv->stride(1)) : 0,
have_extra_kcache ? int64_stride_to_int(extra_indices->stride(0)) : 0,
have_extra_kcache ? int64_stride_to_int(extra_indices->stride(1)) : 0,
at::cuda::getCurrentCUDAStream().stream()
};
// Get MLA metadata if necessary
at::Tensor o_accum, lse_accum;
if (!tile_scheduler_metadata.has_value()) {
tile_scheduler_metadata = torch::empty({impl_meta.num_sm_parts, sizeof(DecodingSchedMeta)/4}, opts.dtype(torch::kInt32));
num_splits = torch::empty({b+1}, opts.dtype(torch::kInt32));
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
GetDecodeSchedMetaParams get_sched_meta_params = {
b, s_q,
impl_meta.block_size_topk,
impl_meta.fixed_overhead_num_blocks,
topk,
extra_topk,
ku::get_optional_tensor_ptr<int>(topk_length),
ku::get_optional_tensor_ptr<int>(extra_topk_length),
nullptr,
(DecodingSchedMeta*)tile_scheduler_metadata->data_ptr(),
num_splits->data_ptr<int>(),
impl_meta.num_sm_parts,
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
}
// Stick the metadata pointers to `params`
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
KU_CHECK_DTYPE(tile_scheduler_metadata, torch::kInt32);
KU_CHECK_DTYPE(num_splits, torch::kInt32);
KU_CHECK_CONTIGUOUS(tile_scheduler_metadata);
KU_CHECK_CONTIGUOUS(num_splits);
KU_CHECK_SHAPE(tile_scheduler_metadata, impl_meta.num_sm_parts, sizeof(DecodingSchedMeta)/sizeof(int));
KU_CHECK_SHAPE(num_splits, b+1);
params.tile_scheduler_metadata_ptr = (DecodingSchedMeta*)tile_scheduler_metadata->data_ptr();
params.num_splits_ptr = num_splits->data_ptr<int>();
params.num_sm_parts = impl_meta.num_sm_parts;
// Allocate intermediate buffers for split-KV
const int total_num_splits = b + impl_meta.num_sm_parts;
lse_accum = torch::empty({total_num_splits, s_q, h_q}, opts.dtype(at::kFloat));
o_accum = torch::empty({total_num_splits, s_q, h_q, d_v}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(lse_accum);
KU_CHECK_CONTIGUOUS(o_accum);
params.lse_accum = lse_accum.data_ptr<float>();
params.o_accum = o_accum.data_ptr<float>();
params.stride_lse_accum_split = int64_stride_to_int(lse_accum.stride(0));
params.stride_lse_accum_s_q = int64_stride_to_int(lse_accum.stride(1));
params.stride_o_accum_split = int64_stride_to_int(o_accum.stride(0));
params.stride_o_accum_s_q = int64_stride_to_int(o_accum.stride(1));
params.stride_o_accum_h_q = int64_stride_to_int(o_accum.stride(2));
impl->run(params, features);
CombineParams combine_params = {
b, s_q, h_q, d_v,
params.lse,
params.out,
params.stride_lse_b, params.stride_lse_s_q,
params.stride_o_b, params.stride_o_s_q, params.stride_o_h_q,
params.lse_accum,
params.o_accum,
params.stride_lse_accum_split, params.stride_lse_accum_s_q,
params.stride_o_accum_split, params.stride_o_accum_s_q, params.stride_o_accum_h_q,
params.tile_scheduler_metadata_ptr,
params.num_splits_ptr,
params.num_sm_parts,
ku::get_optional_tensor_ptr<float>(attn_sink),
at::cuda::getCurrentCUDAStream().stream()
};
smxx::decode::run_flash_mla_combine_kernel<bf16>(combine_params);
delete impl;
return {out, lse.transpose(1, 2), tile_scheduler_metadata, num_splits};
}
#pragma once
#include "common.h"
#include "params.h"
#include "sm90/prefill/sparse/phase1.h"
#include "sm100/prefill/sparse/fwd/head128/phase1.h"
#include "sm100/prefill/sparse/fwd/head64/phase1.h"
#include "sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h"
enum class FwdFeatures : int {
HEAD_64,
HEAD_128,
HEAD_DIM_576,
HEAD_DIM_512,
ATTN_SINK,
SINK_LSE,
TOPK_LENGTH
};
class FwdImplBase : public ImplBase<
SparseAttnFwdParams,
FwdFeatures
> {};
class Fwd_Sm90_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_64,
FwdFeatures::HEAD_128,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::HEAD_DIM_576,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() {
sm90::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);
});
});
}
};
class Fwd_Sm100_Head64_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_64,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::HEAD_DIM_576,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
sm100::fwd::head64::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);
});
}
};
class Fwd_Sm100_Head128_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_128,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::HEAD_DIM_576,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
sm100::fwd::head128::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);
});
}
};
class Fwd_Sm100_Head128_Small_TopK_Impl : public FwdImplBase {
DECLARE_SUPPORTED_FEATURES(
FwdFeatures::HEAD_128,
FwdFeatures::HEAD_DIM_512,
FwdFeatures::ATTN_SINK,
FwdFeatures::SINK_LSE,
FwdFeatures::TOPK_LENGTH
)
protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::Prefill, 512>(params);
}
};
static std::vector<at::Tensor> sparse_attn_prefill_interface(
const at::Tensor &q,
const at::Tensor &kv,
const at::Tensor &indices,
float sm_scale,
int d_v,
const std::optional<at::Tensor> &attn_sink,
const std::optional<at::Tensor> &topk_length
) {
using bf16 = cutlass::bfloat16_t;
Arch arch = Arch();
bool is_sm90a = arch.is_sm90a();
bool is_sm100f = arch.is_sm100f();
TORCH_CHECK(is_sm90a || is_sm100f, "Sparse Attention Forward Kernel is only supported on SM90a and SM100f architectures.");
KU_CHECK_NDIM(q, 3);
KU_CHECK_NDIM(kv, 3);
KU_CHECK_NDIM(indices, 3);
KU_CHECK_NDIM(attn_sink, 1);
KU_CHECK_NDIM(topk_length, 1);
int s_q = q.size(0);
int s_kv = kv.size(0);
int h_q = q.size(1);
int h_kv = kv.size(1);
int d_qk = q.size(2);
int topk = indices.size(2);
bool have_topk_length = topk_length.has_value();
TORCH_CHECK(d_qk == 576 || d_qk == 512, "Invalid d_qk: ", d_qk);
TORCH_CHECK(d_v == 512, "Invalid d_v", d_v);
KU_CHECK_DEVICE(q);
KU_CHECK_DEVICE(kv);
KU_CHECK_DEVICE(indices);
KU_CHECK_DEVICE(attn_sink);
KU_CHECK_DEVICE(topk_length);
KU_CHECK_DTYPE(q, torch::kBFloat16);
KU_CHECK_DTYPE(kv, torch::kBFloat16);
KU_CHECK_DTYPE(indices, torch::kInt32);
KU_CHECK_DTYPE(attn_sink, torch::kFloat32);
KU_CHECK_DTYPE(topk_length, torch::kInt32);
KU_CHECK_SHAPE(q, s_q, h_q, d_qk);
KU_CHECK_SHAPE(kv, s_kv, h_kv, d_qk);
KU_CHECK_SHAPE(indices, s_q, h_kv, topk);
KU_CHECK_SHAPE(attn_sink, h_q);
KU_CHECK_SHAPE(topk_length, s_q);
KU_CHECK_LAST_DIM_CONTIGUOUS(q);
KU_CHECK_LAST_DIM_CONTIGUOUS(kv);
KU_CHECK_LAST_DIM_CONTIGUOUS(indices);
KU_CHECK_LAST_DIM_CONTIGUOUS(attn_sink);
KU_CHECK_LAST_DIM_CONTIGUOUS(topk_length);
// Allocate results and buffers
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({s_q, h_q, d_v}, opts);
at::Tensor lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
at::Tensor max_logits = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
KU_CHECK_CONTIGUOUS(out);
KU_CHECK_CONTIGUOUS(lse);
KU_CHECK_CONTIGUOUS(max_logits);
SparseAttnFwdParams params = {
s_q, s_kv, h_q, h_kv, d_qk, d_v, topk,
sm_scale, sm_scale * LOG_2_E,
(bf16*)q.data_ptr(),
(bf16*)kv.data_ptr(),
(int*)indices.data_ptr(),
ku::get_optional_tensor_ptr<float>(attn_sink),
ku::get_optional_tensor_ptr<int>(topk_length),
int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)),
int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)),
int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)),
(bf16*)out.data_ptr(),
(float*)max_logits.data_ptr(),
(float*)lse.data_ptr(),
arch.num_sms,
at::cuda::getCurrentCUDAStream().stream()
};
std::vector<FwdFeatures> required_features;
if (h_q == 64) {
required_features.push_back(FwdFeatures::HEAD_64);
} else if (h_q == 128) {
required_features.push_back(FwdFeatures::HEAD_128);
} else {
TORCH_CHECK(false, "Unsupported h_q: ", h_q);
}
if (d_qk == 576) {
required_features.push_back(FwdFeatures::HEAD_DIM_576);
} else if (d_qk == 512) {
required_features.push_back(FwdFeatures::HEAD_DIM_512);
} else {
TORCH_CHECK(false, "Unsupported d_qk: ", d_qk);
}
if (attn_sink.has_value()) {
required_features.push_back(FwdFeatures::ATTN_SINK);
}
if (have_topk_length) {
required_features.push_back(FwdFeatures::TOPK_LENGTH);
}
if (is_sm90a) {
Fwd_Sm90_Impl fwd_impl;
fwd_impl.run(params, required_features);
} else if (is_sm100f) {
if (h_q == 64) {
Fwd_Sm100_Head64_Impl fwd_impl;
fwd_impl.run(params, required_features);
} else if (h_q == 128) {
Fwd_Sm100_Head128_Small_TopK_Impl small_topk_impl;
Fwd_Sm100_Head128_Impl regular_impl;
bool use_small_topk_impl = false;
if (
(topk <= 1280 && small_topk_impl.check_if_all_features_are_supported(required_features)) ||
!regular_impl.check_if_all_features_are_supported(required_features)
) {
use_small_topk_impl = true;
}
if (use_small_topk_impl) {
small_topk_impl.run(params, required_features);
} else {
regular_impl.run(params, required_features);
}
} else {
TORCH_CHECK(false, "Unsupported h_q: ", h_q);
}
} else {
TORCH_CHECK(false, "Unsupported architecture");
}
return {out, max_logits, lse};
}
Subproject commit e94e888df3551224738bfa505787b515eae8352f
Subproject commit 147f5673d0c1c3dcf66f78d677fd647e4a020219
......@@ -3,8 +3,6 @@
#include <cutlass/bfloat16.h>
#include <cutlass/arch/barrier.h>
namespace sm100 {
using bf16 = cutlass::bfloat16_t;
using fp8 = cutlass::float_e4m3_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
......@@ -26,5 +24,3 @@ struct bf16x8 {
__nv_bfloat162 a45;
__nv_bfloat162 a67;
};
}
#pragma once
namespace kerutils {}
#define KU_PRINTLN(fmt, ...) { cute::print(fmt, ##__VA_ARGS__); print("\n"); }
namespace ku = kerutils;
/*
Common data types and macros that are used across the kerutils library.
*/
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cutlass/bfloat16.h>
#include <cutlass/arch/barrier.h>
#include <cute/config.hpp> // For CUTE_DEVICE
namespace kerutils {
// Cache hints
enum class CacheHint {
EVICT_FIRST,
EVICT_NORMAL,
EVICT_LAST,
EVICT_UNCHANGED,
NO_ALLOCATE
};
// Prefetch size
enum class PrefetchSize {
B64,
B128,
B256
};
using nvbf16 = __nv_bfloat16;
using nvbf16x2 = __nv_bfloat162;
using nve4m3 = __nv_fp8_e4m3;
using nve4m3x2 = __nv_fp8x2_e4m3;
using nve4m3x4 = __nv_fp8x4_e4m3;
using bf16 = cutlass::bfloat16_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#define KERUTILS_ENABLE_SM80
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
static_assert(false, "kerutils doesn't support SM architectures below SM80");
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
#define KERUTILS_ENABLE_SM90
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000))
#define KERUTILS_ENABLE_SM90A
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
#define KERUTILS_ENABLE_SM100
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200))
#define KERUTILS_ENABLE_SM100A
#endif
#if (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))
#define KERUTILS_ENABLE_SM80
#define KERUTILS_ENABLE_SM90
#define KERUTILS_ENABLE_SM90A
#define KERUTILS_ENABLE_SM100
#define KERUTILS_ENABLE_SM100A
#endif
#pragma once
#include "kerutils/common/common.h"
#include "common.h"
#include "sm80/intrinsics.cuh"
#include "sm80/helpers.cuh"
#include "sm90/intrinsics.cuh"
#include "sm90/helpers.cuh"
#include "sm100/intrinsics.cuh"
#include "sm100/helpers.cuh"
#include "sm100/gemm.cuh"
#include "sm100/tma_cta_group2_nosplit.cuh"
#pragma once
#include <cute/tensor.hpp>
#include "sm100/defines.h"
namespace sm100 {
#include "kerutils/device/common.h"
using namespace cute;
using _72 = Int<72>;
using _576 = Int<576>;
template<
typename TMA,
typename Tensor0,
typename Tensor1
>
CUTE_DEVICE
void launch_tma_copy(
const TMA &tma_copy,
Tensor0 src,
Tensor1 dst,
transac_bar_t &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL
) {
auto thr_tma = tma_copy.get_slice(_0{});
cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint),
thr_tma.partition_S(src),
thr_tma.partition_D(dst)
);
}
namespace kerutils {
// Perform SS UTCMMA
// sA and sB should be shared memory tensors (i.e. make_tensor(make_shared_ptr(XXX), XXX)) while tC_frag should be tmem fragment
template<
typename TiledMMA,
typename TensorA,
......@@ -38,13 +15,14 @@ template<
typename TensorFragC
>
CUTE_DEVICE
void utcmma(
void utcmma_ss(
TiledMMA &tiled_mma,
TensorA sA,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
using namespace cute;
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sA_frag = thr_mma.partition_fragment_A(sA);
......@@ -64,6 +42,8 @@ void utcmma(
}
}
// Perform TS UTCMMA
// sB should be shared memory tensors (i.e. make_tensor(make_shared_ptr(XXX), XXX)) while tA_frag and tC_frag should be tmem fragment
template<
typename TiledMMA,
typename TensorA,
......@@ -78,6 +58,7 @@ void utcmma_ts(
TensorFragC tC_frag,
bool clear_accum
) {
using namespace cute;
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sB_frag = thr_mma.partition_fragment_B(sB);
......@@ -94,11 +75,63 @@ void utcmma_ts(
}
}
struct bf16x8 {
__nv_bfloat162 a01;
__nv_bfloat162 a23;
__nv_bfloat162 a45;
__nv_bfloat162 a67;
};
template<int MN, int K, int SWIZZLE, typename T = bf16>
static constexpr auto make_umma_canonical_k_major_layout() {
using namespace cute;
using base_atom_type = \
std::conditional_t<SWIZZLE == 0 || SWIZZLE == 16,
UMMA::Layout_K_INTER_Atom<T>,
std::conditional_t<SWIZZLE == 32,
UMMA::Layout_K_SW32_Atom<T>,
std::conditional_t<SWIZZLE == 64,
UMMA::Layout_K_SW64_Atom<T>,
std::conditional_t<SWIZZLE == 128,
UMMA::Layout_K_SW128_Atom<T>,
void
>
>
>
>;
static_assert(!std::is_same_v<base_atom_type, void>, "Invalid SWIZZLE value");
return coalesce(tile_to_shape(
base_atom_type{},
Shape<Int<MN>, Int<K>>{},
Step<_1, _2>{}
), Shape<_1, _1>{});
}
template<int MN, int K, int SWIZZLE, typename T = bf16>
static constexpr auto make_umma_canonical_mn_major_layout() {
using namespace cute;
using base_atom_type = \
std::conditional_t<SWIZZLE == 0 || SWIZZLE == 16,
UMMA::Layout_MN_INTER_Atom<T>,
std::conditional_t<SWIZZLE == 32,
UMMA::Layout_MN_SW32_Atom<T>,
std::conditional_t<SWIZZLE == 64,
UMMA::Layout_MN_SW64_Atom<T>,
std::conditional_t<SWIZZLE == 128,
UMMA::Layout_MN_SW128_Atom<T>,
void
>
>
>
>;
static_assert(!std::is_same_v<base_atom_type, void>, "Invalid SWIZZLE value");
return coalesce(tile_to_shape(
base_atom_type{},
Shape<Int<MN>, Int<K>>{},
Step<_2, _1>{}
), Shape<_1, _1>{});
}
template<cute::UMMA::Major MAJOR, int MN, int K, int SWIZZLE, typename T = bf16>
auto make_umma_canonical_layout() {
if constexpr (MAJOR == cute::UMMA::Major::K) {
return make_umma_canonical_k_major_layout<MN, K, SWIZZLE, T>();
} else {
return make_umma_canonical_mn_major_layout<MN, K, SWIZZLE, T>();
}
}
}
#pragma once
#include "kerutils/device/common.h"
namespace kerutils {
// tma gather4 (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
// Please pay attention that the coordinates of TMA gather4 are int32, which may lead to overflow under some scenarios
CUTE_DEVICE
void tma_gather4(const void* desc_ptr, transac_bar_t &mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar_ptr);
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbar_addr), "l"(cache_hint)
: "memory"
);
}
// tma gather4 prefetch (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor)
// Please pay attention that the coordinates of TMA gather4 are int32, which may lead to overflow under some scenarios
CUTE_DEVICE
void tma_gather4_prefetch(const void* desc_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) {
asm volatile(
"cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4.L2::cache_hint [%0, {%1, %2, %3, %4, %5}], %6;\n"
:
: "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"l"(cache_hint)
);
}
// tma gather4 with cta_group::2, allowing for synchronization across CTAs within a pair of CTAs (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
template<bool USE_CTA0_MBAR = false>
CUTE_DEVICE void tma_gather4_cta_group_2(const void* desc_ptr, transac_bar_t &mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar_ptr);
if constexpr (USE_CTA0_MBAR) {
mbar_addr &= cute::Sm100MmaPeerBitMask;
}
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbar_addr), "l"(cache_hint)
: "memory"
);
}
// Vectorized addition for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add)
CUTE_DEVICE
float2 float2_add(const float2 &a, const float2 &b) {
float2 c;
asm volatile(
"add.f32x2 %0, %1, %2;\n"
: "=l"(reinterpret_cast<uint64_t&>(c))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b))
);
return c;
}
// Vectorized multiplication for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-mul)
CUTE_DEVICE
float2 float2_mul(const float2 &a, const float2 &b) {
float2 c;
asm volatile(
"mul.f32x2 %0, %1, %2;\n"
: "=l"(reinterpret_cast<uint64_t&>(c))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)));
return c;
}
// Vectorized fused addition-multiplication for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-fma)
CUTE_DEVICE
float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) {
// return a*b+c
float2 d;
asm volatile(
"fma.rn.f32x2 %0, %1, %2, %3;\n"
: "=l"(reinterpret_cast<uint64_t&>(d))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)),
"l"(reinterpret_cast<uint64_t const&>(c)));
return d;
}
// Vectorized negation for foat32
CUTE_DEVICE
float2 float2_neg(const float2 &a) {
float2 t = {-1.0f, -1.0f};
return float2_mul(a, t);
}
// st.bulk (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-bulk)
CUTE_DEVICE
void st_bulk(void* dst_ptr, int64_t size) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
asm volatile (
"st.bulk.weak.shared::cta [%0], %1, 0;\n"
:
: "r"(dst_addr), "l"(size)
: "memory"
);
}
struct CUTE_ALIGNAS(16) CLCResponseObj {
// An opaque 16B value
char opaque[16];
};
struct CLCResult {
int is_valid;
int x, y, z;
};
// Issue a CLC try_cancel query (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel)
CUTE_DEVICE
void issue_clc_query(transac_bar_t &bar, CLCResponseObj &response_obj) {
uint32_t response_addr = cute::cast_smem_ptr_to_uint(response_obj.opaque);
uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [%0], [%1];\n"
:
: "r"(response_addr), "r"(mbarrier_addr)
);
}
// Issue a CLC try_cancel query with .multicast::cluster::all (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel)
CUTE_DEVICE
void issue_clc_query_multicast_cluster_all(transac_bar_t &bar, CLCResponseObj &response_obj) {
uint32_t response_addr = cute::cast_smem_ptr_to_uint(response_obj.opaque);
uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128 [%0], [%1];\n"
:
: "r"(response_addr), "r"(mbarrier_addr)
);
}
// Get the result of a CLC query (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-query-cancel)
// In this function, we separate get_first_ctaid::x/y/z and hope PTXAS's dead code elimination can remove unnecessary instructions
template<bool USE_LD_ACQUIRE>
CUTE_DEVICE
CLCResult get_clc_query_response(CLCResponseObj &response_obj) {
uint32_t response_addr = cute::cast_smem_ptr_to_uint(&response_obj);
CLCResult result;
#define EMIT_ASM(LD_MODIFIER) \
asm volatile( \
"{\n" \
".reg .pred p1;\n\t" \
".reg .b128 clc_result;\n\t" \
"ld" LD_MODIFIER ".shared.b128 clc_result, [%4];\n\t" \
"clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p1, clc_result;\n\t" \
"selp.u32 %3, 1, 0, p1;\n\t" \
"@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 %0, clc_result;\n\t" \
"@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::y.b32.b128 %1, clc_result;\n\t" \
"@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::z.b32.b128 %2, clc_result;\n\t" \
"}\n" \
: "=r"(result.x), "=r"(result.y), "=r"(result.z), "=r"(result.is_valid) \
: "r"(response_addr) \
: "memory" \
);
if constexpr (USE_LD_ACQUIRE) {
EMIT_ASM(".acquire.cta");
} else {
EMIT_ASM("");
}
return result;
}
// LDG.256 or LDG.256 with non-coherent cache (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld)
// We use macro instead of function here, since we need a multi-level recursive dispatch based on template parameters if using function
// NC_STR should be either "" or ".nc"
// L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate"
// L2_CACHE_HINT_STR should be either "evict_first", "evict_normal", or "evict_last"
// L2_PREFETCH_SIZE_STR should be either "64B", "128B", or "256B"
#define KU_LDG_256(global_addr, result, NC_STR, L1_CACHE_HINT_STR, L2_CACHE_HINT_STR, L2_PREFETCH_SIZE_STR) \
{ \
static_assert(std::is_pointer_v<decltype(global_addr)> || std::is_array_v<decltype(global_addr)>, "`global_addr` must be a pointer"); \
static_assert(std::is_pointer_v<decltype(result)> || std::is_array_v<decltype(result)>, "`result` must be a pointer"); \
uint64_t* result_as_uint64_ptr = (uint64_t*)(result); \
asm volatile( \
"ld.global" NC_STR ".L1::" L1_CACHE_HINT_STR ".L2::" L2_CACHE_HINT_STR ".L2::" L2_PREFETCH_SIZE_STR ".v4.u64 {%0, %1, %2, %3}, [%4];\n" \
: "=l"(result_as_uint64_ptr[0]), "=l"(result_as_uint64_ptr[1]), \
"=l"(result_as_uint64_ptr[2]), "=l"(result_as_uint64_ptr[3]) \
: "l"(global_addr) \
); \
}
// STG.256 (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st)
// L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate"
// L2_CACHE_HINT_STR should be either "evict_first", "evict_normal", or "evict_last"
#define KU_STG_256(global_addr, src, L1_CACHE_HINT_STR, L2_CACHE_HINT_STR) \
{ \
static_assert(std::is_pointer_v<decltype(global_addr)> || std::is_array_v<decltype(global_addr)>, "`global_addr` must be a pointer"); \
static_assert(std::is_pointer_v<decltype(src)> || std::is_array_v<decltype(src)>, "`src` must be a pointer"); \
uint64_t const* src_as_uint64_ptr = (uint64_t const*)(src); \
asm volatile( \
"st.global.L1::" L1_CACHE_HINT_STR ".L2::" L2_CACHE_HINT_STR ".v4.u64 [%0], {%1, %2, %3, %4};\n" \
: \
: "l"(global_addr), "l"(src_as_uint64_ptr[0]), "l"(src_as_uint64_ptr[1]), \
"l"(src_as_uint64_ptr[2]), "l"(src_as_uint64_ptr[3]) \
); \
}
}
namespace kerutils {
// tcgen05.commit.cta_group::1 (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
CUTE_DEVICE
void umma_arrive_noelect(transac_bar_t &bar) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];\n"
:
:"r"(bar_intptr)
);
}
// tcgen05.commit.cta_group::1, with multicast (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
CUTE_DEVICE
void umma_arrive_multicast_noelect(transac_bar_t &bar, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\n"
:
:"r"(bar_intptr), "h"(cta_mask)
);
}
// tcgen05.commit.cta_group::2 (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
CUTE_DEVICE
void umma_arrive_2x1SM_noelect(transac_bar_t &bar) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];\n"
:
:"r"(bar_intptr)
);
}
// tcgen05.commit.cta_group::2, with multicast (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
CUTE_DEVICE
void umma_arrive_multicast_2x1SM_noelect(transac_bar_t &bar, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\n"
:
:"r"(bar_intptr), "h"(cta_mask)
);
}
// tcgen05.fence::before_thread_sync (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-special-sync-operations-fence)
__device__ __forceinline__ void tcgen05_before_thread_sync() {
asm volatile("tcgen05.fence::before_thread_sync;");
}
// tcgen05.fence::after_thread_sync (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-special-sync-operations-fence)
__device__ __forceinline__ void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
// Load from tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)
template <int kNumElements>
__device__ __forceinline__
void tmem_ld_32dp32bNx(uint32_t tmem_start, void* data_) {
uint32_t* data = (uint32_t*)data_;
static_assert(kNumElements == 1 || kNumElements == 2 || kNumElements == 4 || kNumElements == 8 || kNumElements == 16 || kNumElements == 32 || kNumElements == 64 || kNumElements == 128, "Invalid kNumElements");
// NOTE The following code crashes VSCode intellisense engine, so we disable it
#ifndef __VSCODE_IDE__
[&]<size_t... Is>(cute::index_sequence<Is...>) {
if constexpr (kNumElements == 1) {
cute::SM100_TMEM_LOAD_32dp32b1x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 2) {
cute::SM100_TMEM_LOAD_32dp32b2x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 4) {
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 8) {
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 16) {
cute::SM100_TMEM_LOAD_32dp32b16x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 32) {
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 64) {
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 128) {
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, data[Is]...);
}
}(cute::make_index_sequence<kNumElements>{});
#endif
}
// Load from tensor memory, 16 data path lanes, 128-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)
template <int kNumReplications>
__device__ __forceinline__
void tmem_ld_16dp128bNx(uint32_t tmem_start, void* data_) {
uint32_t* data = (uint32_t*)data_;
static_assert(kNumReplications == 1 || kNumReplications == 2 || kNumReplications == 4 || kNumReplications == 8 || kNumReplications == 16 || kNumReplications == 32 || kNumReplications == 64, "Invalid kNumReplications");
#ifndef __VSCODE_IDE__
[&]<size_t... Is>(cute::index_sequence<Is...>) {
if constexpr (kNumReplications == 1) {
cute::SM100_TMEM_LOAD_16dp128b1x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 2) {
cute::SM100_TMEM_LOAD_16dp128b2x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 4) {
cute::SM100_TMEM_LOAD_16dp128b4x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 8) {
cute::SM100_TMEM_LOAD_16dp128b8x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 16) {
cute::SM100_TMEM_LOAD_16dp128b16x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 32) {
cute::SM100_TMEM_LOAD_16dp128b32x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 64) {
cute::SM100_TMEM_LOAD_16dp128b64x::copy(tmem_start, data[Is]...);
}
}(cute::make_index_sequence<kNumReplications*2>{});
#endif
}
// Load from tensor memory, 16 data path lanes, 256-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)
template <int kNumReplications>
__device__ __forceinline__
void tmem_ld_16dp256bNx(uint32_t tmem_start, void* data_) {
uint32_t* data = (uint32_t*)data_;
static_assert(kNumReplications == 1 || kNumReplications == 2 || kNumReplications == 4 || kNumReplications == 8 || kNumReplications == 16 || kNumReplications == 32, "Invalid kNumReplications");
#ifndef __VSCODE_IDE__
[&]<size_t... Is>(cute::index_sequence<Is...>) {
if constexpr (kNumReplications == 1) {
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 2) {
cute::SM100_TMEM_LOAD_16dp256b2x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 4) {
cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 8) {
cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 16) {
cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 32) {
cute::SM100_TMEM_LOAD_16dp256b32x::copy(tmem_start, data[Is]...);
}
}(cute::make_index_sequence<kNumReplications*4>{});
#endif
}
// Store into tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st)
template <int kNumElements>
__device__ __forceinline__
void tmem_st_32dp32bNx(uint32_t tmem_start, void const* data_) {
uint32_t const* data = (uint32_t const*)data_;
static_assert(kNumElements == 1 || kNumElements == 2 || kNumElements == 4 || kNumElements == 8 || kNumElements == 16 || kNumElements == 32 || kNumElements == 64 || kNumElements == 128, "Invalid kNumElements");
#ifndef __VSCODE_IDE__
[&]<size_t... Is>(cute::index_sequence<Is...>) {
if constexpr (kNumElements == 1) {
cute::SM100_TMEM_STORE_32dp32b1x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 2) {
cute::SM100_TMEM_STORE_32dp32b2x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 4) {
cute::SM100_TMEM_STORE_32dp32b4x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 8) {
cute::SM100_TMEM_STORE_32dp32b8x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 16) {
cute::SM100_TMEM_STORE_32dp32b16x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 32) {
cute::SM100_TMEM_STORE_32dp32b32x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 64) {
cute::SM100_TMEM_STORE_32dp32b64x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 128) {
cute::SM100_TMEM_STORE_32dp32b128x::copy(data[Is]..., tmem_start);
}
}(cute::make_index_sequence<kNumElements>{});
#endif
}
}
......@@ -2,15 +2,16 @@
#include <cute/tensor.hpp>
#include <kerutils/device/common.h>
namespace cute {
// Extensions to CuTe
// CuTe's SM100_TMA_2SM_LOAD_1D requires two threads to perform this operation cooperatively (using ThrID = Layout<_2>;), which doesn't fit our use case.
// CuTe's built-in SM100_TMA_2SM_LOAD_1D series requires the number of participating threads to be 2 (using ThrID = Layout<_2>;) and also splits the data, which is really annoying to use, so we modified our own version. Additionally, to keep it consistent with other parts that use SM90 TMA, we made it accept TMA::CacheHintSm90 instead of TMA::CacheHintSm100.
////////////////////////////////////////////////////////////////////////////////////////////////////
/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM100_TMA_2SM_LOAD_1D_NOSPLIT
{
CUTE_HOST_DEVICE static void
......@@ -36,7 +37,6 @@ struct SM100_TMA_2SM_LOAD_1D_NOSPLIT
#endif
}
};
struct SM100_TMA_2SM_LOAD_2D_NOSPLIT
{
CUTE_HOST_DEVICE static void
......@@ -62,7 +62,6 @@ struct SM100_TMA_2SM_LOAD_2D_NOSPLIT
#endif
}
};
struct SM100_TMA_2SM_LOAD_3D_NOSPLIT
{
CUTE_HOST_DEVICE static void
......@@ -88,7 +87,6 @@ struct SM100_TMA_2SM_LOAD_3D_NOSPLIT
#endif
}
};
struct SM100_TMA_2SM_LOAD_4D_NOSPLIT
{
CUTE_HOST_DEVICE static void
......@@ -114,7 +112,6 @@ struct SM100_TMA_2SM_LOAD_4D_NOSPLIT
#endif
}
};
struct SM100_TMA_2SM_LOAD_5D_NOSPLIT
{
CUTE_HOST_DEVICE static void
......@@ -140,7 +137,6 @@ struct SM100_TMA_2SM_LOAD_5D_NOSPLIT
#endif
}
};
struct SM100_TMA_2SM_LOAD_NOSPLIT
{
CUTE_HOST_DEVICE static void
......@@ -178,14 +174,9 @@ struct SM100_TMA_2SM_LOAD_NOSPLIT
{
return SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4);
}
using PREFETCH = typename SM90_TMA_LOAD::PREFETCH;
};
struct SM100_TMA_2SM_LOAD_NOSPLIT_OP : SM100_TMA_2SM_LOAD_NOSPLIT {};
// The non-executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and no tma_mbar
// Use .with(tma_mbar) to construct an executable version
template <class NumBitsPerTMA, class AuxParams_>
......@@ -198,19 +189,16 @@ struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT, NumBitsPerTMA, AuxParams_>
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD_NOSPLIT arguments
TmaDescriptor tma_desc_;
using AuxParams = AuxParams_;
AuxParams aux_params_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
......@@ -221,7 +209,6 @@ struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT, NumBitsPerTMA, AuxParams_>
// We accept multicast_mask here to keep the API for both atoms consistent
return {&tma_desc_, &tma_mbar, static_cast<uint64_t>(cache_hint)};
}
// Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
......@@ -233,16 +220,14 @@ struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT, NumBitsPerTMA, AuxParams_>
// We accept multicast_mask here to keep the API for both atoms consistent
return {new_tma_desc, &tma_mbar, static_cast<uint64_t>(cache_hint)};
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_));
}
// Don't try to execute a copy with SM100_TMA_2SM_LOAD_NOSPLIT before calling .with()
template <class TS, class SLayout,
class TD, class DLayout>
......@@ -251,7 +236,6 @@ struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT, NumBitsPerTMA, AuxParams_>
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) = delete;
};
// The executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and tma_mbar
template <class NumBitsPerTMA>
struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
......@@ -264,18 +248,15 @@ struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD_NOSPLIT arguments
tuple<
TmaDescriptor const*,
uint64_t*, // smem mbarrier
uint64_t // cache hint
> const opargs_;
CUTE_HOST_DEVICE
Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache)
: opargs_(desc, mbar, cache) {}
};
}
#pragma once
#include "kerutils/device/common.h"
#include "kerutils/device/sm80/intrinsics.cuh"
namespace kerutils {
// Retrieve the value of `%smid` and check its range
CUTE_DEVICE
uint32_t get_sm_id_with_range_check(uint32_t num_physical_sms) {
uint32_t sm_id = get_sm_id();
if (!(sm_id < num_physical_sms)) {
trap();
}
return sm_id;
}
#ifndef KU_TRAP_ONLY_DEVICE_ASSERT
#define KU_TRAP_ONLY_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) \
asm("trap;"); \
} while (0)
#endif
// Construct a `float2` from a single `float` by duplicating the value
CUTE_DEVICE
float2 float2float2(const float &x) {
return float2 {x, x};
}
CUTE_DEVICE
void st_shared(void* ptr, __int128_t val) {
asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val));
}
CUTE_DEVICE
void st_shared(void* ptr, float4 val) {
st_shared(ptr, *(__int128_t*)&val);
}
CUTE_DEVICE
__int128_t ld_shared(void* ptr) {
__int128_t val;
asm volatile("ld.shared.b128 %0, [%1];" : "=q"(val) : "l"(__cvta_generic_to_shared(ptr)));
return val;
}
CUTE_DEVICE
float4 ld_shared_float4(void* ptr) {
__int128_t temp = ld_shared(ptr);
return *(float4*)&temp;
}
}
#pragma once
#include "kerutils/device/common.h"
namespace kerutils {
// cp.async.cg (cache global) with prefetch and predicate (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async)
template<PrefetchSize PREFETCH_SIZE=PrefetchSize::B128>
CUTE_DEVICE
void cp_async_cacheglobal(const void* src, void* dst, bool pred=true) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
if constexpr (PREFETCH_SIZE == PrefetchSize::B64) {
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16, %2;\n"
:: "r"(dst_addr),
"l"(src),
"r"(pred?16:0));
} else if constexpr (PREFETCH_SIZE == PrefetchSize::B128) {
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16, %2;\n"
:: "r"(dst_addr),
"l"(src),
"r"(pred?16:0));
} else if constexpr (PREFETCH_SIZE == PrefetchSize::B256) {
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16, %2;\n"
:: "r"(dst_addr),
"l"(src),
"r"(pred?16:0));
} else {
static_assert(PREFETCH_SIZE == PrefetchSize::B64 ||
PREFETCH_SIZE == PrefetchSize::B128 ||
PREFETCH_SIZE == PrefetchSize::B256,
"Unsupported prefetch size for cp_async_cacheglobal.");
}
}
// Create fraction-based cache policy (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-createpolicy)
template<CacheHint PRIMARY_PRIORITY, CacheHint SECONDARY_PRIORITY>
CUTE_DEVICE
int64_t create_fraction_based_cache_policy(float fraction = 1.0f) {
int64_t result;
#define EMIT(PRIMARY_PRIORITY_STR, SECONDARY_PRIORITY_STR) \
asm volatile( \
"createpolicy.fractional.L2::" PRIMARY_PRIORITY_STR ".L2::" SECONDARY_PRIORITY_STR ".b64 %0, %1;\n" \
: "=l"(result) \
: "f"(fraction) \
);
#define EMIT2(PRIMARY_PRIORITY_STR) \
{ \
if constexpr (SECONDARY_PRIORITY == CacheHint::EVICT_FIRST) { \
EMIT(PRIMARY_PRIORITY_STR, "evict_first") \
} else if constexpr (SECONDARY_PRIORITY == CacheHint::EVICT_UNCHANGED) { \
EMIT(PRIMARY_PRIORITY_STR, "evict_unchanged") \
} else { \
static_assert(SECONDARY_PRIORITY == CacheHint::EVICT_FIRST || \
SECONDARY_PRIORITY == CacheHint::EVICT_UNCHANGED, \
"Unsupported secondary cache hint for create_fraction_based_cache_policy."); \
} \
}
if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_FIRST) {
EMIT2("evict_first");
} else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_NORMAL) {
EMIT2("evict_normal");
} else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_LAST) {
EMIT2("evict_last");
} else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_UNCHANGED) {
EMIT2("evict_unchanged");
} else {
static_assert(PRIMARY_PRIORITY == CacheHint::EVICT_FIRST ||
PRIMARY_PRIORITY == CacheHint::EVICT_NORMAL ||
PRIMARY_PRIORITY == CacheHint::EVICT_LAST ||
PRIMARY_PRIORITY == CacheHint::EVICT_UNCHANGED,
"Unsupported primary cache hint for create_fraction_based_cache_policy.");
}
#undef EMIT
#undef EMIT2
return result;
}
// Create a simple cache policy (equivalent to create_fraction_based_cache_policy(1.0f))
// The same as cute::TMA::CacheHintSmXX
template<CacheHint CACHE_HINT>
CUTE_DEVICE
constexpr int64_t create_simple_cache_policy() {
if constexpr (CACHE_HINT == CacheHint::EVICT_FIRST) {
return 0x12F0000000000000; // Result of createpolicy.fractional.L2::evict_first.b64
} else if constexpr (CACHE_HINT == CacheHint::EVICT_NORMAL) {
return 0x1000000000000000; // Copied from CuTe. Unsure about the exact meaning. (TODO Change to 0x16F0000000000000?)
} else if constexpr (CACHE_HINT == CacheHint::EVICT_LAST) {
return 0x14F0000000000000; // Result of createpolicy.fractional.L2::evict_last.b64
} else {
static_assert(CACHE_HINT == CacheHint::EVICT_FIRST ||
CACHE_HINT == CacheHint::EVICT_NORMAL ||
CACHE_HINT == CacheHint::EVICT_LAST,
"Unsupported cache hint for create_simple_cache_policy.");
}
}
// AtomicAdd (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red)
CUTE_DEVICE
void atomicadd_f32_with_policy_and_pred(void* global_addr, const float &data, int64_t cache_policy, uint32_t pred = true) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.eq.u32 p, %3, 1;\n\t"
"@p red.relaxed.gpu.global.add.L2::cache_hint.f32 [%1], %0, %2; \n\t"
"}"
:
: "f"(data),
"l"((int64_t)global_addr), "l"(cache_policy), "r"(pred)
);
}
// Get the id of the current SM
// About %smid (https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-smid): PTX document says that %smid ranges from 0 to %nsmid-1, while "The SM identifier numbering is not guaranteed to be contiguous, so %nsmid may be larger than the physical number of SMs in the device.". However, result shows that, at least for sm90 and sm100f, %nsmid is the number of physical SMs - 1. For the sake of safety, I recommend you to check the return of get_sm_id manually or call `get_sm_id_with_range_check()` defined in `device/sm80/helpers.cuh`.
// Besides, PTX document also says that this number may change due to preemption, but currently this never happens according to [DATEN GELÖSCHT]
CUTE_DEVICE
uint32_t get_sm_id() {
uint32_t ret;
asm volatile("mov.u32 %0, %%smid;\n" : "=r"(ret));
return ret;
}
// trap (https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-trap)
CUTE_DEVICE
void trap() {
asm volatile("trap;\n");
}
// LDG.128 or LDG.128 with non-coherent cache (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld)
// We use macro instead of function here, since we need a multi-level recursive dispatch based on template parameters if using function
// NC_STR should be either "" or ".nc"
// L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate"
// L2_PREFETCH_SIZE_STR should be either "64B", "128B", or "256B"
// L2 cache hint is not supported since it's only supported for LDG.256
#define KU_LDG_128(global_addr, result, NC_STR, L1_CACHE_HINT_STR, L2_PREFETCH_SIZE_STR) \
{ \
static_assert(std::is_pointer_v<decltype(global_addr)> || std::is_array_v<decltype(global_addr)>, "`global_addr` must be a pointer"); \
static_assert(std::is_pointer_v<decltype(result)> || std::is_array_v<decltype(result)>, "`result` must be a pointer"); \
uint64_t* result_as_uint64_ptr = (uint64_t*)(result); \
asm volatile( \
"ld.global" NC_STR ".L1::" L1_CACHE_HINT_STR ".L2::" L2_PREFETCH_SIZE_STR ".v2.u64 {%0, %1}, [%2];\n" \
: "=l"(result_as_uint64_ptr[0]), "=l"(result_as_uint64_ptr[1]) \
: "l"(global_addr) \
); \
}
}
#pragma once
#include <cute/tensor.hpp>
#include "kerutils/device/common.h"
namespace kerutils {
template<
typename TMA,
typename Tensor0,
typename Tensor1
>
CUTE_DEVICE
void launch_tma_copy(
const TMA &tma_copy,
Tensor0 src,
Tensor1 dst,
transac_bar_t &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL
) {
auto thr_tma = tma_copy.get_slice(cute::_0{});
cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint),
thr_tma.partition_S(src),
thr_tma.partition_D(dst)
);
}
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
CUTE_DEVICE
int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);
return row_idx;
}
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in some rows. This function converts the local_elem_idx to the actual col_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
CUTE_DEVICE
int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) {
int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1);
return col_idx;
}
template <bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
CUTE_DEVICE
void wgmma(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC, bool zero_init) {
using namespace cute;
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
warpgroup_arrive();
tiled_mma.accumulate_ = zero_init ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
if constexpr (commit) {
warpgroup_commit_batch();
}
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
CUTE_DEVICE
void wgmma_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {
using namespace cute;
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor sA_frag = thr_mma.partition_fragment_A(sA);
Tensor sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(sA_frag) == size<2>(sB_frag));
warpgroup_fence_operand(rC_frag);
warpgroup_arrive();
tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<2>(sA_frag); ++k) {
cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_fence_operand(rC_frag);
}
template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
CUTE_DEVICE
void wgmma_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {
using namespace cute;
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(rA_frag) == size<2>(sB_frag));
warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));
warpgroup_fence_operand(rC_frag);
warpgroup_arrive();
tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<2>(rA_frag); ++k) {
cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_fence_operand(rC_frag);
warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment