Commit e2e0225c authored by zhanghj2's avatar zhanghj2
Browse files

空kernel可以编译通过

parent 48c6dc42
...@@ -10,3 +10,6 @@ compile_commands.json ...@@ -10,3 +10,6 @@ compile_commands.json
.cache .cache
/dev /dev
/.clangd /.clangd
*.log
**/*_hip*
**/*.hip
...@@ -3,13 +3,11 @@ ...@@ -3,13 +3,11 @@
#include "sparse_fwd.h" #include "sparse_fwd.h"
#include "sparse_decode.h" #include "sparse_decode.h"
#include "dense_decode.h" #include "dense_decode.h"
#include "dense_fwd.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashMLA"; m.doc() = "FlashMLA";
m.def("sparse_decode_fwd", &sparse_attn_decode_interface); m.def("sparse_decode_fwd", &sparse_attn_decode_interface);
m.def("dense_decode_fwd", &dense_attn_decode_interface); m.def("dense_decode_fwd", &dense_attn_decode_interface);
m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface); m.def("sparse_prefill_fwd", &sparse_attn_prefill_interface);
m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun);
m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun);
} }
...@@ -129,8 +129,10 @@ static constexpr std::size_t get_enum_max(){ ...@@ -129,8 +129,10 @@ static constexpr std::size_t get_enum_max(){
return N; return N;
} }
template<typename T> requires std::is_enum_v<T> template<typename T>
static constexpr std::string get_dynamic_enum_name(T value){ static constexpr std::string get_dynamic_enum_name(T value){
static_assert(std::is_enum<T>::value,
"Template parameter T must be an enumeration type");
constexpr std::size_t num = get_enum_max<T>(); constexpr std::size_t num = get_enum_max<T>();
constexpr auto names = []<std::size_t... Is>(std::index_sequence<Is...>){ constexpr auto names = []<std::size_t... Is>(std::index_sequence<Is...>){
return std::array<std::string_view, num>{ return std::array<std::string_view, num>{
...@@ -140,12 +142,30 @@ static constexpr std::string get_dynamic_enum_name(T value){ ...@@ -140,12 +142,30 @@ static constexpr std::string get_dynamic_enum_name(T value){
return (std::string)names[static_cast<std::size_t>(value)]; return (std::string)names[static_cast<std::size_t>(value)];
} }
template<typename T>
class SimpleSpan {
private:
const T* data_;
size_t size_;
public:
constexpr SimpleSpan(const T* data, size_t size) : data_(data), size_(size) {}
constexpr SimpleSpan(const T* begin, const T* end) : data_(begin), size_(end - begin) {}
constexpr const T* data() const { return data_; }
constexpr size_t size() const { return size_; }
constexpr const T* begin() const { return data_; }
constexpr const T* end() const { return data_ + size_; }
constexpr const T& operator[](size_t index) const { return data_[index]; }
};
// A shortcut macro to declare supported features in an implementation class. // A shortcut macro to declare supported features in an implementation class.
#define DECLARE_SUPPORTED_FEATURES(...) \ #define DECLARE_SUPPORTED_FEATURES(...) \
protected: \ protected: \
static constexpr FeatureT features[] = { __VA_ARGS__ }; \ static constexpr FeatureT features[] = { __VA_ARGS__ }; \
constexpr inline std::span<const FeatureT> get_supported_features() const override { \ constexpr inline SimpleSpan<const FeatureT> get_supported_features() const override { \
return features; \ return SimpleSpan<const FeatureT>(features, std::size(features)); \
} }
/* /*
...@@ -168,7 +188,7 @@ protected: ...@@ -168,7 +188,7 @@ protected:
virtual inline void run_(const RunArgT &params, const std::vector<FeatureT> &required_features) = 0; virtual inline void run_(const RunArgT &params, const std::vector<FeatureT> &required_features) = 0;
constexpr virtual inline std::span<const FeatureT> get_supported_features() const = 0; constexpr virtual inline SimpleSpan<const FeatureT> get_supported_features() const = 0;
virtual ~ImplBase() = default; virtual ~ImplBase() = default;
...@@ -224,7 +244,7 @@ public: ...@@ -224,7 +244,7 @@ public:
} }
inline void run(const RunArgT &params, const std::vector<FeatureT> &required_features) { inline void run(const RunArgT &params, const std::vector<FeatureT> &required_features) {
check_if_all_features_are_supported_and_abort(required_features); // check_if_all_features_are_supported_and_abort(required_features);
run_(params, required_features); run_(params, required_features);
} }
}; };
......
#pragma once
#include "common.h"
#include "sm100/prefill/dense/interface.h"
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
#include "params.h" #include "params.h"
#include "sm90/decode/sparse_fp8/splitkv_mla.h" #include "sm90/decode/sparse_fp8/splitkv_mla.h"
#include "sm100/decode/head64/kernel.h"
#include "sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h"
#include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h" #include "smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "smxx/decode/combine/combine.h" #include "smxx/decode/combine/combine.h"
...@@ -75,111 +73,6 @@ protected: ...@@ -75,111 +73,6 @@ protected:
} }
}; };
class Decode_Sm100_Head64_Impl : public DecodeImplBase {
DECLARE_SUPPORTED_FEATURES(
DecodeFeatures::HEAD_64,
DecodeFeatures::HEAD_DIM_512,
DecodeFeatures::HEAD_DIM_576,
DecodeFeatures::V32_KVCACHE_FORMAT,
DecodeFeatures::MODEL1_KVCACHE_FORMAT,
DecodeFeatures::ATTN_SINK,
DecodeFeatures::TOPK_LENGTH,
DecodeFeatures::EXTRA_KVCACHE,
DecodeFeatures::EXTRA_TOPK_LENGTH
)
public:
DecodeImplMeta get_meta(int h_q, int s_q) override {
Arch arch = Arch();
return {
std::max(arch.num_sms / s_q, 1),
5,
64
};
}
protected:
void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() {
sm100::decode::head64::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE>(params);
});
}
};
// An implementation that calls the head64 kernel twice to process head128
// Necessary for running V3.2 shape (i.e. h = 128, d_qk = 576) on SM100f
class Decode_Sm100_Head64x2_Impl : public DecodeImplBase {
DECLARE_SUPPORTED_FEATURES(
DecodeFeatures::HEAD_128,
DecodeFeatures::HEAD_DIM_512,
DecodeFeatures::HEAD_DIM_576,
DecodeFeatures::V32_KVCACHE_FORMAT,
DecodeFeatures::MODEL1_KVCACHE_FORMAT,
DecodeFeatures::ATTN_SINK,
DecodeFeatures::TOPK_LENGTH,
DecodeFeatures::EXTRA_KVCACHE,
DecodeFeatures::EXTRA_TOPK_LENGTH
)
public:
DecodeImplMeta get_meta(int h_q, int s_q) override {
Arch arch = Arch();
return {
std::max(arch.num_sms / s_q, 1),
5,
64
};
}
protected:
void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_MODEL_TYPE(params.model_type, MODEL_TYPE, [&]() {
for (int start_head_idx = 0; start_head_idx < 128; start_head_idx += 64) {
SparseAttnDecodeParams cur_params = params;
cur_params.q += start_head_idx * params.stride_q_h_q;
if (cur_params.attn_sink) {
cur_params.attn_sink += start_head_idx;
}
cur_params.lse += start_head_idx;
cur_params.out += start_head_idx * params.stride_o_h_q;
cur_params.lse_accum += start_head_idx;
cur_params.o_accum += start_head_idx * params.stride_o_accum_h_q;
cur_params.h_q = 64;
sm100::decode::head64::run_flash_splitkv_mla_fp8_sparse_kernel<MODEL_TYPE>(cur_params);
}
});
}
};
class Decode_Sm100_Head128_Impl : public DecodeImplBase {
DECLARE_SUPPORTED_FEATURES(
DecodeFeatures::HEAD_128,
DecodeFeatures::HEAD_DIM_512,
DecodeFeatures::MODEL1_KVCACHE_FORMAT,
DecodeFeatures::ATTN_SINK,
DecodeFeatures::TOPK_LENGTH,
DecodeFeatures::EXTRA_KVCACHE,
DecodeFeatures::EXTRA_TOPK_LENGTH
)
public:
DecodeImplMeta get_meta(int h_q, int s_q) override {
Arch arch = Arch();
return {
std::max(arch.num_sms / s_q / 2, 1),
3,
64
};
}
protected:
void run_(const SparseAttnDecodeParams &params, const std::vector<FeatureT> &required_features) override {
sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::DecodeWithSplitKV, 512>(params);
}
};
static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>> static std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
sparse_attn_decode_interface( sparse_attn_decode_interface(
const at::Tensor &q, // [b, s_q, h_q, d_qk] const at::Tensor &q, // [b, s_q, h_q, d_qk]
...@@ -360,21 +253,7 @@ sparse_attn_decode_interface( ...@@ -360,21 +253,7 @@ sparse_attn_decode_interface(
} }
DecodeImplBase* impl; DecodeImplBase* impl;
if (arch.is_sm100f()) { if (arch.is_sm90a()) {
if (h_q == 64) {
impl = new Decode_Sm100_Head64_Impl();
} else if (h_q == 128) {
if (d_qk == 576) {
impl = new Decode_Sm100_Head64x2_Impl();
} else if (d_qk == 512) {
impl = new Decode_Sm100_Head128_Impl();
} else {
TORCH_CHECK(false, "Unsupported d_qk: ", d_qk);
}
} else {
TORCH_CHECK(false, "Unsupported h_q: ", h_q);
}
} else if (arch.is_sm90a()) {
impl = new Decode_Sm90_Impl(); impl = new Decode_Sm90_Impl();
} else { } else {
TORCH_CHECK(false, "Unsupported architecture for sparse decode fwd"); TORCH_CHECK(false, "Unsupported architecture for sparse decode fwd");
......
...@@ -4,10 +4,8 @@ ...@@ -4,10 +4,8 @@
#include "params.h" #include "params.h"
#include "sm90/prefill/sparse/phase1.h" // #include "sm90/prefill/sparse/phase1.h"
#include "sm100/prefill/sparse/fwd/head128/phase1.h"
#include "sm100/prefill/sparse/fwd/head64/phase1.h"
#include "sm100/prefill/sparse/fwd_for_small_topk/head128/phase1.h"
enum class FwdFeatures : int { enum class FwdFeatures : int {
HEAD_64, HEAD_64,
...@@ -41,7 +39,7 @@ protected: ...@@ -41,7 +39,7 @@ protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override { void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() { DISPATCH_BOOLEAN_FLAG(params.topk_length != nullptr, HAVE_TOPK_LENGTH, [&]() {
sm90::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params); // sm90::fwd::run_fwd_phase1_kernel<HEAD_DIM_QK, HAVE_TOPK_LENGTH>(params);
}); });
}); });
} }
...@@ -60,7 +58,7 @@ class Fwd_Sm100_Head64_Impl : public FwdImplBase { ...@@ -60,7 +58,7 @@ class Fwd_Sm100_Head64_Impl : public FwdImplBase {
protected: protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override { void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
sm100::fwd::head64::run_fwd_phase1_kernel<HEAD_DIM_QK>(params); // sm100::fwd::head64::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);
}); });
} }
}; };
...@@ -78,7 +76,7 @@ class Fwd_Sm100_Head128_Impl : public FwdImplBase { ...@@ -78,7 +76,7 @@ class Fwd_Sm100_Head128_Impl : public FwdImplBase {
protected: protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override { void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() { DISPATCH_HEAD_DIM(params.d_qk, HEAD_DIM_QK, [&]() {
sm100::fwd::head128::run_fwd_phase1_kernel<HEAD_DIM_QK>(params); // sm100::fwd::head128::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);
}); });
} }
}; };
...@@ -94,7 +92,7 @@ class Fwd_Sm100_Head128_Small_TopK_Impl : public FwdImplBase { ...@@ -94,7 +92,7 @@ class Fwd_Sm100_Head128_Small_TopK_Impl : public FwdImplBase {
protected: protected:
void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override { void run_(const SparseAttnFwdParams &params, const std::vector<FeatureT> &required_features) override {
sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::Prefill, 512>(params); // sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::Prefill, 512>(params);
} }
}; };
...@@ -210,34 +208,34 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface( ...@@ -210,34 +208,34 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
required_features.push_back(FwdFeatures::TOPK_LENGTH); required_features.push_back(FwdFeatures::TOPK_LENGTH);
} }
if (is_sm90a) { // if (is_sm90a) {
Fwd_Sm90_Impl fwd_impl; // Fwd_Sm90_Impl fwd_impl;
fwd_impl.run(params, required_features); // fwd_impl.run(params, required_features);
} else if (is_sm100f) { // } else if (is_sm100f) {
if (h_q == 64) { // if (h_q == 64) {
Fwd_Sm100_Head64_Impl fwd_impl; // Fwd_Sm100_Head64_Impl fwd_impl;
fwd_impl.run(params, required_features); // fwd_impl.run(params, required_features);
} else if (h_q == 128) { // } else if (h_q == 128) {
Fwd_Sm100_Head128_Small_TopK_Impl small_topk_impl; // Fwd_Sm100_Head128_Small_TopK_Impl small_topk_impl;
Fwd_Sm100_Head128_Impl regular_impl; // Fwd_Sm100_Head128_Impl regular_impl;
bool use_small_topk_impl = false; // bool use_small_topk_impl = false;
if ( // if (
(topk <= 1280 && small_topk_impl.check_if_all_features_are_supported(required_features)) || // (topk <= 1280 && small_topk_impl.check_if_all_features_are_supported(required_features)) ||
!regular_impl.check_if_all_features_are_supported(required_features) // !regular_impl.check_if_all_features_are_supported(required_features)
) { // ) {
use_small_topk_impl = true; // use_small_topk_impl = true;
} // }
if (use_small_topk_impl) { // if (use_small_topk_impl) {
small_topk_impl.run(params, required_features); // small_topk_impl.run(params, required_features);
} else { // } else {
regular_impl.run(params, required_features); // regular_impl.run(params, required_features);
} // }
} else { // } else {
TORCH_CHECK(false, "Unsupported h_q: ", h_q); // TORCH_CHECK(false, "Unsupported h_q: ", h_q);
} // }
} else { // } else {
TORCH_CHECK(false, "Unsupported architecture"); // TORCH_CHECK(false, "Unsupported architecture");
} // }
return {out, max_logits, lse}; return {out, max_logits, lse};
} }
#pragma once #pragma once
#include <cutlass/bfloat16.h> #include <cutlass/bfloat16.h>
#include <cutlass/arch/barrier.h> // #include <cutlass/arch/barrier.h>
using bf16 = cutlass::bfloat16_t; using bf16 = cutlass::bfloat16_t;
using fp8 = cutlass::float_e4m3_t; using fp8 = cutlass::float_e4m3_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; // using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
using cutlass::arch::fence_view_async_shared; // using cutlass::arch::fence_view_async_shared;
using cutlass::arch::fence_barrier_init; // using cutlass::arch::fence_barrier_init;
using cutlass::arch::NamedBarrier; // using cutlass::arch::NamedBarrier;
struct int32x8_t { // struct int32x8_t {
int a0, a1, a2, a3, a4, a5, a6, a7; // int a0, a1, a2, a3, a4, a5, a6, a7;
}; // };
struct float8 { // struct float8 {
float2 a01, a23, a45, a67; // // float2 a01, a23, a45, a67;
}; // };
struct bf16x8 { // struct bf16x8 {
__nv_bfloat162 a01; // // __nv_bfloat162 a01;
__nv_bfloat162 a23; // // __nv_bfloat162 a23;
__nv_bfloat162 a45; // // __nv_bfloat162 a45;
__nv_bfloat162 a67; // // __nv_bfloat162 a67;
}; // };
...@@ -3,68 +3,68 @@ Common data types and macros that are used across the kerutils library. ...@@ -3,68 +3,68 @@ Common data types and macros that are used across the kerutils library.
*/ */
#pragma once #pragma once
#include <cuda_bf16.h> // #include <cuda_bf16.h>
#include <cuda_fp8.h> // #include <cuda_fp8.h>
#include <cutlass/bfloat16.h> #include <cutlass/bfloat16.h>
#include <cutlass/arch/barrier.h> // #include <cutlass/arch/barrier.h>
#include <cute/config.hpp> // For CUTE_DEVICE #include <cute/config.hpp> // For CUTE_DEVICE
namespace kerutils { namespace kerutils {
// Cache hints // // Cache hints
enum class CacheHint { // enum class CacheHint {
EVICT_FIRST, // EVICT_FIRST,
EVICT_NORMAL, // EVICT_NORMAL,
EVICT_LAST, // EVICT_LAST,
EVICT_UNCHANGED, // EVICT_UNCHANGED,
NO_ALLOCATE // NO_ALLOCATE
}; // };
// Prefetch size // // Prefetch size
enum class PrefetchSize { // enum class PrefetchSize {
B64, // B64,
B128, // B128,
B256 // B256
}; // };
using nvbf16 = __nv_bfloat16; // using nvbf16 = __nv_bfloat16;
using nvbf16x2 = __nv_bfloat162; // using nvbf16x2 = __nv_bfloat162;
using nve4m3 = __nv_fp8_e4m3; // using nve4m3 = __nv_fp8_e4m3;
using nve4m3x2 = __nv_fp8x2_e4m3; // using nve4m3x2 = __nv_fp8x2_e4m3;
using nve4m3x4 = __nv_fp8x4_e4m3; // using nve4m3x4 = __nv_fp8x4_e4m3;
using bf16 = cutlass::bfloat16_t; using bf16 = cutlass::bfloat16_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; // using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
} }
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) // #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#define KERUTILS_ENABLE_SM80 // #define KERUTILS_ENABLE_SM80
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) // #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
static_assert(false, "kerutils doesn't support SM architectures below SM80"); // static_assert(false, "kerutils doesn't support SM architectures below SM80");
#endif // #endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
#define KERUTILS_ENABLE_SM90 // #define KERUTILS_ENABLE_SM90
#endif // #endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000)) // #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000))
#define KERUTILS_ENABLE_SM90A // #define KERUTILS_ENABLE_SM90A
#endif // #endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) // #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
#define KERUTILS_ENABLE_SM100 // #define KERUTILS_ENABLE_SM100
#endif // #endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200)) // #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200))
#define KERUTILS_ENABLE_SM100A // #define KERUTILS_ENABLE_SM100A
#endif // #endif
#if (defined(__CLION_IDE__) || defined(__VSCODE_IDE__)) // #if (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))
#define KERUTILS_ENABLE_SM80 // #define KERUTILS_ENABLE_SM80
#define KERUTILS_ENABLE_SM90 // #define KERUTILS_ENABLE_SM90
#define KERUTILS_ENABLE_SM90A // #define KERUTILS_ENABLE_SM90A
#define KERUTILS_ENABLE_SM100 // #define KERUTILS_ENABLE_SM100
#define KERUTILS_ENABLE_SM100A // #define KERUTILS_ENABLE_SM100A
#endif // #endif
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
#include "kerutils/common/common.h" #include "kerutils/common/common.h"
#include "common.h" #include "common.h"
#include "sm80/intrinsics.cuh" // #include "sm80/intrinsics.cuh"
#include "sm80/helpers.cuh" // #include "sm80/helpers.cuh"
#include "sm90/intrinsics.cuh" // #include "sm90/intrinsics.cuh"
#include "sm90/helpers.cuh" // #include "sm90/helpers.cuh"
#include "sm100/intrinsics.cuh" // #include "sm100/intrinsics.cuh"
#include "sm100/helpers.cuh" // #include "sm100/helpers.cuh"
#include "sm100/gemm.cuh" // #include "sm100/gemm.cuh"
#include "sm100/tma_cta_group2_nosplit.cuh" // #include "sm100/tma_cta_group2_nosplit.cuh"
This diff is collapsed.
#pragma once
#include <cute/tensor.hpp>
#include "kerutils/device/common.h"
namespace kerutils {
// Perform SS UTCMMA
// sA and sB should be shared memory tensors (i.e. make_tensor(make_shared_ptr(XXX), XXX)) while tC_frag should be tmem fragment
template<
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE
void utcmma_ss(
TiledMMA &tiled_mma,
TensorA sA,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
using namespace cute;
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sA_frag = thr_mma.partition_fragment_A(sA);
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(sA_frag) == size<2>(sB_frag));
static_assert(size<1>(sA_frag) == size<1>(tC_frag));
static_assert(size<1>(sB_frag) == size<2>(tC_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(sA_frag); ++k) {
cute::gemm(
tiled_mma,
sA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
}
// Perform TS UTCMMA
// sB should be shared memory tensors (i.e. make_tensor(make_shared_ptr(XXX), XXX)) while tA_frag and tC_frag should be tmem fragment
template<
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE
void utcmma_ts(
TiledMMA &tiled_mma,
TensorA tA_frag,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
using namespace cute;
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(tA_frag) == size<2>(sB_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(tA_frag); ++k) {
cute::gemm(
tiled_mma,
tA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
}
template<int MN, int K, int SWIZZLE, typename T = bf16>
static constexpr auto make_umma_canonical_k_major_layout() {
using namespace cute;
using base_atom_type = \
std::conditional_t<SWIZZLE == 0 || SWIZZLE == 16,
UMMA::Layout_K_INTER_Atom<T>,
std::conditional_t<SWIZZLE == 32,
UMMA::Layout_K_SW32_Atom<T>,
std::conditional_t<SWIZZLE == 64,
UMMA::Layout_K_SW64_Atom<T>,
std::conditional_t<SWIZZLE == 128,
UMMA::Layout_K_SW128_Atom<T>,
void
>
>
>
>;
static_assert(!std::is_same_v<base_atom_type, void>, "Invalid SWIZZLE value");
return coalesce(tile_to_shape(
base_atom_type{},
Shape<Int<MN>, Int<K>>{},
Step<_1, _2>{}
), Shape<_1, _1>{});
}
template<int MN, int K, int SWIZZLE, typename T = bf16>
static constexpr auto make_umma_canonical_mn_major_layout() {
using namespace cute;
using base_atom_type = \
std::conditional_t<SWIZZLE == 0 || SWIZZLE == 16,
UMMA::Layout_MN_INTER_Atom<T>,
std::conditional_t<SWIZZLE == 32,
UMMA::Layout_MN_SW32_Atom<T>,
std::conditional_t<SWIZZLE == 64,
UMMA::Layout_MN_SW64_Atom<T>,
std::conditional_t<SWIZZLE == 128,
UMMA::Layout_MN_SW128_Atom<T>,
void
>
>
>
>;
static_assert(!std::is_same_v<base_atom_type, void>, "Invalid SWIZZLE value");
return coalesce(tile_to_shape(
base_atom_type{},
Shape<Int<MN>, Int<K>>{},
Step<_2, _1>{}
), Shape<_1, _1>{});
}
template<cute::UMMA::Major MAJOR, int MN, int K, int SWIZZLE, typename T = bf16>
auto make_umma_canonical_layout() {
if constexpr (MAJOR == cute::UMMA::Major::K) {
return make_umma_canonical_k_major_layout<MN, K, SWIZZLE, T>();
} else {
return make_umma_canonical_mn_major_layout<MN, K, SWIZZLE, T>();
}
}
}
#pragma once
#include "kerutils/device/common.h"
namespace kerutils {
// tma gather4 (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
// Please pay attention that the coordinates of TMA gather4 are int32, which may lead to overflow under some scenarios
CUTE_DEVICE
void tma_gather4(const void* desc_ptr, transac_bar_t &mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar_ptr);
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbar_addr), "l"(cache_hint)
: "memory"
);
}
// tma gather4 prefetch (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor)
// Please pay attention that the coordinates of TMA gather4 are int32, which may lead to overflow under some scenarios
CUTE_DEVICE
void tma_gather4_prefetch(const void* desc_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) {
asm volatile(
"cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4.L2::cache_hint [%0, {%1, %2, %3, %4, %5}], %6;\n"
:
: "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"l"(cache_hint)
);
}
// tma gather4 with cta_group::2, allowing for synchronization across CTAs within a pair of CTAs (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
template<bool USE_CTA0_MBAR = false>
CUTE_DEVICE void tma_gather4_cta_group_2(const void* desc_ptr, transac_bar_t &mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, int64_t cache_hint) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar_ptr);
if constexpr (USE_CTA0_MBAR) {
mbar_addr &= cute::Sm100MmaPeerBitMask;
}
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbar_addr), "l"(cache_hint)
: "memory"
);
}
// Vectorized addition for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add)
CUTE_DEVICE
float2 float2_add(const float2 &a, const float2 &b) {
float2 c;
asm volatile(
"add.f32x2 %0, %1, %2;\n"
: "=l"(reinterpret_cast<uint64_t&>(c))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b))
);
return c;
}
// Vectorized multiplication for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-mul)
CUTE_DEVICE
float2 float2_mul(const float2 &a, const float2 &b) {
float2 c;
asm volatile(
"mul.f32x2 %0, %1, %2;\n"
: "=l"(reinterpret_cast<uint64_t&>(c))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)));
return c;
}
// Vectorized fused addition-multiplication for float32 (https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-fma)
CUTE_DEVICE
float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) {
// return a*b+c
float2 d;
asm volatile(
"fma.rn.f32x2 %0, %1, %2, %3;\n"
: "=l"(reinterpret_cast<uint64_t&>(d))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)),
"l"(reinterpret_cast<uint64_t const&>(c)));
return d;
}
// Vectorized negation for foat32
CUTE_DEVICE
float2 float2_neg(const float2 &a) {
float2 t = {-1.0f, -1.0f};
return float2_mul(a, t);
}
// st.bulk (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-bulk)
CUTE_DEVICE
void st_bulk(void* dst_ptr, int64_t size) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
asm volatile (
"st.bulk.weak.shared::cta [%0], %1, 0;\n"
:
: "r"(dst_addr), "l"(size)
: "memory"
);
}
struct CUTE_ALIGNAS(16) CLCResponseObj {
// An opaque 16B value
char opaque[16];
};
struct CLCResult {
int is_valid;
int x, y, z;
};
// Issue a CLC try_cancel query (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel)
CUTE_DEVICE
void issue_clc_query(transac_bar_t &bar, CLCResponseObj &response_obj) {
uint32_t response_addr = cute::cast_smem_ptr_to_uint(response_obj.opaque);
uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.b128 [%0], [%1];\n"
:
: "r"(response_addr), "r"(mbarrier_addr)
);
}
// Issue a CLC try_cancel query with .multicast::cluster::all (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel)
CUTE_DEVICE
void issue_clc_query_multicast_cluster_all(transac_bar_t &bar, CLCResponseObj &response_obj) {
uint32_t response_addr = cute::cast_smem_ptr_to_uint(response_obj.opaque);
uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128 [%0], [%1];\n"
:
: "r"(response_addr), "r"(mbarrier_addr)
);
}
// Get the result of a CLC query (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-query-cancel)
// In this function, we separate get_first_ctaid::x/y/z and hope PTXAS's dead code elimination can remove unnecessary instructions
template<bool USE_LD_ACQUIRE>
CUTE_DEVICE
CLCResult get_clc_query_response(CLCResponseObj &response_obj) {
uint32_t response_addr = cute::cast_smem_ptr_to_uint(&response_obj);
CLCResult result;
#define EMIT_ASM(LD_MODIFIER) \
asm volatile( \
"{\n" \
".reg .pred p1;\n\t" \
".reg .b128 clc_result;\n\t" \
"ld" LD_MODIFIER ".shared.b128 clc_result, [%4];\n\t" \
"clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p1, clc_result;\n\t" \
"selp.u32 %3, 1, 0, p1;\n\t" \
"@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128 %0, clc_result;\n\t" \
"@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::y.b32.b128 %1, clc_result;\n\t" \
"@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid::z.b32.b128 %2, clc_result;\n\t" \
"}\n" \
: "=r"(result.x), "=r"(result.y), "=r"(result.z), "=r"(result.is_valid) \
: "r"(response_addr) \
: "memory" \
);
if constexpr (USE_LD_ACQUIRE) {
EMIT_ASM(".acquire.cta");
} else {
EMIT_ASM("");
}
return result;
}
// LDG.256 or LDG.256 with non-coherent cache (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld)
// We use macro instead of function here, since we need a multi-level recursive dispatch based on template parameters if using function
// NC_STR should be either "" or ".nc"
// L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate"
// L2_CACHE_HINT_STR should be either "evict_first", "evict_normal", or "evict_last"
// L2_PREFETCH_SIZE_STR should be either "64B", "128B", or "256B"
#define KU_LDG_256(global_addr, result, NC_STR, L1_CACHE_HINT_STR, L2_CACHE_HINT_STR, L2_PREFETCH_SIZE_STR) \
{ \
static_assert(std::is_pointer_v<decltype(global_addr)> || std::is_array_v<decltype(global_addr)>, "`global_addr` must be a pointer"); \
static_assert(std::is_pointer_v<decltype(result)> || std::is_array_v<decltype(result)>, "`result` must be a pointer"); \
uint64_t* result_as_uint64_ptr = (uint64_t*)(result); \
asm volatile( \
"ld.global" NC_STR ".L1::" L1_CACHE_HINT_STR ".L2::" L2_CACHE_HINT_STR ".L2::" L2_PREFETCH_SIZE_STR ".v4.u64 {%0, %1, %2, %3}, [%4];\n" \
: "=l"(result_as_uint64_ptr[0]), "=l"(result_as_uint64_ptr[1]), \
"=l"(result_as_uint64_ptr[2]), "=l"(result_as_uint64_ptr[3]) \
: "l"(global_addr) \
); \
}
// STG.256 (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st)
// L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate"
// L2_CACHE_HINT_STR should be either "evict_first", "evict_normal", or "evict_last"
#define KU_STG_256(global_addr, src, L1_CACHE_HINT_STR, L2_CACHE_HINT_STR) \
{ \
static_assert(std::is_pointer_v<decltype(global_addr)> || std::is_array_v<decltype(global_addr)>, "`global_addr` must be a pointer"); \
static_assert(std::is_pointer_v<decltype(src)> || std::is_array_v<decltype(src)>, "`src` must be a pointer"); \
uint64_t const* src_as_uint64_ptr = (uint64_t const*)(src); \
asm volatile( \
"st.global.L1::" L1_CACHE_HINT_STR ".L2::" L2_CACHE_HINT_STR ".v4.u64 [%0], {%1, %2, %3, %4};\n" \
: \
: "l"(global_addr), "l"(src_as_uint64_ptr[0]), "l"(src_as_uint64_ptr[1]), \
"l"(src_as_uint64_ptr[2]), "l"(src_as_uint64_ptr[3]) \
); \
}
}
namespace kerutils {
// tcgen05.commit.cta_group::1 (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
CUTE_DEVICE
void umma_arrive_noelect(transac_bar_t &bar) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];\n"
:
:"r"(bar_intptr)
);
}
// tcgen05.commit.cta_group::1, with multicast (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
CUTE_DEVICE
void umma_arrive_multicast_noelect(transac_bar_t &bar, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\n"
:
:"r"(bar_intptr), "h"(cta_mask)
);
}
// tcgen05.commit.cta_group::2 (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
CUTE_DEVICE
void umma_arrive_2x1SM_noelect(transac_bar_t &bar) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];\n"
:
:"r"(bar_intptr)
);
}
// tcgen05.commit.cta_group::2, with multicast (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen-async-sync-operations-commit)
CUTE_DEVICE
void umma_arrive_multicast_2x1SM_noelect(transac_bar_t &bar, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&bar);
asm volatile(
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\n"
:
:"r"(bar_intptr), "h"(cta_mask)
);
}
// tcgen05.fence::before_thread_sync (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-special-sync-operations-fence)
__device__ __forceinline__ void tcgen05_before_thread_sync() {
asm volatile("tcgen05.fence::before_thread_sync;");
}
// tcgen05.fence::after_thread_sync (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-special-sync-operations-fence)
__device__ __forceinline__ void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
// Load from tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)
template <int kNumElements>
__device__ __forceinline__
void tmem_ld_32dp32bNx(uint32_t tmem_start, void* data_) {
uint32_t* data = (uint32_t*)data_;
static_assert(kNumElements == 1 || kNumElements == 2 || kNumElements == 4 || kNumElements == 8 || kNumElements == 16 || kNumElements == 32 || kNumElements == 64 || kNumElements == 128, "Invalid kNumElements");
// NOTE The following code crashes VSCode intellisense engine, so we disable it
#ifndef __VSCODE_IDE__
[&]<size_t... Is>(cute::index_sequence<Is...>) {
if constexpr (kNumElements == 1) {
cute::SM100_TMEM_LOAD_32dp32b1x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 2) {
cute::SM100_TMEM_LOAD_32dp32b2x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 4) {
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 8) {
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 16) {
cute::SM100_TMEM_LOAD_32dp32b16x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 32) {
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 64) {
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumElements == 128) {
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, data[Is]...);
}
}(cute::make_index_sequence<kNumElements>{});
#endif
}
// Load from tensor memory, 16 data path lanes, 128-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)
template <int kNumReplications>
__device__ __forceinline__
void tmem_ld_16dp128bNx(uint32_t tmem_start, void* data_) {
uint32_t* data = (uint32_t*)data_;
static_assert(kNumReplications == 1 || kNumReplications == 2 || kNumReplications == 4 || kNumReplications == 8 || kNumReplications == 16 || kNumReplications == 32 || kNumReplications == 64, "Invalid kNumReplications");
#ifndef __VSCODE_IDE__
[&]<size_t... Is>(cute::index_sequence<Is...>) {
if constexpr (kNumReplications == 1) {
cute::SM100_TMEM_LOAD_16dp128b1x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 2) {
cute::SM100_TMEM_LOAD_16dp128b2x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 4) {
cute::SM100_TMEM_LOAD_16dp128b4x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 8) {
cute::SM100_TMEM_LOAD_16dp128b8x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 16) {
cute::SM100_TMEM_LOAD_16dp128b16x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 32) {
cute::SM100_TMEM_LOAD_16dp128b32x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 64) {
cute::SM100_TMEM_LOAD_16dp128b64x::copy(tmem_start, data[Is]...);
}
}(cute::make_index_sequence<kNumReplications*2>{});
#endif
}
// Load from tensor memory, 16 data path lanes, 256-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld)
template <int kNumReplications>
__device__ __forceinline__
void tmem_ld_16dp256bNx(uint32_t tmem_start, void* data_) {
uint32_t* data = (uint32_t*)data_;
static_assert(kNumReplications == 1 || kNumReplications == 2 || kNumReplications == 4 || kNumReplications == 8 || kNumReplications == 16 || kNumReplications == 32, "Invalid kNumReplications");
#ifndef __VSCODE_IDE__
[&]<size_t... Is>(cute::index_sequence<Is...>) {
if constexpr (kNumReplications == 1) {
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 2) {
cute::SM100_TMEM_LOAD_16dp256b2x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 4) {
cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 8) {
cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 16) {
cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start, data[Is]...);
} else if constexpr (kNumReplications == 32) {
cute::SM100_TMEM_LOAD_16dp256b32x::copy(tmem_start, data[Is]...);
}
}(cute::make_index_sequence<kNumReplications*4>{});
#endif
}
// Store into tensor memory, 32 data path lanes, 32-bit pattern, repeated N times. (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st)
template <int kNumElements>
__device__ __forceinline__
void tmem_st_32dp32bNx(uint32_t tmem_start, void const* data_) {
uint32_t const* data = (uint32_t const*)data_;
static_assert(kNumElements == 1 || kNumElements == 2 || kNumElements == 4 || kNumElements == 8 || kNumElements == 16 || kNumElements == 32 || kNumElements == 64 || kNumElements == 128, "Invalid kNumElements");
#ifndef __VSCODE_IDE__
[&]<size_t... Is>(cute::index_sequence<Is...>) {
if constexpr (kNumElements == 1) {
cute::SM100_TMEM_STORE_32dp32b1x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 2) {
cute::SM100_TMEM_STORE_32dp32b2x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 4) {
cute::SM100_TMEM_STORE_32dp32b4x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 8) {
cute::SM100_TMEM_STORE_32dp32b8x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 16) {
cute::SM100_TMEM_STORE_32dp32b16x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 32) {
cute::SM100_TMEM_STORE_32dp32b32x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 64) {
cute::SM100_TMEM_STORE_32dp32b64x::copy(data[Is]..., tmem_start);
} else if constexpr (kNumElements == 128) {
cute::SM100_TMEM_STORE_32dp32b128x::copy(data[Is]..., tmem_start);
}
}(cute::make_index_sequence<kNumElements>{});
#endif
}
}
#pragma once
#include <cute/tensor.hpp>
#include <kerutils/device/common.h>
namespace cute {
// Extensions to CuTe
// CuTe's built-in SM100_TMA_2SM_LOAD_1D series requires the number of participating threads to be 2 (using ThrID = Layout<_2>;) and also splits the data, which is really annoying to use, so we modified our own version. Additionally, to keep it consistent with other parts that use SM90 TMA, we made it accept TMA::CacheHintSm90 instead of TMA::CacheHintSm100.
////////////////////////////////////////////////////////////////////////////////////////////////////
/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM100_TMA_2SM_LOAD_1D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
[[maybe_unused]] void * smem_ptr,
[[maybe_unused]] int32_t const& crd0)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3}], [%2], %4;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_2D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
[[maybe_unused]] void * smem_ptr,
[[maybe_unused]] int32_t const& crd0, int32_t const& crd1)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4}], [%2], %5;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_3D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
[[maybe_unused]] void * smem_ptr,
[[maybe_unused]] int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5}], [%2], %6;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_4D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6}], [%2], %7;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_5D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0)
{
return SM100_TMA_2SM_LOAD_1D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1)
{
return SM100_TMA_2SM_LOAD_2D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
return SM100_TMA_2SM_LOAD_3D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
return SM100_TMA_2SM_LOAD_4D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
return SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4);
}
using PREFETCH = typename SM90_TMA_LOAD::PREFETCH;
};
struct SM100_TMA_2SM_LOAD_NOSPLIT_OP : SM100_TMA_2SM_LOAD_NOSPLIT {};
// The non-executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and no tma_mbar
// Use .with(tma_mbar) to construct an executable version
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD_NOSPLIT arguments
TmaDescriptor tma_desc_;
using AuxParams = AuxParams_;
AuxParams aux_params_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
with(
uint64_t& tma_mbar,
[[maybe_unused]] uint16_t const& multicast_mask = 0,
TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {&tma_desc_, &tma_mbar, static_cast<uint64_t>(cache_hint)};
}
// Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
with(
TmaDescriptor const* new_tma_desc,
uint64_t& tma_mbar,
[[maybe_unused]] uint16_t const& multicast_mask = 0,
TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {new_tma_desc, &tma_mbar, static_cast<uint64_t>(cache_hint)};
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_));
}
// Don't try to execute a copy with SM100_TMA_2SM_LOAD_NOSPLIT before calling .with()
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) = delete;
};
// The executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and tma_mbar
template <class NumBitsPerTMA>
struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
: TMA_LOAD_Unpack<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD_NOSPLIT arguments
tuple<
TmaDescriptor const*,
uint64_t*, // smem mbarrier
uint64_t // cache hint
> const opargs_;
CUTE_HOST_DEVICE
Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache)
: opargs_(desc, mbar, cache) {}
};
}
#pragma once
#include "kerutils/device/common.h"
#include "kerutils/device/sm80/intrinsics.cuh"
namespace kerutils {
// Retrieve the value of `%smid` and check its range
CUTE_DEVICE
uint32_t get_sm_id_with_range_check(uint32_t num_physical_sms) {
uint32_t sm_id = get_sm_id();
if (!(sm_id < num_physical_sms)) {
trap();
}
return sm_id;
}
#ifndef KU_TRAP_ONLY_DEVICE_ASSERT
#define KU_TRAP_ONLY_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) \
asm("trap;"); \
} while (0)
#endif
// Construct a `float2` from a single `float` by duplicating the value
CUTE_DEVICE
float2 float2float2(const float &x) {
return float2 {x, x};
}
CUTE_DEVICE
void st_shared(void* ptr, __int128_t val) {
asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val));
}
CUTE_DEVICE
void st_shared(void* ptr, float4 val) {
st_shared(ptr, *(__int128_t*)&val);
}
CUTE_DEVICE
__int128_t ld_shared(void* ptr) {
__int128_t val;
asm volatile("ld.shared.b128 %0, [%1];" : "=q"(val) : "l"(__cvta_generic_to_shared(ptr)));
return val;
}
CUTE_DEVICE
float4 ld_shared_float4(void* ptr) {
__int128_t temp = ld_shared(ptr);
return *(float4*)&temp;
}
}
#pragma once
#include "kerutils/device/common.h"
namespace kerutils {
// cp.async.cg (cache global) with prefetch and predicate (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async)
template<PrefetchSize PREFETCH_SIZE=PrefetchSize::B128>
CUTE_DEVICE
void cp_async_cacheglobal(const void* src, void* dst, bool pred=true) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
if constexpr (PREFETCH_SIZE == PrefetchSize::B64) {
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16, %2;\n"
:: "r"(dst_addr),
"l"(src),
"r"(pred?16:0));
} else if constexpr (PREFETCH_SIZE == PrefetchSize::B128) {
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16, %2;\n"
:: "r"(dst_addr),
"l"(src),
"r"(pred?16:0));
} else if constexpr (PREFETCH_SIZE == PrefetchSize::B256) {
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16, %2;\n"
:: "r"(dst_addr),
"l"(src),
"r"(pred?16:0));
} else {
static_assert(PREFETCH_SIZE == PrefetchSize::B64 ||
PREFETCH_SIZE == PrefetchSize::B128 ||
PREFETCH_SIZE == PrefetchSize::B256,
"Unsupported prefetch size for cp_async_cacheglobal.");
}
}
// Create fraction-based cache policy (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-createpolicy)
template<CacheHint PRIMARY_PRIORITY, CacheHint SECONDARY_PRIORITY>
CUTE_DEVICE
int64_t create_fraction_based_cache_policy(float fraction = 1.0f) {
int64_t result;
#define EMIT(PRIMARY_PRIORITY_STR, SECONDARY_PRIORITY_STR) \
asm volatile( \
"createpolicy.fractional.L2::" PRIMARY_PRIORITY_STR ".L2::" SECONDARY_PRIORITY_STR ".b64 %0, %1;\n" \
: "=l"(result) \
: "f"(fraction) \
);
#define EMIT2(PRIMARY_PRIORITY_STR) \
{ \
if constexpr (SECONDARY_PRIORITY == CacheHint::EVICT_FIRST) { \
EMIT(PRIMARY_PRIORITY_STR, "evict_first") \
} else if constexpr (SECONDARY_PRIORITY == CacheHint::EVICT_UNCHANGED) { \
EMIT(PRIMARY_PRIORITY_STR, "evict_unchanged") \
} else { \
static_assert(SECONDARY_PRIORITY == CacheHint::EVICT_FIRST || \
SECONDARY_PRIORITY == CacheHint::EVICT_UNCHANGED, \
"Unsupported secondary cache hint for create_fraction_based_cache_policy."); \
} \
}
if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_FIRST) {
EMIT2("evict_first");
} else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_NORMAL) {
EMIT2("evict_normal");
} else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_LAST) {
EMIT2("evict_last");
} else if constexpr (PRIMARY_PRIORITY == CacheHint::EVICT_UNCHANGED) {
EMIT2("evict_unchanged");
} else {
static_assert(PRIMARY_PRIORITY == CacheHint::EVICT_FIRST ||
PRIMARY_PRIORITY == CacheHint::EVICT_NORMAL ||
PRIMARY_PRIORITY == CacheHint::EVICT_LAST ||
PRIMARY_PRIORITY == CacheHint::EVICT_UNCHANGED,
"Unsupported primary cache hint for create_fraction_based_cache_policy.");
}
#undef EMIT
#undef EMIT2
return result;
}
// Create a simple cache policy (equivalent to create_fraction_based_cache_policy(1.0f))
// The same as cute::TMA::CacheHintSmXX
template<CacheHint CACHE_HINT>
CUTE_DEVICE
constexpr int64_t create_simple_cache_policy() {
if constexpr (CACHE_HINT == CacheHint::EVICT_FIRST) {
return 0x12F0000000000000; // Result of createpolicy.fractional.L2::evict_first.b64
} else if constexpr (CACHE_HINT == CacheHint::EVICT_NORMAL) {
return 0x1000000000000000; // Copied from CuTe. Unsure about the exact meaning. (TODO Change to 0x16F0000000000000?)
} else if constexpr (CACHE_HINT == CacheHint::EVICT_LAST) {
return 0x14F0000000000000; // Result of createpolicy.fractional.L2::evict_last.b64
} else {
static_assert(CACHE_HINT == CacheHint::EVICT_FIRST ||
CACHE_HINT == CacheHint::EVICT_NORMAL ||
CACHE_HINT == CacheHint::EVICT_LAST,
"Unsupported cache hint for create_simple_cache_policy.");
}
}
// AtomicAdd (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red)
CUTE_DEVICE
void atomicadd_f32_with_policy_and_pred(void* global_addr, const float &data, int64_t cache_policy, uint32_t pred = true) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.eq.u32 p, %3, 1;\n\t"
"@p red.relaxed.gpu.global.add.L2::cache_hint.f32 [%1], %0, %2; \n\t"
"}"
:
: "f"(data),
"l"((int64_t)global_addr), "l"(cache_policy), "r"(pred)
);
}
// Get the id of the current SM
// About %smid (https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-smid): PTX document says that %smid ranges from 0 to %nsmid-1, while "The SM identifier numbering is not guaranteed to be contiguous, so %nsmid may be larger than the physical number of SMs in the device.". However, result shows that, at least for sm90 and sm100f, %nsmid is the number of physical SMs - 1. For the sake of safety, I recommend you to check the return of get_sm_id manually or call `get_sm_id_with_range_check()` defined in `device/sm80/helpers.cuh`.
// Besides, PTX document also says that this number may change due to preemption, but currently this never happens according to [DATEN GELÖSCHT]
CUTE_DEVICE
uint32_t get_sm_id() {
uint32_t ret;
asm volatile("mov.u32 %0, %%smid;\n" : "=r"(ret));
return ret;
}
// trap (https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-trap)
CUTE_DEVICE
void trap() {
asm volatile("trap;\n");
}
// LDG.128 or LDG.128 with non-coherent cache (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld)
// We use macro instead of function here, since we need a multi-level recursive dispatch based on template parameters if using function
// NC_STR should be either "" or ".nc"
// L1_CACHE_HINT_STR should be either "evict_first", "evict_normal", "evict_last", "evict_unchanged", or "no_allocate"
// L2_PREFETCH_SIZE_STR should be either "64B", "128B", or "256B"
// L2 cache hint is not supported since it's only supported for LDG.256
#define KU_LDG_128(global_addr, result, NC_STR, L1_CACHE_HINT_STR, L2_PREFETCH_SIZE_STR) \
{ \
static_assert(std::is_pointer_v<decltype(global_addr)> || std::is_array_v<decltype(global_addr)>, "`global_addr` must be a pointer"); \
static_assert(std::is_pointer_v<decltype(result)> || std::is_array_v<decltype(result)>, "`result` must be a pointer"); \
uint64_t* result_as_uint64_ptr = (uint64_t*)(result); \
asm volatile( \
"ld.global" NC_STR ".L1::" L1_CACHE_HINT_STR ".L2::" L2_PREFETCH_SIZE_STR ".v2.u64 {%0, %1}, [%2];\n" \
: "=l"(result_as_uint64_ptr[0]), "=l"(result_as_uint64_ptr[1]) \
: "l"(global_addr) \
); \
}
}
...@@ -6,105 +6,105 @@ ...@@ -6,105 +6,105 @@
namespace kerutils { namespace kerutils {
template< // template<
typename TMA, // typename TMA,
typename Tensor0, // typename Tensor0,
typename Tensor1 // typename Tensor1
> // >
CUTE_DEVICE // CUTE_DEVICE
void launch_tma_copy( // void launch_tma_copy(
const TMA &tma_copy, // const TMA &tma_copy,
Tensor0 src, // Tensor0 src,
Tensor1 dst, // Tensor1 dst,
transac_bar_t &bar, // transac_bar_t &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL // const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL
) { // ) {
auto thr_tma = tma_copy.get_slice(cute::_0{}); // auto thr_tma = tma_copy.get_slice(cute::_0{});
cute::copy( // cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint), // tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint),
thr_tma.partition_S(src), // thr_tma.partition_S(src),
thr_tma.partition_D(dst) // thr_tma.partition_D(dst)
); // );
} // }
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx // // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a // // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
CUTE_DEVICE // CUTE_DEVICE
int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { // int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); // int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);
return row_idx; // return row_idx;
} // }
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in some rows. This function converts the local_elem_idx to the actual col_idx // // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in some rows. This function converts the local_elem_idx to the actual col_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a // // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
CUTE_DEVICE // CUTE_DEVICE
int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) { // int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) {
int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1); // int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1);
return col_idx; // return col_idx;
} // }
template <bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma> // template <bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
CUTE_DEVICE // CUTE_DEVICE
void wgmma(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC, bool zero_init) { // void wgmma(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC, bool zero_init) {
using namespace cute; // using namespace cute;
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value; // constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const // // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); } // if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC); // warpgroup_fence_operand(tCrC);
warpgroup_arrive(); // warpgroup_arrive();
tiled_mma.accumulate_ = zero_init ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; // tiled_mma.accumulate_ = zero_init ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
// Unroll the K mode manually to set scale D to 1 // // Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); // cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One; // tiled_mma.accumulate_ = GMMA::ScaleOut::One;
} // }
if constexpr (commit) { // if constexpr (commit) {
warpgroup_commit_batch(); // warpgroup_commit_batch();
} // }
warpgroup_fence_operand(tCrC); // warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); } // if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
} // }
template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma> // template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
CUTE_DEVICE // CUTE_DEVICE
void wgmma_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { // void wgmma_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {
using namespace cute; // using namespace cute;
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); // ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor sA_frag = thr_mma.partition_fragment_A(sA); // Tensor sA_frag = thr_mma.partition_fragment_A(sA);
Tensor sB_frag = thr_mma.partition_fragment_B(sB); // Tensor sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(sA_frag) == size<2>(sB_frag)); // static_assert(size<2>(sA_frag) == size<2>(sB_frag));
warpgroup_fence_operand(rC_frag); // warpgroup_fence_operand(rC_frag);
warpgroup_arrive(); // warpgroup_arrive();
tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; // tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<2>(sA_frag); ++k) { // for (int k = 0; k < size<2>(sA_frag); ++k) {
cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag); // cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag);
tiled_mma.accumulate_ = GMMA::ScaleOut::One; // tiled_mma.accumulate_ = GMMA::ScaleOut::One;
} // }
warpgroup_fence_operand(rC_frag); // warpgroup_fence_operand(rC_frag);
} // }
template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma> // template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
CUTE_DEVICE // CUTE_DEVICE
void wgmma_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { // void wgmma_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {
using namespace cute; // using namespace cute;
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); // ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor sB_frag = thr_mma.partition_fragment_B(sB); // Tensor sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(rA_frag) == size<2>(sB_frag)); // static_assert(size<2>(rA_frag) == size<2>(sB_frag));
warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag)); // warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));
warpgroup_fence_operand(rC_frag); // warpgroup_fence_operand(rC_frag);
warpgroup_arrive(); // warpgroup_arrive();
tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; // tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<2>(rA_frag); ++k) { // for (int k = 0; k < size<2>(rA_frag); ++k) {
cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag); // cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag);
tiled_mma.accumulate_ = GMMA::ScaleOut::One; // tiled_mma.accumulate_ = GMMA::ScaleOut::One;
} // }
warpgroup_fence_operand(rC_frag); // warpgroup_fence_operand(rC_frag);
warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag)); // warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));
} // }
} }
...@@ -4,104 +4,104 @@ ...@@ -4,104 +4,104 @@
namespace kerutils { namespace kerutils {
// st.async (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-async) // // st.async (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-async)
template<typename T> // template<typename T>
CUTE_DEVICE // CUTE_DEVICE
static void st_async(void* dst_ptr, const T& data, transac_bar_t &mbar) { // static void st_async(void* dst_ptr, const T& data, transac_bar_t &mbar) {
static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async."); // static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async.");
long2 data_long2 = *reinterpret_cast<const long2*>(&data); // long2 data_long2 = *reinterpret_cast<const long2*>(&data);
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); // uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar); // uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar);
asm volatile ( // asm volatile (
"st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" // "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n"
: // :
: "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) // : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr)
); // );
} // }
static constexpr int PEER_ADDR_MASK = 16777216; // static constexpr int PEER_ADDR_MASK = 16777216;
// Given an address in the current CTA, return the corresponding address in the peer CTA // // Given an address in the current CTA, return the corresponding address in the peer CTA
template<typename T> // template<typename T>
CUTE_DEVICE // CUTE_DEVICE
T* get_peer_addr(const T* p) { // T* get_peer_addr(const T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); // return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
} // }
// Given an address in the current CTA, return the corresponding address in the peer CTA (if the current CTA_id%2 == 1) or the address itself (if CTA_id%2 == 0) // // Given an address in the current CTA, return the corresponding address in the peer CTA (if the current CTA_id%2 == 1) or the address itself (if CTA_id%2 == 0)
template<typename T> // template<typename T>
CUTE_DEVICE // CUTE_DEVICE
T* get_cta0_addr(const T* p) { // T* get_cta0_addr(const T* p) {
constexpr int CTA0_ADDR_MASK = 0xFEFFFFFF; // constexpr int CTA0_ADDR_MASK = 0xFEFFFFFF;
return (T*)((int64_t)(p) & CTA0_ADDR_MASK); // return (T*)((int64_t)(p) & CTA0_ADDR_MASK);
} // }
// TMA bulk reduce add (cp.reduce.async.bulk), shared to global, float32, add. (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk) // // TMA bulk reduce add (cp.reduce.async.bulk), shared to global, float32, add. (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk)
CUTE_DEVICE // CUTE_DEVICE
void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) { // void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) {
uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(src_ptr); // uint32_t smem_int_ptr = cute::cast_smem_ptr_to_uint(src_ptr);
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" // asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
: // :
: "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes) // : "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes)
: "memory"); // : "memory");
} // }
// Cluster barrier arrive with .release modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster) // // Cluster barrier arrive with .release modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)
CUTE_DEVICE // CUTE_DEVICE
void barrier_cluster_arrive_release() { // void barrier_cluster_arrive_release() {
asm volatile("barrier.cluster.arrive.release;" : : : "memory"); // asm volatile("barrier.cluster.arrive.release;" : : : "memory");
} // }
// Cluster barrier arrive with .relaxed modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster) // // Cluster barrier arrive with .relaxed modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)
CUTE_DEVICE // CUTE_DEVICE
void barrier_cluster_arrive_relaxed() { // void barrier_cluster_arrive_relaxed() {
asm volatile("barrier.cluster.arrive.relaxed;" : : :); // asm volatile("barrier.cluster.arrive.relaxed;" : : :);
} // }
// Cluster barrier wait with .acquire modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster) // // Cluster barrier wait with .acquire modifier. (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-barrier-cluster)
CUTE_DEVICE // CUTE_DEVICE
void barrier_cluster_wait_acquire() { // void barrier_cluster_wait_acquire() {
asm volatile("barrier.cluster.wait.acquire;" : : : "memory"); // asm volatile("barrier.cluster.wait.acquire;" : : : "memory");
} // }
// mbarrier.arrive with .relaxed.cluster qualifier (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-arrive) // // mbarrier.arrive with .relaxed.cluster qualifier (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
CUTE_DEVICE // CUTE_DEVICE
void mbarrier_arrive_relaxed_cluster(transac_bar_t &mbar) { // void mbarrier_arrive_relaxed_cluster(transac_bar_t &mbar) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(&mbar); // uint32_t smem_addr = cute::cast_smem_ptr_to_uint(&mbar);
asm volatile( // asm volatile(
"{\n\t" // "{\n\t"
"mbarrier.arrive.relaxed.cluster.shared::cta.b64 _, [%0];\n\t" // "mbarrier.arrive.relaxed.cluster.shared::cta.b64 _, [%0];\n\t"
"}" // "}"
: // :
: "r"(smem_addr)); // : "r"(smem_addr));
} // }
// AtomicAdd with v4.f32 type (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red) // // AtomicAdd with v4.f32 type (https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-red)
CUTE_DEVICE // CUTE_DEVICE
void atomicadd_f32x4_with_policy_and_pred(void* global_addr, const float4 &data, int64_t cache_policy, uint32_t pred = true) { // void atomicadd_f32x4_with_policy_and_pred(void* global_addr, const float4 &data, int64_t cache_policy, uint32_t pred = true) {
asm volatile( // asm volatile(
"{\n\t" // "{\n\t"
".reg .pred p;\n\t" // ".reg .pred p;\n\t"
"setp.eq.u32 p, %6, 1;\n\t" // "setp.eq.u32 p, %6, 1;\n\t"
"@p red.relaxed.gpu.global.add.L2::cache_hint.v4.f32 [%4], {%0, %1, %2, %3}, %5; \n\t" // "@p red.relaxed.gpu.global.add.L2::cache_hint.v4.f32 [%4], {%0, %1, %2, %3}, %5; \n\t"
"}" // "}"
: // :
: "f"(data.x), "f"(data.y), "f"(data.z), "f"(data.w), // : "f"(data.x), "f"(data.y), "f"(data.z), "f"(data.w),
"l"((int64_t)global_addr), "l"(cache_policy), "r"(pred) // "l"((int64_t)global_addr), "l"(cache_policy), "r"(pred)
); // );
} // }
// cp.async.bulk, from .shared::cta to .shared::cluster (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk) // // cp.async.bulk, from .shared::cta to .shared::cluster (https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk)
CUTE_DEVICE // CUTE_DEVICE
void cp_async_bulk_shared_cta_to_shared_cluster(void* dst_ptr, const void* src_ptr, int32_t load_bytes, transac_bar_t &mbar) { // void cp_async_bulk_shared_cta_to_shared_cluster(void* dst_ptr, const void* src_ptr, int32_t load_bytes, transac_bar_t &mbar) {
uint32_t dst_smem_addr = cute::cast_smem_ptr_to_uint(dst_ptr); // uint32_t dst_smem_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t src_smem_addr = cute::cast_smem_ptr_to_uint(src_ptr); // uint32_t src_smem_addr = cute::cast_smem_ptr_to_uint(src_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar); // uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(&mbar);
asm volatile( // asm volatile(
"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \n" // "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \n"
: // :
: "r"(dst_smem_addr), "r"(src_smem_addr), "r"(load_bytes), "r"(mbar_addr) // : "r"(dst_smem_addr), "r"(src_smem_addr), "r"(load_bytes), "r"(mbar_addr)
); // );
} // }
} }
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cuda.h> #include <cuda.h>
#include <cutlass/cuda_host_adapter.hpp> // #include <cutlass/cuda_host_adapter.hpp>
#include "kerutils/common/common.h" #include "kerutils/common/common.h"
...@@ -78,78 +78,78 @@ inline __host__ __device__ constexpr T ceil(const T &a, const T &b) { ...@@ -78,78 +78,78 @@ inline __host__ __device__ constexpr T ceil(const T &a, const T &b) {
return (a + b - 1) / b * b; return (a + b - 1) / b * b;
} }
// A wrapper for make_tensor_map // // A wrapper for make_tensor_map
static inline CUtensorMap make_tensor_map( // static inline CUtensorMap make_tensor_map(
const std::vector<uint64_t> &size, // const std::vector<uint64_t> &size,
const std::vector<uint64_t> &strides, // PAY ATTENTION: In BYTES // const std::vector<uint64_t> &strides, // PAY ATTENTION: In BYTES
const std::vector<uint32_t> &box_size, // const std::vector<uint32_t> &box_size,
void* global_ptr, // void* global_ptr,
CUtensorMapDataType data_type, // CUtensorMapDataType data_type,
CUtensorMapSwizzle swizzle_mode, // CUtensorMapSwizzle swizzle_mode,
CUtensorMapL2promotion l2_promotion, // CUtensorMapL2promotion l2_promotion,
CUtensorMapInterleave interleave_mode = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, // CUtensorMapInterleave interleave_mode = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapFloatOOBfill oob_fill = CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, // CUtensorMapFloatOOBfill oob_fill = CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
const std::vector<uint32_t> &element_strides_ = {} // const std::vector<uint32_t> &element_strides_ = {}
) { // ) {
int dim = size.size(); // int dim = size.size();
KU_ASSERT(dim >= 1); // KU_ASSERT(dim >= 1);
std::vector<uint32_t> element_strides; // std::vector<uint32_t> element_strides;
if (element_strides_.empty()) { // if (element_strides_.empty()) {
for (int i = 0; i < dim; ++i) // for (int i = 0; i < dim; ++i)
element_strides.push_back(1); // element_strides.push_back(1);
} else { // } else {
element_strides = element_strides_; // element_strides = element_strides_;
} // }
KU_ASSERT(strides.size() == (uint32_t)dim-1 && box_size.size() == (uint32_t)dim && element_strides.size() == (uint32_t)dim); // KU_ASSERT(strides.size() == (uint32_t)dim-1 && box_size.size() == (uint32_t)dim && element_strides.size() == (uint32_t)dim);
CUtensorMap result; // CUtensorMap result;
CUresult ret_code = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( // CUresult ret_code = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&result, // &result,
data_type, // data_type,
dim, // dim,
global_ptr, // global_ptr,
size.data(), // size.data(),
strides.data(), // strides.data(),
box_size.data(), // box_size.data(),
element_strides.data(), // element_strides.data(),
interleave_mode, // interleave_mode,
swizzle_mode, // swizzle_mode,
l2_promotion, // l2_promotion,
oob_fill // oob_fill
); // );
if (ret_code != CUresult::CUDA_SUCCESS) { // if (ret_code != CUresult::CUDA_SUCCESS) {
auto print_vector = [&](auto t, const char* fmt, const char end='\n') { // auto print_vector = [&](auto t, const char* fmt, const char end='\n') {
for (auto elem : t) { // for (auto elem : t) {
printf(fmt, elem); // printf(fmt, elem);
} // }
printf("%c", end); // printf("%c", end);
}; // };
fprintf(stderr, "Failed to create tensormap\n"); // fprintf(stderr, "Failed to create tensormap\n");
fprintf(stderr, "Dim: %d\n", dim); // fprintf(stderr, "Dim: %d\n", dim);
printf("size: "); print_vector(size, "%lu "); // printf("size: "); print_vector(size, "%lu ");
printf("strides: "); print_vector(strides, "%lu "); // printf("strides: "); print_vector(strides, "%lu ");
printf("box_size: "); print_vector(box_size, "%u "); // printf("box_size: "); print_vector(box_size, "%u ");
printf("element_strides: "); print_vector(element_strides, "%u "); // printf("element_strides: "); print_vector(element_strides, "%u ");
printf("global ptr: 0x%lx\n", (int64_t)global_ptr); // printf("global ptr: 0x%lx\n", (int64_t)global_ptr);
printf("data_type: %d\n", (int)data_type); // printf("data_type: %d\n", (int)data_type);
printf("swizzle_mode: %d\n", (int)swizzle_mode); // printf("swizzle_mode: %d\n", (int)swizzle_mode);
printf("l2_promotion: %d\n", (int)l2_promotion); // printf("l2_promotion: %d\n", (int)l2_promotion);
printf("interleave_mode: %d\n", (int)interleave_mode); // printf("interleave_mode: %d\n", (int)interleave_mode);
printf("oob_fill: %d\n", (int)oob_fill); // printf("oob_fill: %d\n", (int)oob_fill);
KU_ASSERT(false); // KU_ASSERT(false);
} // }
return result; // return result;
} // }
// Given strides (in number of elements), this function converts their datatype in uint64_t and then multiplies by elem_size // // Given strides (in number of elements), this function converts their datatype in uint64_t and then multiplies by elem_size
template<typename T> // template<typename T>
static inline std::vector<uint64_t> make_stride_helper(const std::vector<T> &strides_in_elems, size_t elem_size) { // static inline std::vector<uint64_t> make_stride_helper(const std::vector<T> &strides_in_elems, size_t elem_size) {
std::vector<uint64_t> res; // std::vector<uint64_t> res;
for (auto stride : strides_in_elems) { // for (auto stride : strides_in_elems) {
res.push_back(((uint64_t)stride) * elem_size); // res.push_back(((uint64_t)stride) * elem_size);
} // }
return res; // return res;
} // }
} }
\ No newline at end of file
...@@ -7,7 +7,7 @@ enum class ModelType { ...@@ -7,7 +7,7 @@ enum class ModelType {
MODEL1 MODEL1
}; };
struct __align__(4*8) DecodingSchedMeta { struct alignas(32) DecodingSchedMeta {
int begin_req_idx, end_req_idx; // Both inclusive int begin_req_idx, end_req_idx; // Both inclusive
int begin_block_idx, end_block_idx; // Inclusive, exclusive int begin_block_idx, end_block_idx; // Inclusive, exclusive
int begin_split_idx; int begin_split_idx;
......
Head128 decoding kernels are located at `csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu` (for k_dim = 512) or simulated using 2x head64 kernel
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment