Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
...@@ -37,18 +37,19 @@ ExtractFuncInfo(const IRModule &mod) { ...@@ -37,18 +37,19 @@ ExtractFuncInfo(const IRModule &mod) {
} }
info.arg_types.push_back(f->params[i].dtype()); info.arg_types.push_back(f->params[i].dtype());
} }
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) { if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
for (const auto &tag : opt.value()) { for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag); info.launch_param_tags.push_back(tag);
} }
} }
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info; fmap[static_cast<std::string>(global_symbol.value())] = info;
} }
return fmap; return fmap;
} }
runtime::Module BuildTileLangHIP(IRModule mod, Target target) { ffi::Module BuildTileLangHIP(IRModule mod, Target target) {
bool output_ssa = false; bool output_ssa = false;
CodeGenTileLangHIP cg; CodeGenTileLangHIP cg;
cg.Init(output_ssa); cg.Init(output_ssa);
...@@ -84,7 +85,7 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) { ...@@ -84,7 +85,7 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string()); return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string());
} }
runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) { ffi::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
bool output_ssa = false; bool output_ssa = false;
CodeGenTileLangHIP cg; CodeGenTileLangHIP cg;
cg.Init(output_ssa); cg.Init(output_ssa);
...@@ -110,13 +111,13 @@ runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) { ...@@ -110,13 +111,13 @@ runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
std::string()); std::string());
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef() refl::GlobalDef()
.def("target.build.tilelang_hip", BuildTileLangHIP) .def("target.build.tilelang_hip", BuildTileLangHIP)
.def("target.build.tilelang_hip_without_compile", .def("target.build.tilelang_hip_without_compile",
BuildTileLangHIPWithoutCompile); BuildTileLangHIPWithoutCompile);
}); }
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#include "utils.h" #include "utils.h"
#include "../support/ffi_aliases.h"
#include <tvm/node/node.h>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -16,8 +19,8 @@ bool TargetIsRocm(Target target) { ...@@ -16,8 +19,8 @@ bool TargetIsRocm(Target target) {
} }
int GetArchInt(Target target) { int GetArchInt(Target target) {
auto s = target->GetAttr<String>("arch"); auto s = target->GetAttr<tvm::ffi::String>("arch");
ICHECK(s.defined()); ICHECK(s.has_value());
const std::string arch_str = s.value(); const std::string arch_str = s.value();
ICHECK(arch_str.size() >= 3); ICHECK(arch_str.size() >= 3);
ICHECK_EQ(arch_str.compare(0, 3, "sm_"), 0) ICHECK_EQ(arch_str.compare(0, 3, "sm_"), 0)
...@@ -71,7 +74,7 @@ bool TargetIsCDNA(Target target) { ...@@ -71,7 +74,7 @@ bool TargetIsCDNA(Target target) {
if (!TargetIsRocm(target)) if (!TargetIsRocm(target))
return false; return false;
if (target->attrs.count("mcpu")) { if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu")); std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
// if mcpu start with "gfx9", it is CDNA // if mcpu start with "gfx9", it is CDNA
return mcpu.find("gfx9") == 0; return mcpu.find("gfx9") == 0;
} }
...@@ -94,7 +97,7 @@ bool TargetHasAsyncCopy(Target target) { ...@@ -94,7 +97,7 @@ bool TargetHasAsyncCopy(Target target) {
return arch >= 80; return arch >= 80;
} else if (TargetIsCDNA(target)) { } else if (TargetIsCDNA(target)) {
if (target->attrs.count("mcpu")) { if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu")); std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
if (mcpu.rfind("gfx9", 0) == 0) { if (mcpu.rfind("gfx9", 0) == 0) {
int gfx_version = std::stoi(mcpu.substr(3, 2)); int gfx_version = std::stoi(mcpu.substr(3, 2));
return gfx_version >= 94; return gfx_version >= 94;
...@@ -141,7 +144,7 @@ int TargetGetWarpSize(Target target) { ...@@ -141,7 +144,7 @@ int TargetGetWarpSize(Target target) {
return res; return res;
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef() refl::GlobalDef()
.def("tl.TargetIsCuda", .def("tl.TargetIsCuda",
...@@ -170,7 +173,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ...@@ -170,7 +173,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
[](Target target) { return TargetHasBulkCopy(target); }) [](Target target) { return TargetHasBulkCopy(target); })
.def("tl.TargetGetWarpSize", .def("tl.TargetGetWarpSize",
[](Target target) { return TargetGetWarpSize(target); }); [](Target target) { return TargetGetWarpSize(target); });
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -10,6 +10,9 @@ ...@@ -10,6 +10,9 @@
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
#include <math_constants.h> #include <math_constants.h>
#include <cutlass/bfloat16.h>
#include <cutlass/float8.h>
using cutlass::bfloat16_t; using cutlass::bfloat16_t;
using cutlass::half_t; using cutlass::half_t;
using cutlass::tfloat32_t; using cutlass::tfloat32_t;
...@@ -285,6 +288,138 @@ union GmmaDescriptor { ...@@ -285,6 +288,138 @@ union GmmaDescriptor {
} }
}; };
union Tcgen05SMemDescriptor {
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor() noexcept : desc_(0) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor(uint64_t desc) noexcept
: desc_(desc) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor(
Tcgen05SMemDescriptor const &t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor(
Tcgen05SMemDescriptor &&t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor &
operator=(Tcgen05SMemDescriptor const &t) noexcept {
desc_ = t.desc_;
return *this;
}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor &
operator=(Tcgen05SMemDescriptor &&t) noexcept {
desc_ = t.desc_;
return *this;
}
uint64_t desc_;
uint32_t reg32_[2];
// Bitfield implementation avoids the need for shifts in assignment
struct {
// start_address, bit [0,14), 4LSB not included
uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused
// leading dimension byte offset, bit [16,30), 4LSB not included
uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused
// stride dimension byte offset, bit [32,46), 4LSB not included
uint16_t stride_byte_offset_ : 14,
version_ : 2; // 14 bits [0,14), 2 bits [14,16)
// base_offset, bit [49,52). leading_byte_offset_mode, bit [52,53).
uint8_t : 1, base_offset_ : 3, lbo_mode_ : 1,
: 3; // 1 bit unused, 3 bits [1,4), 1 bit [4,5), 3 bits unused
// layout type, bit [61,64), SWIZZLE_NONE matrix descriptor = 0,
// SWIZZLE_128B matrix descriptor = 2, SWIZZLE_64B descriptor = 4,
// SWIZZLE_32B descriptor = 6, SWIZZLE_128B_BASE32B = 1, N/A = 3, N/A = 5,
// N/A = 7
uint8_t : 5, layout_type_ : 3; // 6 bits unused, 3 bits [5,8)
} bitfield;
// Separate the field, as we may only update one part of desc
struct {
uint32_t lo;
uint32_t hi;
} words;
CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept {
return desc_;
}
template <typename T>
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor
operator+(const T &offset) const {
Tcgen05SMemDescriptor ret;
// Address addition is in units of 16 bytes (4 LSB not encoded)
ret.reg32_[0] = reg32_[0] + (uint32_t(offset) >> 4);
ret.reg32_[1] = reg32_[1];
return ret;
}
};
//
// Tcgen05 instruction descriptor (wraps cute::UMMA::InstrDescriptor layout)
//
union Tcgen05InstrDescriptor {
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor() noexcept : desc_(0) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor(uint32_t desc) noexcept
: desc_(desc) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor(
Tcgen05InstrDescriptor const &t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor(
Tcgen05InstrDescriptor &&t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor &
operator=(Tcgen05InstrDescriptor const &t) noexcept {
desc_ = t.desc_;
return *this;
}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor &
operator=(Tcgen05InstrDescriptor &&t) noexcept {
desc_ = t.desc_;
return *this;
}
uint32_t desc_;
uint16_t reg16_[2];
// Bitfield implementation mirrors cute::UMMA::InstrDescriptor
struct {
// bit [ 0, 2) : Sparse meta data id2
uint16_t sparse_id2_ : 2,
// bit [ 2, 3) : 0 = dense. 1 = sparse. Only valid for
// F32F16/S8/MXF8F6F4
sparse_flag_ : 1,
// bit [ 3, 4) : 0 = no saturate. 1 = saturate. Only valid for S8
saturate_ : 1,
// bit [ 4, 6) : 0 = F16. 1 = F32, 2 = S32
c_format_ : 2,
// padding
: 1,
// bit [ 7,10) : see UMMA format encoding
a_format_ : 3,
// bit [10,13) : see UMMA format encoding
b_format_ : 3,
// bit [13,14) : 0 = no negate. 1 = negate
a_negate_ : 1,
// bit [14,15) : 0 = no negate. 1 = negate
b_negate_ : 1,
// bit [15,16) : 0 = K-major. 1 = MN-major
a_major_ : 1;
// Upper 16 bits
uint16_t b_major_ : 1, // bit [16,17)
n_dim_ : 6, // bit [17,23) : 3 LSBs not included
: 1, // padding
m_dim_ : 5, // bit [24,29) : 4 LSBs not included
: 1, // padding
max_shift_ : 2; // bit [30,32)
} bitfield;
// Decay to a uint32_t
CUTE_HOST_DEVICE constexpr explicit operator uint32_t() const noexcept {
return desc_;
}
};
// Any // Any
template <typename T> TL_DEVICE bool Any(T *a, int size) { template <typename T> TL_DEVICE bool Any(T *a, int size) {
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
...@@ -323,8 +458,8 @@ TL_DEVICE void __sync_thread_partial() { ...@@ -323,8 +458,8 @@ TL_DEVICE void __sync_thread_partial() {
template <int layout_type = 0, int leading_byte_offset = 0, template <int layout_type = 0, int leading_byte_offset = 0,
int stride_byte_offset = 0, typename T> int stride_byte_offset = 0, typename T>
TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, TL_DEVICE void initialize_wgmma_descriptor(GmmaDescriptor &descriptor,
T *start_address) { T *start_address) {
descriptor.bitfield.start_address_ = descriptor.bitfield.start_address_ =
cute::cast_smem_ptr_to_uint(start_address) >> 4; cute::cast_smem_ptr_to_uint(start_address) >> 4;
descriptor.bitfield.layout_type_ = layout_type; descriptor.bitfield.layout_type_ = layout_type;
...@@ -333,15 +468,151 @@ TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, ...@@ -333,15 +468,151 @@ TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
} }
template <typename T>
TL_DEVICE void
initialize_tcgen05_descriptor(Tcgen05SMemDescriptor &descriptor,
T *start_address, int leading_byte_offset,
int stride_byte_offset, int base_offset,
bool leading_is_absolute, int swizzle_mode) {
descriptor.bitfield.start_address_ =
static_cast<uint16_t>(cast_smem_ptr_to_uint(start_address) >> 4);
descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
descriptor.bitfield.version_ = 1;
descriptor.bitfield.base_offset_ = base_offset & 0x7;
descriptor.bitfield.lbo_mode_ = leading_is_absolute ? 1 : 0;
descriptor.bitfield.layout_type_ = swizzle_mode & 0x7;
}
template <typename T> template <typename T>
TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor, TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
T offset) { T offset) {
descriptor.reg32_[0] += (offset >> 4); descriptor.reg32_[0] += (offset >> 4);
} }
// and add the desired implicit conversion from bfloat16_t.
struct float_e4m3_t : public cute::float_e4m3_t {
using cute::float_e4m3_t::float_e4m3_t;
CUTLASS_HOST_DEVICE
float_e4m3_t() = default;
CUTLASS_HOST_DEVICE
explicit float_e4m3_t(__nv_bfloat16 x)
: float_e4m3_t(static_cast<float>(x)) {}
};
struct float_e5m2_t : public cute::float_e5m2_t {
using cute::float_e5m2_t::float_e5m2_t;
CUTLASS_HOST_DEVICE
float_e5m2_t() = default;
CUTLASS_HOST_DEVICE
explicit float_e5m2_t(__nv_bfloat16 x)
: float_e5m2_t(static_cast<float>(x)) {}
};
template <typename T> struct to_cute_type {
using type = T;
};
template <> struct to_cute_type<tl::float_e4m3_t> {
using type = cute::float_e4m3_t;
};
template <> struct to_cute_type<tl::float_e5m2_t> {
using type = cute::float_e5m2_t;
};
} // namespace tl } // namespace tl
namespace cutlass { namespace cutlass {
TL_DEVICE TL_DEVICE
bfloat16_t fast_exp(bfloat16_t x) { return ::hexp(x); } bfloat16_t fast_exp(bfloat16_t x) { return ::hexp(x); }
} // namespace cutlass } // namespace cutlass
//
// Type-safe warp shuffle helpers for 16-bit float types
// These wrappers avoid relying on implicit conversions that may be disallowed
// (e.g., converting float -> cutlass::bfloat16_t) by explicitly promoting to
// float for the shuffle and then down-converting.
//
namespace tl {
// Generic passthroughs
template <typename T>
TL_DEVICE T shfl_xor_sync(unsigned mask, T val, int laneMask) {
return __shfl_xor_sync(mask, val, laneMask);
}
template <typename T>
TL_DEVICE T shfl_down_sync(unsigned mask, T val, int delta) {
return __shfl_down_sync(mask, val, delta);
}
template <typename T>
TL_DEVICE T shfl_up_sync(unsigned mask, T val, int delta) {
return __shfl_up_sync(mask, val, delta);
}
template <typename T> TL_DEVICE T shfl_sync(unsigned mask, T val, int srcLane) {
return __shfl_sync(mask, val, srcLane);
}
// Specializations for cutlass::half_t
template <>
TL_DEVICE half_t shfl_xor_sync(unsigned mask, half_t val, int laneMask) {
float f = static_cast<float>(val);
float r = __shfl_xor_sync(mask, f, laneMask);
return half_t(r);
}
template <>
TL_DEVICE half_t shfl_down_sync(unsigned mask, half_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_down_sync(mask, f, delta);
return half_t(r);
}
template <>
TL_DEVICE half_t shfl_up_sync(unsigned mask, half_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_up_sync(mask, f, delta);
return half_t(r);
}
template <> TL_DEVICE half_t shfl_sync(unsigned mask, half_t val, int srcLane) {
float f = static_cast<float>(val);
float r = __shfl_sync(mask, f, srcLane);
return half_t(r);
}
// Specializations for cutlass::bfloat16_t
template <>
TL_DEVICE bfloat16_t shfl_xor_sync(unsigned mask, bfloat16_t val,
int laneMask) {
float f = static_cast<float>(val);
float r = __shfl_xor_sync(mask, f, laneMask);
return bfloat16_t(r);
}
template <>
TL_DEVICE bfloat16_t shfl_down_sync(unsigned mask, bfloat16_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_down_sync(mask, f, delta);
return bfloat16_t(r);
}
template <>
TL_DEVICE bfloat16_t shfl_up_sync(unsigned mask, bfloat16_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_up_sync(mask, f, delta);
return bfloat16_t(r);
}
template <>
TL_DEVICE bfloat16_t shfl_sync(unsigned mask, bfloat16_t val, int srcLane) {
float f = static_cast<float>(val);
float r = __shfl_sync(mask, f, srcLane);
return bfloat16_t(r);
}
} // namespace tl
#pragma once #pragma once
#include "common.h"
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <cute/numeric/numeric_types.hpp> #include <cute/numeric/numeric_types.hpp>
using fp8_e4_t = cute::float_e4m3_t; using fp8_e4_t = tl::float_e4m3_t;
using fp8_e5_t = cute::float_e5m2_t; using fp8_e5_t = tl::float_e5m2_t;
struct __CUDA_ALIGN__(2) fp8_e4_2_t { struct __CUDA_ALIGN__(2) fp8_e4_2_t {
fp8_e4_t x; fp8_e4_t x;
......
...@@ -263,16 +263,18 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, ...@@ -263,16 +263,18 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename C_type_raw> typename C_type_raw>
class GemmTensorOp { class GemmTensorOp {
public: public:
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type = using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value, typename std::conditional<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_raw>::type; tfloat32_t, A_type_cute>::type;
using B_type = using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value, typename std::conditional<std::is_same<B_type_cute, float>::value,
tfloat32_t, A_type_raw>::type; tfloat32_t, B_type_cute>::type;
using C_type = C_type_raw; using C_type = C_type_raw;
using Instruction = using Instruction = DispatchInstruction<A_type_raw, B_type_raw, C_type_raw,
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>; num_warp_m, num_warp_n, N>;
using OperandATraits = OperandTraits<sizeof_bits<A_type>::value, M, K, using OperandATraits = OperandTraits<sizeof_bits<A_type>::value, M, K,
!trans_A, num_warp_m, lda>; !trans_A, num_warp_m, lda>;
......
...@@ -289,12 +289,14 @@ template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A, ...@@ -289,12 +289,14 @@ template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A,
typename C_type_raw> typename C_type_raw>
class GemmTensorOp { class GemmTensorOp {
public: public:
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type = using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value, typename std::conditional<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_raw>::type; tfloat32_t, A_type_cute>::type;
using B_type = using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value, typename std::conditional<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_raw>::type; tfloat32_t, B_type_cute>::type;
using C_type = C_type_raw; using C_type = C_type_raw;
static_assert(AtomM == 128 || AtomM == 64 || AtomM == 32); static_assert(AtomM == 128 || AtomM == 64 || AtomM == 32);
......
...@@ -21,10 +21,12 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, ...@@ -21,10 +21,12 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename B_type_raw, typename C_type_raw> typename B_type_raw, typename C_type_raw>
class GemmTensorOp { class GemmTensorOp {
public: public:
using A_type = conditional_t<std::is_same<A_type_raw, float>::value, using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
tfloat32_t, A_type_raw>; using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value, using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
tfloat32_t, B_type_raw>; tfloat32_t, A_type_cute>;
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using C_type = C_type_raw; using C_type = C_type_raw;
static constexpr GMMA::Major GmmaMajorA = static constexpr GMMA::Major GmmaMajorA =
......
...@@ -13,10 +13,12 @@ class GemmTensorOp { ...@@ -13,10 +13,12 @@ class GemmTensorOp {
public: public:
static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4"); static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4");
using A_type = conditional_t<std::is_same<A_type_raw, float>::value, using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
tfloat32_t, A_type_raw>; using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value, using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
tfloat32_t, B_type_raw>; tfloat32_t, A_type_cute>;
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_cute>;
using C_type = C_type_raw; using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast = static constexpr bool need_tfloat32_cast =
......
#pragma once
#include "../common.h"
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp>
#include <type_traits>
#include <utility>
namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false;
#endif
namespace detail {
template <class Impl> struct MmaImplTraits {
using DReg = std::remove_extent_t<typename Impl::DRegisters>;
using AReg = std::remove_extent_t<typename Impl::ARegisters>;
using BReg = std::remove_extent_t<typename Impl::BRegisters>;
using CReg = std::remove_extent_t<typename Impl::CRegisters>;
static constexpr int kDRegs = std::extent_v<typename Impl::DRegisters>;
static constexpr int kARegs = std::extent_v<typename Impl::ARegisters>;
static constexpr int kBRegs = std::extent_v<typename Impl::BRegisters>;
static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
};
template <class Impl, size_t... DIdx, size_t... AIdx, size_t... BIdx,
size_t... CIdx>
TL_DEVICE void
call_fma_impl(typename MmaImplTraits<Impl>::DReg *d,
const typename MmaImplTraits<Impl>::AReg *a,
const typename MmaImplTraits<Impl>::BReg *b,
const typename MmaImplTraits<Impl>::CReg *c,
std::index_sequence<DIdx...>, std::index_sequence<AIdx...>,
std::index_sequence<BIdx...>, std::index_sequence<CIdx...>) {
Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]...);
}
template <class Impl>
TL_DEVICE void call_fma(typename MmaImplTraits<Impl>::DReg *d,
const typename MmaImplTraits<Impl>::AReg *a,
const typename MmaImplTraits<Impl>::BReg *b,
const typename MmaImplTraits<Impl>::CReg *c) {
call_fma_impl<Impl>(d, a, b, c,
std::make_index_sequence<MmaImplTraits<Impl>::kDRegs>{},
std::make_index_sequence<MmaImplTraits<Impl>::kARegs>{},
std::make_index_sequence<MmaImplTraits<Impl>::kBRegs>{},
std::make_index_sequence<MmaImplTraits<Impl>::kCRegs>{});
}
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB, bool Saturate>
struct MmaDispatcher {
using CRegType = void;
using ARegType = void;
using BRegType = void;
static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *,
const CRegType *) {
static_assert(always_false_v<std::integral_constant<int, M>>,
"tl::mma_sync: unsupported configuration");
}
};
#define TL_DEFINE_MMA_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, MValue, \
NValue, KValue, TransAValue, TransBValue, \
SaturateValue, ImplType) \
template <> \
struct MmaDispatcher<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, MValue, NValue, KValue, \
TransAValue, TransBValue, SaturateValue> { \
using Impl = ImplType; \
using Traits = MmaImplTraits<Impl>; \
using CRegType = typename Traits::DReg; \
using ARegType = typename Traits::AReg; \
using BRegType = typename Traits::BReg; \
static_assert( \
std::is_same_v<typename Traits::DReg, typename Traits::CReg>, \
"tl::mma_sync requires matching accumulator/output regs"); \
static TL_DEVICE void exec(CRegType *d, const ARegType *a, \
const BRegType *b, const CRegType *c) { \
call_fma<Impl>(d, a, b, c); \
} \
};
// FP16 inputs (TN layout: A row-major, B column-major)
TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat16, 16, 8, 16, false, true,
false, cute::SM80_16x8x16_F16F16F16F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat32, 16, 8, 16, false, true,
false, cute::SM80_16x8x16_F32F16F16F32_TN)
// BF16 inputs
TL_DEFINE_MMA_DISPATCHER(kBFloat16, kBFloat16, kFloat32, 16, 8, 16, false, true,
false, cute::SM80_16x8x16_F32BF16BF16F32_TN)
// INT8 inputs (k32)
TL_DEFINE_MMA_DISPATCHER(kInt8, kInt8, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32S8S8S32_TN)
TL_DEFINE_MMA_DISPATCHER(kUInt8, kUInt8, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32U8U8S32_TN)
// INT4 inputs (k32)
TL_DEFINE_MMA_DISPATCHER(kInt4, kInt4, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32S4S4S32_TN)
TL_DEFINE_MMA_DISPATCHER(kUInt4, kUInt4, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32U4U4S32_TN)
// FP8 inputs (k32)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E4M3E4M3F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E4M3E4M3F32_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E4M3E5M2F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E4M3E5M2F32_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E5M2E4M3F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E5M2E4M3F32_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E5M2E5M2F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E5M2E5M2F32_TN)
// TF32 inputs (FP32 math on Tensor Cores)
// Support both k=4 and k=8 variants on SM80
TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 4,
false, true, false,
cute::SM80_16x8x4_F32TF32TF32F32_TN)
TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 8,
false, true, false,
cute::SM80_16x8x8_F32TF32TF32F32_TN)
// FP64 inputs (DMMA: m8n8k4, TN layout)
TL_DEFINE_MMA_DISPATCHER(kFloat64, kFloat64, kFloat64, 8, 8, 4, false, true,
false, cute::SM80_8x8x4_F64F64F64F64_TN)
#undef TL_DEFINE_MMA_DISPATCHER
} // namespace detail
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB, bool Saturate = false>
TL_DEVICE void mma_sync(
typename detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA, TransB,
Saturate>::CRegType *c,
const typename detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA,
TransB, Saturate>::ARegType *a,
const typename detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA,
TransB, Saturate>::BRegType *b) {
using Dispatcher = detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA,
TransB, Saturate>;
static_assert(!std::is_void_v<typename Dispatcher::CRegType>,
"tl::mma_sync: unsupported configuration");
Dispatcher::exec(c, a, b, c);
}
} // namespace tl
#pragma once
#include "../common.h"
#include <type_traits>
#include <utility>
namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false;
#endif
namespace detail {
// SM70 MMA Instruction Traits and Implementations
// SM70 supports m16n16k4 (m8n8k4 instruction at warp level) with FP16/FP32
// accumulation
// Base template for SM70 MMA implementation
template <DataType AType, DataType BType, DataType CType, bool TransA,
bool TransB>
struct MmaSm70Impl {
// Default: unsupported configuration
static constexpr bool kSupported = false;
static TL_DEVICE void exec(void *, const void *, const void *, const void *) {
static_assert(always_false_v<std::integral_constant<bool, TransA>>,
"tl::mma_sync_sm70: unsupported configuration");
}
};
// FP16 inputs, FP16 accumulation - col.col (TransA=true, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
true, true> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP16 accumulation - col.row (TransA=true, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
true, false> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP16 accumulation - row.col (TransA=false, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
false, true> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP16 accumulation - row.row (TransA=false, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
false, false> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP32 accumulation - col.col (TransA=true, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
true, true> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// FP16 inputs, FP32 accumulation - col.row (TransA=true, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
true, false> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// FP16 inputs, FP32 accumulation - row.col (TransA=false, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
false, true> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// FP16 inputs, FP32 accumulation - row.row (TransA=false, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
false, false> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// Helper to extract register types
template <class Impl> struct MmaSm70ImplTraits {
using DReg = std::remove_extent_t<typename Impl::DRegisters>;
using AReg = std::remove_extent_t<typename Impl::ARegisters>;
using BReg = std::remove_extent_t<typename Impl::BRegisters>;
using CReg = std::remove_extent_t<typename Impl::CRegisters>;
static constexpr int kDRegs = std::extent_v<typename Impl::DRegisters>;
static constexpr int kARegs = std::extent_v<typename Impl::ARegisters>;
static constexpr int kBRegs = std::extent_v<typename Impl::BRegisters>;
static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
};
// Dispatcher for SM70 MMA operations
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB>
struct MmaSm70Dispatcher {
using CRegType = void;
using ARegType = void;
using BRegType = void;
static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *,
const CRegType *) {
static_assert(always_false_v<std::integral_constant<int, M>>,
"tl::mma_sync_sm70: unsupported configuration. "
"SM70 only supports m16n16k4 with FP16 inputs and FP16/FP32 "
"accumulation.");
}
};
// Helper to call fma with unpacked register arrays
template <class Impl, size_t... DIdx, size_t... AIdx, size_t... BIdx,
size_t... CIdx>
TL_DEVICE void
call_fma_impl_sm70(typename MmaSm70ImplTraits<Impl>::DReg *d,
const typename MmaSm70ImplTraits<Impl>::AReg *a,
const typename MmaSm70ImplTraits<Impl>::BReg *b,
const typename MmaSm70ImplTraits<Impl>::CReg *c,
std::index_sequence<DIdx...>, std::index_sequence<AIdx...>,
std::index_sequence<BIdx...>, std::index_sequence<CIdx...>) {
Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]...);
}
template <class Impl>
TL_DEVICE void call_fma_sm70(typename MmaSm70ImplTraits<Impl>::DReg *d,
const typename MmaSm70ImplTraits<Impl>::AReg *a,
const typename MmaSm70ImplTraits<Impl>::BReg *b,
const typename MmaSm70ImplTraits<Impl>::CReg *c) {
call_fma_impl_sm70<Impl>(
d, a, b, c, std::make_index_sequence<MmaSm70ImplTraits<Impl>::kDRegs>{},
std::make_index_sequence<MmaSm70ImplTraits<Impl>::kARegs>{},
std::make_index_sequence<MmaSm70ImplTraits<Impl>::kBRegs>{},
std::make_index_sequence<MmaSm70ImplTraits<Impl>::kCRegs>{});
}
// Define dispatchers for all supported SM70 configurations
// Note: m8n8k4 instruction computes m16n16k4 at warp level
#define TL_DEFINE_MMA_SM70_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, \
TransAValue, TransBValue) \
template <> \
struct MmaSm70Dispatcher<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, 16, 16, 4, TransAValue, \
TransBValue> { \
using Impl = MmaSm70Impl<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, TransAValue, TransBValue>; \
using Traits = MmaSm70ImplTraits<Impl>; \
using CRegType = typename Traits::DReg; \
using ARegType = typename Traits::AReg; \
using BRegType = typename Traits::BReg; \
static_assert( \
std::is_same_v<typename Traits::DReg, typename Traits::CReg>, \
"tl::mma_sync_sm70 requires matching accumulator/output regs"); \
static TL_DEVICE void exec(CRegType *d, const ARegType *a, \
const BRegType *b, const CRegType *c) { \
call_fma_sm70<Impl>(d, a, b, c); \
} \
};
// FP16 inputs with FP16 accumulation (all layout combinations)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, true, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, true, false)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, false, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, false, false)
// FP16 inputs with FP32 accumulation (all layout combinations)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, true, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, true, false)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, false, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, false, false)
#undef TL_DEFINE_MMA_SM70_DISPATCHER
} // namespace detail
/// SM70 MMA synchronous instruction wrapper
/// Supports m16n16k4 shape (m8n8k4 instruction at warp level) with FP16 inputs
/// and FP16/FP32 accumulation
///
/// @tparam AType Input A data type (kFloat16)
/// @tparam BType Input B data type (kFloat16)
/// @tparam CType Accumulator/output data type (kFloat16 or kFloat32)
/// @tparam M Matrix M dimension (16)
/// @tparam N Matrix N dimension (16)
/// @tparam K Matrix K dimension (4)
/// @tparam TransA Whether A is transposed (false=row-major, true=col-major)
/// @tparam TransB Whether B is transposed (false=row-major, true=col-major)
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB>
TL_DEVICE void mma_sync_sm70(
typename detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K, TransA,
TransB>::CRegType *c,
const typename detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K,
TransA, TransB>::ARegType *a,
const typename detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K,
TransA, TransB>::BRegType *b) {
using Dispatcher =
detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K, TransA, TransB>;
static_assert(!std::is_void_v<typename Dispatcher::CRegType>,
"tl::mma_sync_sm70: unsupported configuration. "
"SM70 only supports m16n16k4 with FP16 inputs.");
Dispatcher::exec(c, a, b, c);
}
} // namespace tl
#pragma once
#include "../common.h"
#include <cute/arch/cluster_sm90.hpp>
namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false;
#endif
// Generic declaration: unsupported by default
template <DataType C_type>
TL_DEVICE void
tcgen05mma_ss(uint64_t const & /*desc_a*/, uint64_t const & /*desc_b*/,
uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/,
uint32_t const & /*desc_val*/, int const & /*mask0*/,
int const & /*mask1*/, int const & /*mask2*/,
int const & /*mask3*/) {
static_assert(
always_false_v<std::integral_constant<int, static_cast<int>(C_type)>>,
"tl::tcgen05mma_ss: unsupported accumulator type");
}
// TS variants: A from TMEM, B from SMEM (desc)
// Generic declaration: unsupported by default
template <DataType C_type>
TL_DEVICE void
tcgen05mma_ts(uint32_t const & /*tmem_a*/, uint64_t const & /*desc_b*/,
uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/,
uint32_t const & /*desc_val*/, int const & /*mask0*/,
int const & /*mask1*/, int const & /*mask2*/,
int const & /*mask3*/) {
static_assert(
always_false_v<std::integral_constant<int, static_cast<int>(C_type)>>,
"tl::tcgen05mma_ts: unsupported accumulator type");
}
// F16/BF16 instruction kind (maps to kind::f16)
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kFloat16>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// BF16 maps to the same f16-kind instruction
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kBFloat16>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ts<DataType::kFloat16>(tmem_a, desc_b, tmem_c, scalec, desc_val,
mask0, mask1, mask2, mask3);
}
// TF32 instruction kind
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kTensorFloat32>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::tf32 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// INT8 instruction kind
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kInt8>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::i8 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// FP8 family instruction kind (maps to f8f6f4)
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kFloat8_e4m3>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], [%1], %2, %3, "
"{%5, %6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kFloat8_e5m2>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ts<DataType::kFloat8_e4m3>(tmem_a, desc_b, tmem_c, scalec,
desc_val, mask0, mask1, mask2, mask3);
}
// F16/BF16 instruction kind (maps to kind::f16)
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
// idescE upper 32 bits carry the instruction descriptor; lower 32 ignored for
// SS Load TMEM base from shared memory slot handled by caller
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// BF16 maps to the same f16-kind instruction
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kBFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ss<DataType::kFloat16>(desc_a, desc_b, tmem_c, scalec, desc_val,
mask0, mask1, mask2, mask3);
}
// TF32 instruction kind
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kTensorFloat32>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::tf32 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// INT8 instruction kind
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kInt8>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::i8 [%0], %1, %2, %3, {%5, %6, "
"%7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// FP8 family instruction kind (maps to f8f6f4)
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kFloat8_e4m3>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kFloat8_e5m2>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ss<DataType::kFloat8_e4m3>(desc_a, desc_b, tmem_c, scalec,
desc_val, mask0, mask1, mask2, mask3);
}
// WS variants: tcgen05.mma.ws.cta_group::1.kind::xxx
// Generic declaration falls back to static assert
template <DataType C_type>
TL_DEVICE void
tcgen05mma_ws_ss(uint64_t const & /*desc_a*/, uint64_t const & /*desc_b*/,
uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/,
uint32_t const & /*desc_val*/, int const & /*mask0*/,
int const & /*mask1*/, int const & /*mask2*/,
int const & /*mask3*/) {
static_assert(
always_false_v<std::integral_constant<int, static_cast<int>(C_type)>>,
"tl::tcgen05mma_ws_ss: unsupported accumulator type");
}
// F16/BF16 ws
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kBFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ws_ss<DataType::kFloat16>(desc_a, desc_b, tmem_c, scalec, desc_val,
mask0, mask1, mask2, mask3);
}
// TF32 ws
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kTensorFloat32>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::tf32 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
// INT8 ws
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kInt8>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::i8 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
// FP8 ws (maps to f8f6f4)
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kFloat8_e4m3>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kFloat8_e5m2>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ws_ss<DataType::kFloat8_e4m3>(
desc_a, desc_b, tmem_c, scalec, desc_val, mask0, mask1, mask2, mask3);
}
} // namespace tl
#pragma once #pragma once
#include "../common.h" #include "../common.h"
#include "cute/arch/mma_sm90_gmma.hpp" #include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp>
#include <type_traits>
#include <utility>
namespace tl { namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false; template <class> inline constexpr bool always_false_v = false;
#endif
// 主类模板 - 移除默认参数,因为特化不能有默认参数 namespace detail {
template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
int K, bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
printf("DEBUG: WgmmaSSImpl fallback - A_type=%d (kFloat16=%d), B_type=%d, "
"C_type=%d, M=%d, N=%d, K=%d, tnspA=%d, tnspB=%d, scaleA=%d, "
"scaleB=%d\n",
(int)A_type, (int)DataType::kFloat16, (int)B_type, (int)C_type, M, N,
K, (int)tnspA, (int)tnspB, scaleA, scaleB);
// 暂时注释掉 static_assert 来看调试输出
// static_assert(always_false_v<decltype(c)>,
// "wgmma_ss: No specialization available for given template
// parameters!");
};
};
// ================================= F16 x F16 -> F16
// =================================
// M64N8K16 F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 8, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N16K16 F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 16, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N32K16 F16 template <bool IsMnMajor> struct MajorValue {
template <bool tnspA, bool tnspB, int scaleA, int scaleB> static constexpr auto value =
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16, IsMnMajor ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K;
64, 32, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
}
}; };
// M64N64K16 F16 template <int Scale> struct ScaleInValue {
template <bool tnspA, bool tnspB, int scaleA, int scaleB> static_assert(Scale == 1 || Scale == -1,
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16, "tl::wgmma requires scale factors of +1 or -1.");
64, 64, 16, tnspA, tnspB, scaleA, scaleB> { static constexpr auto value = Scale == 1 ? cute::SM90::GMMA::ScaleIn::One
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, : cute::SM90::GMMA::ScaleIn::Neg;
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %18, 0;\n"
"wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15},"
" %16, %17, p, %19, %20, %21, %22;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
}; };
// M64N96K16 F16 template <int Scale>
template <bool tnspA, bool tnspB, int scaleA, int scaleB> inline constexpr bool IsValidScale = (Scale == 1 || Scale == -1);
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 96, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %26, 0;\n"
"wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23}, "
"%24, %25, p, %27, %28, %29, %30;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]),
"+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]),
"+r"(c[22]), "+r"(c[23])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N128K16 F16 template <class Impl> struct CallWgmmaSS {
template <bool tnspA, bool tnspB, int scaleA, int scaleB> using CReg = std::remove_extent_t<typename Impl::CRegisters>;
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16, static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
64, 128, 16, tnspA, tnspB, scaleA, scaleB> { static_assert(sizeof(CReg) == sizeof(uint32_t),
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, "tl::wgmma_ss expects 32-bit accumulator registers.");
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %34, 0;\n"
"wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31}, "
"%32, %33, p, %35, %36, %37, %38;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]),
"+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]),
"+r"(c[22]), "+r"(c[23]), "+r"(c[24]), "+r"(c[25]),
"+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]),
"+r"(c[30]), "+r"(c[31])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N192K16 F16 template <size_t... Idx>
template <bool tnspA, bool tnspB, int scaleA, int scaleB> TL_DEVICE static void Run(uint64_t desc_a, uint64_t desc_b, CReg *c,
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16, cute::SM90::GMMA::ScaleOut scale,
64, 192, 16, tnspA, tnspB, scaleA, scaleB> { std::index_sequence<Idx...>) {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, Impl::fma(desc_a, desc_b, c[Idx]..., scale);
bool scale_out) {
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %50, 0;\n"
"wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31, "
"%32, %33, %34, %35, %36, %37, %38, %39, "
"%40, %41, %42, %43, %44, %45, %46, %47}, "
"%48, %49, p, %51, %52, %53, %54;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), "+r"(c[14]),
"+r"(c[15]), "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]),
"+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), "+r"(c[24]),
"+r"(c[25]), "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]),
"+r"(c[30]), "+r"(c[31]), "+r"(c[32]), "+r"(c[33]), "+r"(c[34]),
"+r"(c[35]), "+r"(c[36]), "+r"(c[37]), "+r"(c[38]), "+r"(c[39]),
"+r"(c[40]), "+r"(c[41]), "+r"(c[42]), "+r"(c[43]), "+r"(c[44]),
"+r"(c[45]), "+r"(c[46]), "+r"(c[47])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
} }
};
// M64N256K16 F16 TL_DEVICE static void exec(uint64_t desc_a, uint64_t desc_b, uint32_t *c_raw,
template <bool tnspA, bool tnspB, int scaleA, int scaleB> bool scale_out) {
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16, auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One
64, 256, 16, tnspA, tnspB, scaleA, scaleB> { : cute::SM90::GMMA::ScaleOut::Zero;
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, auto c = reinterpret_cast<CReg *>(c_raw);
bool scale_out) { Run(desc_a, desc_b, c, scale, std::make_index_sequence<kCRegs>{});
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %66, 0;\n"
"wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31, "
"%32, %33, %34, %35, %36, %37, %38, %39, "
"%40, %41, %42, %43, %44, %45, %46, %47, "
"%48, %49, %50, %51, %52, %53, %54, %55, "
"%56, %57, %58, %59, %60, %61, %62, %63}, "
"%64, %65, p, %67, %68, %69, %70;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), "+r"(c[14]),
"+r"(c[15]), "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]),
"+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), "+r"(c[24]),
"+r"(c[25]), "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]),
"+r"(c[30]), "+r"(c[31]), "+r"(c[32]), "+r"(c[33]), "+r"(c[34]),
"+r"(c[35]), "+r"(c[36]), "+r"(c[37]), "+r"(c[38]), "+r"(c[39]),
"+r"(c[40]), "+r"(c[41]), "+r"(c[42]), "+r"(c[43]), "+r"(c[44]),
"+r"(c[45]), "+r"(c[46]), "+r"(c[47]), "+r"(c[48]), "+r"(c[49]),
"+r"(c[50]), "+r"(c[51]), "+r"(c[52]), "+r"(c[53]), "+r"(c[54]),
"+r"(c[55]), "+r"(c[56]), "+r"(c[57]), "+r"(c[58]), "+r"(c[59]),
"+r"(c[60]), "+r"(c[61]), "+r"(c[62]), "+r"(c[63])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
} }
}; };
// ================================= F16 x F16 -> F32 template <class Impl> struct CallWgmmaRS {
// ================================= using AReg = std::remove_extent_t<typename Impl::ARegisters>;
using CReg = std::remove_extent_t<typename Impl::CRegisters>;
// M64N8K16 F16->F32 static constexpr int kARegs = std::extent_v<typename Impl::ARegisters>;
template <bool tnspA, bool tnspB, int scaleA, int scaleB> static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32, static_assert(sizeof(AReg) == sizeof(uint32_t),
64, 8, 16, tnspA, tnspB, scaleA, scaleB> { "tl::wgmma_rs expects 32-bit register operands for A.");
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, static_assert(sizeof(CReg) == sizeof(uint32_t) ||
bool scale_out) { sizeof(CReg) == sizeof(float),
asm volatile("{\n" "tl::wgmma_rs expects 32-bit accumulator registers.");
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n" template <size_t... AIdx, size_t... CIdx>
"wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " TL_DEVICE static void
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" Run(const AReg *a, uint64_t desc_b, CReg *c, cute::SM90::GMMA::ScaleOut scale,
"}\n" std::index_sequence<AIdx...>, std::index_sequence<CIdx...>) {
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) Impl::fma(a[AIdx]..., desc_b, c[CIdx]..., scale);
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
} }
};
// M64N16K16 F16->F32 TL_DEVICE static void exec(const uint32_t *a_raw, uint64_t desc_b,
template <bool tnspA, bool tnspB, int scaleA, int scaleB> uint32_t *c_raw, bool scale_out) {
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32, auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One
64, 16, 16, tnspA, tnspB, scaleA, scaleB> { : cute::SM90::GMMA::ScaleOut::Zero;
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, auto a = reinterpret_cast<const AReg *>(a_raw);
bool scale_out) { auto c = reinterpret_cast<CReg *>(c_raw);
asm volatile( Run(a, desc_b, c, scale, std::make_index_sequence<kARegs>{},
"{\n" std::make_index_sequence<kCRegs>{});
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
} }
}; };
// M64N32K16 F16->F32 } // namespace detail
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
64, 32, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %18, 0;\n"
"wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15}, "
"%16, %17, p, %19, %20, %21, %22;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N64K16 F16->F32 template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
template <bool tnspA, bool tnspB, int scaleA, int scaleB> int K, bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32, struct WgmmaSSImpl {
64, 64, 16, tnspA, tnspB, scaleA, scaleB> { static_assert(detail::IsValidScale<scaleA>, "tl::wgmma_ss: invalid scaleA");
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, static_assert(detail::IsValidScale<scaleB>, "tl::wgmma_ss: invalid scaleB");
bool scale_out) { TL_DEVICE static void execute(uint64_t, uint64_t, uint32_t *, bool) {
asm volatile("{\n" static_assert(always_false_v<std::integral_constant<int, M>>,
".reg .pred p;\n" "tl::wgmma_ss: unsupported configuration");
"setp.ne.b32 p, %34, 0;\n"
"wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31}, "
"%32, %33, p, %35, %36, %37, %38;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]),
"+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]),
"+r"(c[22]), "+r"(c[23]), "+r"(c[24]), "+r"(c[25]),
"+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]),
"+r"(c[30]), "+r"(c[31])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
} }
}; };
// ================================= BF16 x BF16 -> F32 template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
// ================================= int K, bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaRSImpl {
// M64N8K16 BF16->F32 static_assert(detail::IsValidScale<scaleA>, "tl::wgmma_rs: invalid scaleA");
template <bool tnspA, bool tnspB, int scaleA, int scaleB> static_assert(detail::IsValidScale<scaleB>, "tl::wgmma_rs: invalid scaleB");
struct WgmmaSSImpl<DataType::kBFloat16, DataType::kBFloat16, DataType::kFloat32, TL_DEVICE static void execute(const uint32_t *, uint64_t, uint32_t *, bool) {
64, 8, 16, tnspA, tnspB, scaleA, scaleB> { static_assert(always_false_v<std::integral_constant<int, M>>,
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, "tl::wgmma_rs: unsupported configuration");
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
} }
}; };
// M64N16K16 BF16->F32 #define TL_WGMMA_DEFINE_SS_GENERAL(AType, BType, CType, M, N, K, ImplName) \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> template <bool tnspA, bool tnspB, int scaleA, int scaleB> \
struct WgmmaSSImpl<DataType::kBFloat16, DataType::kBFloat16, DataType::kFloat32, struct WgmmaSSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
64, 16, 16, tnspA, tnspB, scaleA, scaleB> { K, tnspA, tnspB, scaleA, scaleB> { \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, static_assert(detail::IsValidScale<scaleA>, \
bool scale_out) { "tl::wgmma_ss: invalid scaleA"); \
asm volatile( static_assert(detail::IsValidScale<scaleB>, \
"{\n" "tl::wgmma_ss: invalid scaleB"); \
".reg .pred p;\n" using Impl = \
"setp.ne.b32 p, %10, 0;\n" cute::SM90::GMMA::ImplName<detail::MajorValue<tnspA>::value, \
"wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " detail::MajorValue<tnspB>::value, \
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" detail::ScaleInValue<scaleA>::value, \
"}\n" detail::ScaleInValue<scaleB>::value>; \
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]) uint32_t *c, bool scale_out) { \
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), detail::CallWgmmaSS<Impl>::exec(desc_a, desc_b, c, scale_out); \
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), } \
"n"(int32_t(tnspB))); };
}
};
// ================================= TF32 x TF32 -> F32 #define TL_WGMMA_DEFINE_SS_TN(AType, BType, CType, M, N, K, ImplName) \
// ================================= template <int scaleA, int scaleB> \
struct WgmmaSSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
// M64N8K8 TF32->F32 K, false, false, scaleA, scaleB> { \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> static_assert(detail::IsValidScale<scaleA>, \
struct WgmmaSSImpl<DataType::kTensorFloat32, DataType::kTensorFloat32, "tl::wgmma_ss: invalid scaleA"); \
DataType::kFloat32, 64, 8, 8, tnspA, tnspB, scaleA, scaleB> { static_assert(detail::IsValidScale<scaleB>, \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, "tl::wgmma_ss: invalid scaleB"); \
bool scale_out) { using Impl = \
asm volatile("{\n" cute::SM90::GMMA::ImplName<detail::ScaleInValue<scaleA>::value, \
".reg .pred p;\n" detail::ScaleInValue<scaleB>::value>; \
"setp.ne.b32 p, %6, 0;\n" TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \
"wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " uint32_t *c, bool scale_out) { \
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" detail::CallWgmmaSS<Impl>::exec(desc_a, desc_b, c, scale_out); \
"}\n" } \
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) };
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N16K8 TF32->F32 #define TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> ImplName) \
struct WgmmaSSImpl<DataType::kTensorFloat32, DataType::kTensorFloat32, template <int scaleA, int scaleB> \
DataType::kFloat32, 64, 16, 8, tnspA, tnspB, scaleA, struct WgmmaSSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
scaleB> { K, false, false, scaleA, scaleB> { \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, static_assert(detail::IsValidScale<scaleA>, \
bool scale_out) { "tl::wgmma_ss: invalid scaleA"); \
asm volatile( static_assert(detail::IsValidScale<scaleB>, \
"{\n" "tl::wgmma_ss: invalid scaleB"); \
".reg .pred p;\n" static_assert(scaleA == 1 && scaleB == 1, \
"setp.ne.b32 p, %10, 0;\n" "tl::wgmma_ss: only +1 scaling supported for this WGMMA"); \
"wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " using Impl = cute::SM90::GMMA::ImplName; \
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \
"}\n" uint32_t *c, bool scale_out) { \
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), detail::CallWgmmaSS<Impl>::exec(desc_a, desc_b, c, scale_out); \
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]) } \
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), };
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
}
};
// ================================= INT8 x INT8 -> INT32 #define TL_WGMMA_DEFINE_RS_GENERAL(AType, BType, CType, M, N, K, ImplName) \
// ================================= template <bool tnspA, bool tnspB, int scaleA, int scaleB> \
struct WgmmaRSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
// M64N8K32 S8->S32 K, tnspA, tnspB, scaleA, scaleB> { \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> static_assert(!tnspA, "tl::wgmma_rs: operand A must be K-major"); \
struct WgmmaSSImpl<DataType::kInt8, DataType::kInt8, DataType::kInt32, 64, 8, static_assert(detail::IsValidScale<scaleA>, \
32, tnspA, tnspB, scaleA, scaleB> { "tl::wgmma_rs: invalid scaleA"); \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, static_assert(detail::IsValidScale<scaleB>, \
bool scale_out) { "tl::wgmma_rs: invalid scaleB"); \
asm volatile("{\n" using Impl = \
".reg .pred p;\n" cute::SM90::GMMA::ImplName<detail::MajorValue<tnspA>::value, \
"setp.ne.b32 p, %4, 0;\n" detail::MajorValue<tnspB>::value, \
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " detail::ScaleInValue<scaleA>::value, \
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" detail::ScaleInValue<scaleB>::value>; \
"}\n" TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \
: "+r"(c[0]), "+r"(c[1]) uint32_t *c, bool scale_out) { \
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), detail::CallWgmmaRS<Impl>::exec(a, desc_b, c, scale_out); \
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), } \
"n"(int32_t(tnspA)), "n"(int32_t(tnspB))); };
}
};
// M64N16K32 S8->S32 #define TL_WGMMA_DEFINE_RS_TN(AType, BType, CType, M, N, K, ImplName) \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> template <int scaleA, int scaleB> \
struct WgmmaSSImpl<DataType::kInt8, DataType::kInt8, DataType::kInt32, 64, 16, struct WgmmaRSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
32, tnspA, tnspB, scaleA, scaleB> { K, false, false, scaleA, scaleB> { \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, static_assert(detail::IsValidScale<scaleA>, \
bool scale_out) { "tl::wgmma_rs: invalid scaleA"); \
asm volatile("{\n" static_assert(detail::IsValidScale<scaleB>, \
".reg .pred p;\n" "tl::wgmma_rs: invalid scaleB"); \
"setp.ne.b32 p, %6, 0;\n" using Impl = \
"wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " cute::SM90::GMMA::ImplName<detail::ScaleInValue<scaleA>::value, \
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" detail::ScaleInValue<scaleB>::value>; \
"}\n" TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) uint32_t *c, bool scale_out) { \
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), detail::CallWgmmaRS<Impl>::exec(a, desc_b, c, scale_out); \
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), } \
"n"(int32_t(tnspA)), "n"(int32_t(tnspB))); };
}
};
// ================================= FP8 x FP8 -> F16/F32 #define TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \
// ================================= ImplName) \
template <int scaleA, int scaleB> \
// M64N8K32 E4M3->F16 struct WgmmaRSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> K, false, false, scaleA, scaleB> { \
struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, static_assert(detail::IsValidScale<scaleA>, \
DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA, "tl::wgmma_rs: invalid scaleA"); \
scaleB> { static_assert(detail::IsValidScale<scaleB>, \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, "tl::wgmma_rs: invalid scaleB"); \
bool scale_out) { static_assert(scaleA == 1 && scaleB == 1, \
asm volatile("{\n" "tl::wgmma_rs: only +1 scaling supported for this WGMMA"); \
".reg .pred p;\n" using Impl = cute::SM90::GMMA::ImplName; \
"setp.ne.b32 p, %4, 0;\n" TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " uint32_t *c, bool scale_out) { \
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" detail::CallWgmmaRS<Impl>::exec(a, desc_b, c, scale_out); \
"}\n" } \
: "+r"(c[0]), "+r"(c[1]) };
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N8K32 E4M3->F32 #define TL_WGMMA_FOREACH_N_FLOAT_MUL8(OP) \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> OP(8) \
struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, OP(16) \
DataType::kFloat32, 64, 8, 32, tnspA, tnspB, scaleA, OP(24) \
scaleB> { OP(32) \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, OP(40) \
bool scale_out) { OP(48) \
asm volatile("{\n" OP(56) \
".reg .pred p;\n" OP(64) \
"setp.ne.b32 p, %6, 0;\n" OP(72) \
"wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " OP(80) \
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" OP(88) \
"}\n" OP(96) \
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) OP(104) \
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), OP(112) \
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), OP(120) \
"n"(int32_t(tnspA)), "n"(int32_t(tnspB))); OP(128) \
} OP(136) \
}; OP(144) \
OP(152) \
OP(160) \
OP(168) \
OP(176) \
OP(184) \
OP(192) \
OP(200) \
OP(208) \
OP(216) \
OP(224) \
OP(232) \
OP(240) \
OP(248) \
OP(256)
#define TL_WGMMA_FOREACH_N_INT32_MUL8(OP) \
OP(8) \
OP(16) \
OP(24) \
OP(32) \
OP(48) \
OP(64) \
OP(80) \
OP(96) \
OP(112) \
OP(128) \
OP(144) \
OP(160) \
OP(176) \
OP(192) \
OP(208) \
OP(224) \
OP(240) \
OP(256)
#define TL_WGMMA_DEFINE_F16_F16_F16_SS(N) \
TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \
MMA_64x##N##x16_F16F16F16_SS)
#define TL_WGMMA_DEFINE_F16_F16_F32_SS(N) \
TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32F16F16_SS)
#define TL_WGMMA_DEFINE_BF16_BF16_F32_SS(N) \
TL_WGMMA_DEFINE_SS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32BF16BF16_SS)
#define TL_WGMMA_DEFINE_F32_TF32_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \
MMA_64x##N##x8_F32TF32TF32_SS_TN)
#define TL_WGMMA_DEFINE_S32_S8S8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8S8_SS_TN)
#define TL_WGMMA_DEFINE_S32_S8U8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8U8_SS_TN)
#define TL_WGMMA_DEFINE_S32_U8S8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8S8_SS_TN)
#define TL_WGMMA_DEFINE_S32_U8U8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8U8_SS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E5M2_SS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E5M2_SS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E5M2_SS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E5M2_SS_TN)
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F16_SS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F32_SS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_BF16_BF16_F32_SS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_TF32_SS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8S8_SS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8U8_SS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8S8_SS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8U8_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN);
#define TL_WGMMA_DEFINE_F16_F16_F16_RS(N) \
TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \
MMA_64x##N##x16_F16F16F16_RS)
#define TL_WGMMA_DEFINE_F16_F16_F32_RS(N) \
TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32F16F16_RS)
#define TL_WGMMA_DEFINE_BF16_BF16_F32_RS(N) \
TL_WGMMA_DEFINE_RS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32BF16BF16_RS)
#define TL_WGMMA_DEFINE_F32_TF32_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \
MMA_64x##N##x8_F32TF32TF32_RS_TN)
#define TL_WGMMA_DEFINE_S32_S8S8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8S8_RS_TN)
#define TL_WGMMA_DEFINE_S32_S8U8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8U8_RS_TN)
#define TL_WGMMA_DEFINE_S32_U8S8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8S8_RS_TN)
#define TL_WGMMA_DEFINE_S32_U8U8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8U8_RS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E5M2_RS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E5M2_RS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E5M2_RS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E5M2_RS_TN)
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F16_RS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F32_RS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_BF16_BF16_F32_RS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_TF32_RS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8S8_RS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8U8_RS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8S8_RS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8U8_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN);
#undef TL_WGMMA_DEFINE_F16_F16_F16_SS
#undef TL_WGMMA_DEFINE_F16_F16_F32_SS
#undef TL_WGMMA_DEFINE_BF16_BF16_F32_SS
#undef TL_WGMMA_DEFINE_F32_TF32_SS_TN
#undef TL_WGMMA_DEFINE_S32_S8S8_SS_TN
#undef TL_WGMMA_DEFINE_S32_S8U8_SS_TN
#undef TL_WGMMA_DEFINE_S32_U8S8_SS_TN
#undef TL_WGMMA_DEFINE_S32_U8U8_SS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F16_F16_F16_RS
#undef TL_WGMMA_DEFINE_F16_F16_F32_RS
#undef TL_WGMMA_DEFINE_BF16_BF16_F32_RS
#undef TL_WGMMA_DEFINE_F32_TF32_RS_TN
#undef TL_WGMMA_DEFINE_S32_S8S8_RS_TN
#undef TL_WGMMA_DEFINE_S32_S8U8_RS_TN
#undef TL_WGMMA_DEFINE_S32_U8S8_RS_TN
#undef TL_WGMMA_DEFINE_S32_U8U8_RS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN
#undef TL_WGMMA_FOREACH_N_FLOAT_MUL8
#undef TL_WGMMA_FOREACH_N_INT32_MUL8
#undef TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE
#undef TL_WGMMA_DEFINE_SS_GENERAL
#undef TL_WGMMA_DEFINE_SS_TN
#undef TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE
#undef TL_WGMMA_DEFINE_RS_GENERAL
#undef TL_WGMMA_DEFINE_RS_TN
// 函数模板委托给类模板
template <DataType A_type, DataType B_type, DataType C_type, int M, int N, template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
int K, bool tnspA, bool tnspB, int scaleA = 1, int scaleB = 1> int K, bool tnspA, bool tnspB, int scaleA = 1, int scaleB = 1>
TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c, TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
...@@ -519,129 +460,12 @@ TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c, ...@@ -519,129 +460,12 @@ TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
scaleB>::execute(desc_a, desc_b, c, scale_out); scaleB>::execute(desc_a, desc_b, c, scale_out);
} }
// ================================= Mixed Precision Support template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
// ================================= int K, bool tnspA, bool tnspB, int scaleA = 1, int scaleB = 1>
TL_DEVICE void wgmma_rs(const uint32_t *a, uint64_t desc_b, uint32_t *c,
// Mixed precision: S8 x U8 -> S32 bool scale_out) {
template <bool tnspA, bool tnspB, int scaleA, int scaleB> WgmmaRSImpl<A_type, B_type, C_type, M, N, K, tnspA, tnspB, scaleA,
struct WgmmaSSImpl<DataType::kInt8, DataType::kUInt8, DataType::kInt32, 64, 8, scaleB>::execute(a, desc_b, c, scale_out);
32, tnspA, tnspB, scaleA, scaleB> { }
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// Mixed precision: U8 x S8 -> S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kUInt8, DataType::kInt8, DataType::kInt32, 64, 8,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// Mixed precision: U8 x U8 -> S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kUInt8, DataType::kUInt8, DataType::kInt32, 64, 8,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// Mixed precision FP8: E4M3 x E5M2 -> F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e5m2,
DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// Mixed precision FP8: E5M2 x E4M3 -> F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat8_e5m2, DataType::kFloat8_e4m3,
DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// ================================= Convenience Templates
// =================================
// Type trait to determine the number of output registers needed
template <DataType C_type, int M, int N> struct WgmmaOutputRegs {
static constexpr int value =
(M * N * (C_type == DataType::kFloat32 ? 32 : 16)) / (32 * 8);
};
// Type trait to get element size in bits
template <DataType dtype> struct ElementBits {
static constexpr int value =
(dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 ||
dtype == DataType::kInt32)
? 32
: (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 ||
dtype == DataType::kInt16 || dtype == DataType::kUInt16)
? 16
: (dtype == DataType::kInt8 || dtype == DataType::kUInt8 ||
dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2)
? 8
: (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4
: 8;
};
} // namespace tl } // namespace tl
\ No newline at end of file
...@@ -67,6 +67,20 @@ template <int NumMma> TL_DEVICE void warpgroup_wait() { ...@@ -67,6 +67,20 @@ template <int NumMma> TL_DEVICE void warpgroup_wait() {
cute::warpgroup_wait<NumMma>(); cute::warpgroup_wait<NumMma>();
} }
TL_DEVICE void warpgroup_fence_operand(uint32_t *regs, int count) {
#pragma unroll
for (int i = 0; i < count; ++i) {
cute::warpgroup_fence_operand(regs[i]);
}
}
TL_DEVICE void warpgroup_fence_operand(float *regs, int count) {
#pragma unroll
for (int i = 0; i < count; ++i) {
cute::warpgroup_fence_operand(regs[i]);
}
}
// Template parameter: // Template parameter:
// thread_extent: the logical size (in number of threads) of each "group" // thread_extent: the logical size (in number of threads) of each "group"
// within which we want to elect exactly ONE representative // within which we want to elect exactly ONE representative
......
#pragma once #pragma once
#include "common.h" #include "common.h"
#include <cstdint>
#include <type_traits>
namespace tl { namespace tl {
// Select a wider accumulator type for improved numerical accuracy.
// Default: accumulate in the same type. Specialize FP16/BF16 to float.
template <typename T> struct AccType {
using type = T;
};
template <> struct AccType<half_t> {
using type = float;
};
template <> struct AccType<bfloat16_t> {
using type = float;
};
struct SumOp { struct SumOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) { template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x + y; return x + y;
...@@ -40,53 +54,6 @@ struct BitXorOp { ...@@ -40,53 +54,6 @@ struct BitXorOp {
} }
}; };
template <class Reducer, int Threads, bool UseAbs, bool NeedAccumulate>
struct SharedReduceWarp {
template <typename T>
static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst,
int total_dest, int reduce_extent, int tail,
T init_value) {
if (total_dest <= 0 || reduce_extent <= 0)
return;
constexpr int kWarpSize = 32;
static_assert(Threads % kWarpSize == 0,
"SharedReduceWarp expects blockDim.x to be a multiple of "
"warp size on CUDA.");
const int tid = threadIdx.x;
const int warp_id = tid / kWarpSize;
const int lane = tid % kWarpSize;
const int num_warps = Threads / kWarpSize;
for (int dest_idx = warp_id; dest_idx < total_dest; dest_idx += num_warps) {
const int prefix = tail == 1 ? dest_idx : dest_idx / tail;
const int suffix = tail == 1 ? 0 : dest_idx % tail;
const int src_base = (prefix * reduce_extent) * tail + suffix;
const int dst_index = prefix * tail + suffix;
T partial = init_value;
for (int rv = lane; rv < reduce_extent; rv += kWarpSize) {
T val = src[src_base + rv * tail];
if constexpr (UseAbs) {
val = val < T(0) ? -val : val;
}
partial = Reducer()(partial, val);
}
unsigned mask = __activemask();
for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) {
T other = __shfl_down_sync(mask, partial, offset);
partial = Reducer()(partial, other);
}
if (lane == 0) {
if constexpr (NeedAccumulate) {
partial = Reducer()(dst[dst_index], partial);
}
dst[dst_index] = partial;
}
}
}
};
template <class Reducer, int threads, int scale, int thread_offset = 0, template <class Reducer, int threads, int scale, int thread_offset = 0,
int all_threads = threads> int all_threads = threads>
struct AllReduce { struct AllReduce {
...@@ -102,7 +69,7 @@ struct AllReduce { ...@@ -102,7 +69,7 @@ struct AllReduce {
__syncthreads(); __syncthreads();
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]); x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else { } else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset))); x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
} }
if constexpr (offset == scale) { if constexpr (offset == scale) {
return x; return x;
...@@ -122,7 +89,7 @@ struct AllReduce { ...@@ -122,7 +89,7 @@ struct AllReduce {
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads)); asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]); x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else { } else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset))); x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
} }
if constexpr (offset == scale) { if constexpr (offset == scale) {
return x; return x;
...@@ -159,7 +126,7 @@ template <int threads, bool reverse = false> struct CumSum1D { ...@@ -159,7 +126,7 @@ template <int threads, bool reverse = false> struct CumSum1D {
#pragma unroll #pragma unroll
for (int off = 1; off < SEG; off <<= 1) { for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off); T n = (T)tl::shfl_down_sync(MASK, val, off);
if (lane < SEG - off) if (lane < SEG - off)
val += n; val += n;
} }
...@@ -234,7 +201,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D { ...@@ -234,7 +201,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
#pragma unroll #pragma unroll
for (int off = 1; off < SEG; off <<= 1) { for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off); T n = tl::shfl_down_sync(MASK, val, off);
if (lane < SEG - off) if (lane < SEG - off)
val += n; val += n;
} }
...@@ -244,10 +211,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D { ...@@ -244,10 +211,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
if (real_col < W) if (real_col < W)
dst[real_row * W + real_col] = val; dst[real_row * W + real_col] = val;
T segSum = (T)__shfl_sync(MASK, val, (T)0); T segSum = tl::shfl_sync(MASK, val, 0);
if (lane == 0) if (lane == 0)
carry = segSum; carry = segSum;
carry = (T)__shfl_sync(MASK, carry, (T)0); carry = tl::shfl_sync(MASK, carry, 0);
} }
} else { } else {
for (int seg = 0; seg * SEG < W; ++seg) { for (int seg = 0; seg * SEG < W; ++seg) {
...@@ -260,7 +227,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D { ...@@ -260,7 +227,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
#pragma unroll #pragma unroll
for (int off = 1; off < SEG; off <<= 1) { for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_up_sync(MASK, val, off); T n = tl::shfl_up_sync(MASK, val, off);
if (lane >= off) if (lane >= off)
val += n; val += n;
} }
...@@ -270,10 +237,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D { ...@@ -270,10 +237,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
if (real_col < W) if (real_col < W)
dst[real_row * W + real_col] = val; dst[real_row * W + real_col] = val;
T segSum = (T)__shfl_sync(MASK, val, SEG - 1); T segSum = tl::shfl_sync(MASK, val, SEG - 1);
if (lane == SEG - 1) if (lane == SEG - 1)
carry = segSum; carry = segSum;
carry = (T)__shfl_sync(MASK, carry, SEG - 1); carry = tl::shfl_sync(MASK, carry, SEG - 1);
} }
} }
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#endif #endif
#include "common.h" #include "common.h"
#include <cute/arch/cluster_sm90.hpp>
namespace tl { namespace tl {
...@@ -59,12 +60,15 @@ inline void __device__ amma_fp16bf16_ss(uint64_t const desc_a, ...@@ -59,12 +60,15 @@ inline void __device__ amma_fp16bf16_ss(uint64_t const desc_a,
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]));
} }
inline __device__ void amma_commit(uint64_t const *smem_ptr) { // Wrapper for CUTLASS umma_arrive: elect one lane, then arrive the mbarrier
TL_DEVICE void tcgen05_mma_arrive(void const *smem_ptr) {
uint32_t bar_intptr = smem_ptr_to_uint(smem_ptr); uint32_t bar_intptr = smem_ptr_to_uint(smem_ptr);
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::" if (cute::elect_one_sync()) {
"cluster.b64 [%0];" asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::"
: "cluster.b64 [%0];"
: "r"(bar_intptr)); :
: "r"(bar_intptr));
}
} }
} // namespace tl } // namespace tl
\ No newline at end of file
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
} }
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op); Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers; Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply([this](Buffer buf) { alloc_buffers.MutateByApply([this](Buffer buf) {
auto storage_scope = auto storage_scope =
...@@ -58,7 +58,7 @@ public: ...@@ -58,7 +58,7 @@ public:
buf->dtype.bytes()); buf->dtype.bytes());
if (!new_shape.same_as(buf->shape)) { if (!new_shape.same_as(buf->shape)) {
ObjectPtr<BufferNode> new_buffer = ObjectPtr<BufferNode> new_buffer =
make_object<BufferNode>(*(buf.get())); tvm::ffi::make_object<BufferNode>(*(buf.get()));
new_buffer->shape = std::move(new_shape); new_buffer->shape = std::move(new_shape);
buffer_remap_.Set(buf, Buffer(new_buffer)); buffer_remap_.Set(buf, Buffer(new_buffer));
return Buffer(new_buffer); return Buffer(new_buffer);
...@@ -73,7 +73,7 @@ public: ...@@ -73,7 +73,7 @@ public:
} }
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store_node = GetRef<BufferStore>(op); auto store_node = tvm::ffi::GetRef<BufferStore>(op);
Buffer buf = op->buffer; Buffer buf = op->buffer;
if (buffer_remap_.count(buf)) { if (buffer_remap_.count(buf)) {
buf = buffer_remap_[buf]; buf = buffer_remap_[buf];
...@@ -83,7 +83,7 @@ public: ...@@ -83,7 +83,7 @@ public:
} }
PrimExpr VisitExpr_(const BufferLoadNode *op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load_node = GetRef<BufferLoad>(op); auto load_node = tvm::ffi::GetRef<BufferLoad>(op);
Buffer buf = op->buffer; Buffer buf = op->buffer;
if (buffer_remap_.count(buf)) { if (buffer_remap_.count(buf)) {
buf = buffer_remap_[buf]; buf = buffer_remap_[buf];
...@@ -149,11 +149,11 @@ tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) { ...@@ -149,11 +149,11 @@ tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) {
"tl.AlignDynamicSharedMemoryAllocations", {}); "tl.AlignDynamicSharedMemoryAllocations", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AlignDynamicSharedMemoryAllocations", refl::GlobalDef().def("tl.transform.AlignDynamicSharedMemoryAllocations",
AlignDynamicSharedMemoryAllocations); AlignDynamicSharedMemoryAllocations);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -46,13 +46,13 @@ public: ...@@ -46,13 +46,13 @@ public:
Stmt VisitStmt_(const AttrStmtNode *op) final { Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::attr::kTarget) { if (op->attr_key == tvm::attr::kTarget) {
// If a target attribute already exists, use it as-is. // If a target attribute already exists, use it as-is.
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else if (op->attr_key == tir::attr::thread_extent || } else if (op->attr_key == tir::attr::thread_extent ||
op->attr_key == tir::attr::pipeline_exec_scope || op->attr_key == tir::attr::pipeline_exec_scope ||
op->attr_key == tir::attr::device_scope) { op->attr_key == tir::attr::device_scope) {
// These attributes are only allowed in device-side code, so // These attributes are only allowed in device-side code, so
// they should be annotated with the function's default target. // they should be annotated with the function's default target.
Stmt body = GetRef<Stmt>(op); Stmt body = tvm::ffi::GetRef<Stmt>(op);
return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); return AttrStmt(device_target_, tvm::attr::kTarget, 0, body);
} else { } else {
// All other annotations are ignored // All other annotations are ignored
...@@ -90,11 +90,11 @@ tvm::transform::Pass AnnotateDeviceRegions() { ...@@ -90,11 +90,11 @@ tvm::transform::Pass AnnotateDeviceRegions() {
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateDeviceRegions", {}); return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateDeviceRegions", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AnnotateDeviceRegions", refl::GlobalDef().def("tl.transform.AnnotateDeviceRegions",
AnnotateDeviceRegions); AnnotateDeviceRegions);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -124,7 +124,9 @@ private: ...@@ -124,7 +124,9 @@ private:
} }
auto producer_body = if_then_else->then_case; auto producer_body = if_then_else->then_case;
Optional<Stmt> consumer_body = if_then_else->else_case; Optional<Stmt> consumer_body = if_then_else->else_case;
ICHECK(consumer_body.defined()) << "Consumer body is undefined"; // In some degenerate warp-specialized patterns (e.g., producer-only),
// the consumer body may be absent. Handle gracefully by only annotating
// the producer side when consumer is missing.
auto dec_reg = nreg_[0].as<IntImmNode>()->value; auto dec_reg = nreg_[0].as<IntImmNode>()->value;
auto inc_reg = nreg_[1].as<IntImmNode>()->value; auto inc_reg = nreg_[1].as<IntImmNode>()->value;
...@@ -150,15 +152,20 @@ private: ...@@ -150,15 +152,20 @@ private:
producer_stmts.push_back(producer_body); producer_stmts.push_back(producer_body);
auto new_producer_body = SeqStmt(producer_stmts); auto new_producer_body = SeqStmt(producer_stmts);
Array<Stmt> consumer_stmts; Stmt new_if_stmt;
consumer_stmts.push_back(inc_reg_stmt); if (consumer_body.defined()) {
consumer_stmts.push_back(consumer_body.value()); Array<Stmt> consumer_stmts;
auto new_consumer_body = SeqStmt(consumer_stmts); consumer_stmts.push_back(inc_reg_stmt);
consumer_stmts.push_back(consumer_body.value());
auto new_consumer_body = SeqStmt(consumer_stmts);
new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
new_consumer_body);
} else {
// No consumer branch; keep the if-then form.
new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body);
}
auto new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
new_consumer_body);
auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt); auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt);
return new_attr; return new_attr;
} else { } else {
return StmtExprMutator::VisitStmt_(op); return StmtExprMutator::VisitStmt_(op);
...@@ -181,11 +188,11 @@ tvm::transform::Pass AnnotateWarpGroupRegAlloc() { ...@@ -181,11 +188,11 @@ tvm::transform::Pass AnnotateWarpGroupRegAlloc() {
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {}); return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc", refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc",
AnnotateWarpGroupRegAlloc); AnnotateWarpGroupRegAlloc);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -80,8 +80,8 @@ void ArgBinder::Bind(const PrimExpr &arg, const PrimExpr &value, ...@@ -80,8 +80,8 @@ void ArgBinder::Bind(const PrimExpr &arg, const PrimExpr &value,
Bind_(arg, value, arg_name, with_let); Bind_(arg, value, arg_name, with_let);
} }
void ArgBinder::BindArray(const Array<PrimExpr> &arg, void ArgBinder::BindArray(const ffi::Array<PrimExpr> &arg,
const Array<PrimExpr> &value, const ffi::Array<PrimExpr> &value,
const std::string &arg_name) { const std::string &arg_name) {
ICHECK_EQ(arg.size(), value.size()) ICHECK_EQ(arg.size(), value.size())
<< "Argument " << arg_name << " array size mismatch"; << "Argument " << arg_name << " array size mismatch";
...@@ -250,7 +250,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, ...@@ -250,7 +250,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
// Assert the buffer is compact // Assert the buffer is compact
DataType stype = buffer->DefaultIndexType(); DataType stype = buffer->DefaultIndexType();
PrimExpr expect_stride = make_const(stype, 1); PrimExpr expect_stride = make_const(stype, 1);
Array<PrimExpr> conds; ffi::Array<PrimExpr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) { for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1; size_t k = i - 1;
PrimExpr svalue = PrimExpr svalue =
......
...@@ -82,7 +82,8 @@ public: ...@@ -82,7 +82,8 @@ public:
* \param value The target expression value * \param value The target expression value
* \param arg_name argument name. * \param arg_name argument name.
*/ */
void BindArray(const Array<PrimExpr> &arg, const Array<PrimExpr> &value, void BindArray(const ffi::Array<PrimExpr> &arg,
const ffi::Array<PrimExpr> &value,
const std::string &arg_name); const std::string &arg_name);
/*! /*!
* \brief Bind symbolic buffer to another symbolic buffer * \brief Bind symbolic buffer to another symbolic buffer
...@@ -149,7 +150,7 @@ public: ...@@ -149,7 +150,7 @@ public:
*/ */
const std::vector<Stmt> &init_nest() const { return init_nest_; } const std::vector<Stmt> &init_nest() const { return init_nest_; }
/*! \return Handle data type of the data */ /*! \return Handle data type of the data */
const Map<Var, PrimExpr> &def_handle_dtype() const { const ffi::Map<Var, PrimExpr> &def_handle_dtype() const {
return def_handle_dtype_; return def_handle_dtype_;
} }
...@@ -164,7 +165,7 @@ private: ...@@ -164,7 +165,7 @@ private:
/*! \brief Initialize nest */ /*! \brief Initialize nest */
std::vector<Stmt> init_nest_; std::vector<Stmt> init_nest_;
/*! \brief handle data type in the defintiions */ /*! \brief handle data type in the defintiions */
Map<Var, PrimExpr> def_handle_dtype_; ffi::Map<Var, PrimExpr> def_handle_dtype_;
/*! \brief asserts generated */ /*! \brief asserts generated */
std::vector<Stmt> asserts_; std::vector<Stmt> asserts_;
/*! \brief internal analyzer. */ /*! \brief internal analyzer. */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment