Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
......@@ -37,18 +37,19 @@ ExtractFuncInfo(const IRModule &mod) {
}
info.arg_types.push_back(f->params[i].dtype());
}
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag);
}
}
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info;
}
return fmap;
}
runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
ffi::Module BuildTileLangHIP(IRModule mod, Target target) {
bool output_ssa = false;
CodeGenTileLangHIP cg;
cg.Init(output_ssa);
......@@ -84,7 +85,7 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string());
}
runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
ffi::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
bool output_ssa = false;
CodeGenTileLangHIP cg;
cg.Init(output_ssa);
......@@ -110,13 +111,13 @@ runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
std::string());
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("target.build.tilelang_hip", BuildTileLangHIP)
.def("target.build.tilelang_hip_without_compile",
BuildTileLangHIPWithoutCompile);
});
}
} // namespace codegen
} // namespace tvm
......@@ -5,6 +5,9 @@
#include "utils.h"
#include "../support/ffi_aliases.h"
#include <tvm/node/node.h>
namespace tvm {
namespace tl {
......@@ -16,8 +19,8 @@ bool TargetIsRocm(Target target) {
}
int GetArchInt(Target target) {
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
auto s = target->GetAttr<tvm::ffi::String>("arch");
ICHECK(s.has_value());
const std::string arch_str = s.value();
ICHECK(arch_str.size() >= 3);
ICHECK_EQ(arch_str.compare(0, 3, "sm_"), 0)
......@@ -71,7 +74,7 @@ bool TargetIsCDNA(Target target) {
if (!TargetIsRocm(target))
return false;
if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
// if mcpu start with "gfx9", it is CDNA
return mcpu.find("gfx9") == 0;
}
......@@ -94,7 +97,7 @@ bool TargetHasAsyncCopy(Target target) {
return arch >= 80;
} else if (TargetIsCDNA(target)) {
if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
if (mcpu.rfind("gfx9", 0) == 0) {
int gfx_version = std::stoi(mcpu.substr(3, 2));
return gfx_version >= 94;
......@@ -141,7 +144,7 @@ int TargetGetWarpSize(Target target) {
return res;
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tl.TargetIsCuda",
......@@ -170,7 +173,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
[](Target target) { return TargetHasBulkCopy(target); })
.def("tl.TargetGetWarpSize",
[](Target target) { return TargetGetWarpSize(target); });
});
}
} // namespace tl
} // namespace tvm
......@@ -10,6 +10,9 @@
#include <cutlass/numeric_types.h>
#include <math_constants.h>
#include <cutlass/bfloat16.h>
#include <cutlass/float8.h>
using cutlass::bfloat16_t;
using cutlass::half_t;
using cutlass::tfloat32_t;
......@@ -285,6 +288,138 @@ union GmmaDescriptor {
}
};
union Tcgen05SMemDescriptor {
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor() noexcept : desc_(0) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor(uint64_t desc) noexcept
: desc_(desc) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor(
Tcgen05SMemDescriptor const &t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor(
Tcgen05SMemDescriptor &&t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor &
operator=(Tcgen05SMemDescriptor const &t) noexcept {
desc_ = t.desc_;
return *this;
}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor &
operator=(Tcgen05SMemDescriptor &&t) noexcept {
desc_ = t.desc_;
return *this;
}
uint64_t desc_;
uint32_t reg32_[2];
// Bitfield implementation avoids the need for shifts in assignment
struct {
// start_address, bit [0,14), 4LSB not included
uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused
// leading dimension byte offset, bit [16,30), 4LSB not included
uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused
// stride dimension byte offset, bit [32,46), 4LSB not included
uint16_t stride_byte_offset_ : 14,
version_ : 2; // 14 bits [0,14), 2 bits [14,16)
// base_offset, bit [49,52). leading_byte_offset_mode, bit [52,53).
uint8_t : 1, base_offset_ : 3, lbo_mode_ : 1,
: 3; // 1 bit unused, 3 bits [1,4), 1 bit [4,5), 3 bits unused
// layout type, bit [61,64), SWIZZLE_NONE matrix descriptor = 0,
// SWIZZLE_128B matrix descriptor = 2, SWIZZLE_64B descriptor = 4,
// SWIZZLE_32B descriptor = 6, SWIZZLE_128B_BASE32B = 1, N/A = 3, N/A = 5,
// N/A = 7
uint8_t : 5, layout_type_ : 3; // 6 bits unused, 3 bits [5,8)
} bitfield;
// Separate the field, as we may only update one part of desc
struct {
uint32_t lo;
uint32_t hi;
} words;
CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept {
return desc_;
}
template <typename T>
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor
operator+(const T &offset) const {
Tcgen05SMemDescriptor ret;
// Address addition is in units of 16 bytes (4 LSB not encoded)
ret.reg32_[0] = reg32_[0] + (uint32_t(offset) >> 4);
ret.reg32_[1] = reg32_[1];
return ret;
}
};
//
// Tcgen05 instruction descriptor (wraps cute::UMMA::InstrDescriptor layout)
//
union Tcgen05InstrDescriptor {
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor() noexcept : desc_(0) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor(uint32_t desc) noexcept
: desc_(desc) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor(
Tcgen05InstrDescriptor const &t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor(
Tcgen05InstrDescriptor &&t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor &
operator=(Tcgen05InstrDescriptor const &t) noexcept {
desc_ = t.desc_;
return *this;
}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor &
operator=(Tcgen05InstrDescriptor &&t) noexcept {
desc_ = t.desc_;
return *this;
}
uint32_t desc_;
uint16_t reg16_[2];
// Bitfield implementation mirrors cute::UMMA::InstrDescriptor
struct {
// bit [ 0, 2) : Sparse meta data id2
uint16_t sparse_id2_ : 2,
// bit [ 2, 3) : 0 = dense. 1 = sparse. Only valid for
// F32F16/S8/MXF8F6F4
sparse_flag_ : 1,
// bit [ 3, 4) : 0 = no saturate. 1 = saturate. Only valid for S8
saturate_ : 1,
// bit [ 4, 6) : 0 = F16. 1 = F32, 2 = S32
c_format_ : 2,
// padding
: 1,
// bit [ 7,10) : see UMMA format encoding
a_format_ : 3,
// bit [10,13) : see UMMA format encoding
b_format_ : 3,
// bit [13,14) : 0 = no negate. 1 = negate
a_negate_ : 1,
// bit [14,15) : 0 = no negate. 1 = negate
b_negate_ : 1,
// bit [15,16) : 0 = K-major. 1 = MN-major
a_major_ : 1;
// Upper 16 bits
uint16_t b_major_ : 1, // bit [16,17)
n_dim_ : 6, // bit [17,23) : 3 LSBs not included
: 1, // padding
m_dim_ : 5, // bit [24,29) : 4 LSBs not included
: 1, // padding
max_shift_ : 2; // bit [30,32)
} bitfield;
// Decay to a uint32_t
CUTE_HOST_DEVICE constexpr explicit operator uint32_t() const noexcept {
return desc_;
}
};
// Any
template <typename T> TL_DEVICE bool Any(T *a, int size) {
for (int i = 0; i < size; i++) {
......@@ -323,7 +458,7 @@ TL_DEVICE void __sync_thread_partial() {
template <int layout_type = 0, int leading_byte_offset = 0,
int stride_byte_offset = 0, typename T>
TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
TL_DEVICE void initialize_wgmma_descriptor(GmmaDescriptor &descriptor,
T *start_address) {
descriptor.bitfield.start_address_ =
cute::cast_smem_ptr_to_uint(start_address) >> 4;
......@@ -333,15 +468,151 @@ TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
}
template <typename T>
TL_DEVICE void
initialize_tcgen05_descriptor(Tcgen05SMemDescriptor &descriptor,
T *start_address, int leading_byte_offset,
int stride_byte_offset, int base_offset,
bool leading_is_absolute, int swizzle_mode) {
descriptor.bitfield.start_address_ =
static_cast<uint16_t>(cast_smem_ptr_to_uint(start_address) >> 4);
descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
descriptor.bitfield.version_ = 1;
descriptor.bitfield.base_offset_ = base_offset & 0x7;
descriptor.bitfield.lbo_mode_ = leading_is_absolute ? 1 : 0;
descriptor.bitfield.layout_type_ = swizzle_mode & 0x7;
}
template <typename T>
TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
T offset) {
descriptor.reg32_[0] += (offset >> 4);
}
// and add the desired implicit conversion from bfloat16_t.
struct float_e4m3_t : public cute::float_e4m3_t {
using cute::float_e4m3_t::float_e4m3_t;
CUTLASS_HOST_DEVICE
float_e4m3_t() = default;
CUTLASS_HOST_DEVICE
explicit float_e4m3_t(__nv_bfloat16 x)
: float_e4m3_t(static_cast<float>(x)) {}
};
struct float_e5m2_t : public cute::float_e5m2_t {
using cute::float_e5m2_t::float_e5m2_t;
CUTLASS_HOST_DEVICE
float_e5m2_t() = default;
CUTLASS_HOST_DEVICE
explicit float_e5m2_t(__nv_bfloat16 x)
: float_e5m2_t(static_cast<float>(x)) {}
};
template <typename T> struct to_cute_type {
using type = T;
};
template <> struct to_cute_type<tl::float_e4m3_t> {
using type = cute::float_e4m3_t;
};
template <> struct to_cute_type<tl::float_e5m2_t> {
using type = cute::float_e5m2_t;
};
} // namespace tl
namespace cutlass {
TL_DEVICE
bfloat16_t fast_exp(bfloat16_t x) { return ::hexp(x); }
} // namespace cutlass
//
// Type-safe warp shuffle helpers for 16-bit float types
// These wrappers avoid relying on implicit conversions that may be disallowed
// (e.g., converting float -> cutlass::bfloat16_t) by explicitly promoting to
// float for the shuffle and then down-converting.
//
namespace tl {
// Generic passthroughs
template <typename T>
TL_DEVICE T shfl_xor_sync(unsigned mask, T val, int laneMask) {
return __shfl_xor_sync(mask, val, laneMask);
}
template <typename T>
TL_DEVICE T shfl_down_sync(unsigned mask, T val, int delta) {
return __shfl_down_sync(mask, val, delta);
}
template <typename T>
TL_DEVICE T shfl_up_sync(unsigned mask, T val, int delta) {
return __shfl_up_sync(mask, val, delta);
}
template <typename T> TL_DEVICE T shfl_sync(unsigned mask, T val, int srcLane) {
return __shfl_sync(mask, val, srcLane);
}
// Specializations for cutlass::half_t
template <>
TL_DEVICE half_t shfl_xor_sync(unsigned mask, half_t val, int laneMask) {
float f = static_cast<float>(val);
float r = __shfl_xor_sync(mask, f, laneMask);
return half_t(r);
}
template <>
TL_DEVICE half_t shfl_down_sync(unsigned mask, half_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_down_sync(mask, f, delta);
return half_t(r);
}
template <>
TL_DEVICE half_t shfl_up_sync(unsigned mask, half_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_up_sync(mask, f, delta);
return half_t(r);
}
template <> TL_DEVICE half_t shfl_sync(unsigned mask, half_t val, int srcLane) {
float f = static_cast<float>(val);
float r = __shfl_sync(mask, f, srcLane);
return half_t(r);
}
// Specializations for cutlass::bfloat16_t
template <>
TL_DEVICE bfloat16_t shfl_xor_sync(unsigned mask, bfloat16_t val,
int laneMask) {
float f = static_cast<float>(val);
float r = __shfl_xor_sync(mask, f, laneMask);
return bfloat16_t(r);
}
template <>
TL_DEVICE bfloat16_t shfl_down_sync(unsigned mask, bfloat16_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_down_sync(mask, f, delta);
return bfloat16_t(r);
}
template <>
TL_DEVICE bfloat16_t shfl_up_sync(unsigned mask, bfloat16_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_up_sync(mask, f, delta);
return bfloat16_t(r);
}
template <>
TL_DEVICE bfloat16_t shfl_sync(unsigned mask, bfloat16_t val, int srcLane) {
float f = static_cast<float>(val);
float r = __shfl_sync(mask, f, srcLane);
return bfloat16_t(r);
}
} // namespace tl
#pragma once
#include "common.h"
#include <cuda_fp8.h>
#include <cute/numeric/numeric_types.hpp>
using fp8_e4_t = cute::float_e4m3_t;
using fp8_e5_t = cute::float_e5m2_t;
using fp8_e4_t = tl::float_e4m3_t;
using fp8_e5_t = tl::float_e5m2_t;
struct __CUDA_ALIGN__(2) fp8_e4_2_t {
fp8_e4_t x;
......
......@@ -263,16 +263,18 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
typename std::conditional<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
typename std::conditional<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_cute>::type;
using C_type = C_type_raw;
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
using Instruction = DispatchInstruction<A_type_raw, B_type_raw, C_type_raw,
num_warp_m, num_warp_n, N>;
using OperandATraits = OperandTraits<sizeof_bits<A_type>::value, M, K,
!trans_A, num_warp_m, lda>;
......
......@@ -289,12 +289,14 @@ template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
typename std::conditional<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>::type;
typename std::conditional<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_cute>::type;
using C_type = C_type_raw;
static_assert(AtomM == 128 || AtomM == 64 || AtomM == 32);
......
......@@ -21,10 +21,12 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>;
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using C_type = C_type_raw;
static constexpr GMMA::Major GmmaMajorA =
......
......@@ -13,10 +13,12 @@ class GemmTensorOp {
public:
static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4");
using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>;
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_cute>;
using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
......
#pragma once
#include "../common.h"
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp>
#include <type_traits>
#include <utility>
namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false;
#endif
namespace detail {
template <class Impl> struct MmaImplTraits {
using DReg = std::remove_extent_t<typename Impl::DRegisters>;
using AReg = std::remove_extent_t<typename Impl::ARegisters>;
using BReg = std::remove_extent_t<typename Impl::BRegisters>;
using CReg = std::remove_extent_t<typename Impl::CRegisters>;
static constexpr int kDRegs = std::extent_v<typename Impl::DRegisters>;
static constexpr int kARegs = std::extent_v<typename Impl::ARegisters>;
static constexpr int kBRegs = std::extent_v<typename Impl::BRegisters>;
static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
};
template <class Impl, size_t... DIdx, size_t... AIdx, size_t... BIdx,
size_t... CIdx>
TL_DEVICE void
call_fma_impl(typename MmaImplTraits<Impl>::DReg *d,
const typename MmaImplTraits<Impl>::AReg *a,
const typename MmaImplTraits<Impl>::BReg *b,
const typename MmaImplTraits<Impl>::CReg *c,
std::index_sequence<DIdx...>, std::index_sequence<AIdx...>,
std::index_sequence<BIdx...>, std::index_sequence<CIdx...>) {
Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]...);
}
template <class Impl>
TL_DEVICE void call_fma(typename MmaImplTraits<Impl>::DReg *d,
const typename MmaImplTraits<Impl>::AReg *a,
const typename MmaImplTraits<Impl>::BReg *b,
const typename MmaImplTraits<Impl>::CReg *c) {
call_fma_impl<Impl>(d, a, b, c,
std::make_index_sequence<MmaImplTraits<Impl>::kDRegs>{},
std::make_index_sequence<MmaImplTraits<Impl>::kARegs>{},
std::make_index_sequence<MmaImplTraits<Impl>::kBRegs>{},
std::make_index_sequence<MmaImplTraits<Impl>::kCRegs>{});
}
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB, bool Saturate>
struct MmaDispatcher {
using CRegType = void;
using ARegType = void;
using BRegType = void;
static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *,
const CRegType *) {
static_assert(always_false_v<std::integral_constant<int, M>>,
"tl::mma_sync: unsupported configuration");
}
};
#define TL_DEFINE_MMA_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, MValue, \
NValue, KValue, TransAValue, TransBValue, \
SaturateValue, ImplType) \
template <> \
struct MmaDispatcher<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, MValue, NValue, KValue, \
TransAValue, TransBValue, SaturateValue> { \
using Impl = ImplType; \
using Traits = MmaImplTraits<Impl>; \
using CRegType = typename Traits::DReg; \
using ARegType = typename Traits::AReg; \
using BRegType = typename Traits::BReg; \
static_assert( \
std::is_same_v<typename Traits::DReg, typename Traits::CReg>, \
"tl::mma_sync requires matching accumulator/output regs"); \
static TL_DEVICE void exec(CRegType *d, const ARegType *a, \
const BRegType *b, const CRegType *c) { \
call_fma<Impl>(d, a, b, c); \
} \
};
// FP16 inputs (TN layout: A row-major, B column-major)
TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat16, 16, 8, 16, false, true,
false, cute::SM80_16x8x16_F16F16F16F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat32, 16, 8, 16, false, true,
false, cute::SM80_16x8x16_F32F16F16F32_TN)
// BF16 inputs
TL_DEFINE_MMA_DISPATCHER(kBFloat16, kBFloat16, kFloat32, 16, 8, 16, false, true,
false, cute::SM80_16x8x16_F32BF16BF16F32_TN)
// INT8 inputs (k32)
TL_DEFINE_MMA_DISPATCHER(kInt8, kInt8, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32S8S8S32_TN)
TL_DEFINE_MMA_DISPATCHER(kUInt8, kUInt8, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32U8U8S32_TN)
// INT4 inputs (k32)
TL_DEFINE_MMA_DISPATCHER(kInt4, kInt4, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32S4S4S32_TN)
TL_DEFINE_MMA_DISPATCHER(kUInt4, kUInt4, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32U4U4S32_TN)
// FP8 inputs (k32)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E4M3E4M3F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E4M3E4M3F32_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E4M3E5M2F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E4M3E5M2F32_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E5M2E4M3F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E5M2E4M3F32_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E5M2E5M2F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E5M2E5M2F32_TN)
// TF32 inputs (FP32 math on Tensor Cores)
// Support both k=4 and k=8 variants on SM80
TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 4,
false, true, false,
cute::SM80_16x8x4_F32TF32TF32F32_TN)
TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 8,
false, true, false,
cute::SM80_16x8x8_F32TF32TF32F32_TN)
// FP64 inputs (DMMA: m8n8k4, TN layout)
TL_DEFINE_MMA_DISPATCHER(kFloat64, kFloat64, kFloat64, 8, 8, 4, false, true,
false, cute::SM80_8x8x4_F64F64F64F64_TN)
#undef TL_DEFINE_MMA_DISPATCHER
} // namespace detail
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB, bool Saturate = false>
TL_DEVICE void mma_sync(
typename detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA, TransB,
Saturate>::CRegType *c,
const typename detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA,
TransB, Saturate>::ARegType *a,
const typename detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA,
TransB, Saturate>::BRegType *b) {
using Dispatcher = detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA,
TransB, Saturate>;
static_assert(!std::is_void_v<typename Dispatcher::CRegType>,
"tl::mma_sync: unsupported configuration");
Dispatcher::exec(c, a, b, c);
}
} // namespace tl
#pragma once
#include "../common.h"
#include <type_traits>
#include <utility>
namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false;
#endif
namespace detail {
// SM70 MMA Instruction Traits and Implementations
// SM70 supports m16n16k4 (m8n8k4 instruction at warp level) with FP16/FP32
// accumulation
// Base template for SM70 MMA implementation
template <DataType AType, DataType BType, DataType CType, bool TransA,
bool TransB>
struct MmaSm70Impl {
// Default: unsupported configuration
static constexpr bool kSupported = false;
static TL_DEVICE void exec(void *, const void *, const void *, const void *) {
static_assert(always_false_v<std::integral_constant<bool, TransA>>,
"tl::mma_sync_sm70: unsupported configuration");
}
};
// FP16 inputs, FP16 accumulation - col.col (TransA=true, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
true, true> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP16 accumulation - col.row (TransA=true, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
true, false> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP16 accumulation - row.col (TransA=false, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
false, true> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP16 accumulation - row.row (TransA=false, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
false, false> {
using DRegisters = unsigned[4];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = unsigned[4];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(unsigned &d0, unsigned &d1, unsigned &d2,
unsigned &d3, unsigned a0, unsigned a1, unsigned b0,
unsigned b1, unsigned c0, unsigned c1, unsigned c2,
unsigned c3) {
asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
"{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "r"(c0), "r"(c1),
"r"(c2), "r"(c3));
}
};
// FP16 inputs, FP32 accumulation - col.col (TransA=true, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
true, true> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// FP16 inputs, FP32 accumulation - col.row (TransA=true, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
true, false> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// FP16 inputs, FP32 accumulation - row.col (TransA=false, TransB=true)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
false, true> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// FP16 inputs, FP32 accumulation - row.row (TransA=false, TransB=false)
template <>
struct MmaSm70Impl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
false, false> {
using DRegisters = float[8];
using ARegisters = unsigned[2];
using BRegisters = unsigned[2];
using CRegisters = float[8];
static constexpr bool kSupported = true;
static TL_DEVICE void fma(float &d0, float &d1, float &d2, float &d3,
float &d4, float &d5, float &d6, float &d7,
unsigned a0, unsigned a1, unsigned b0, unsigned b1,
float c0, float c1, float c2, float c3, float c4,
float c5, float c6, float c7) {
asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 "
"{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), "=f"(d4), "=f"(d5),
"=f"(d6), "=f"(d7)
: "r"(a0), "r"(a1), "r"(b0), "r"(b1), "f"(c0), "f"(c1),
"f"(c2), "f"(c3), "f"(c4), "f"(c5), "f"(c6), "f"(c7));
}
};
// Helper to extract register types
template <class Impl> struct MmaSm70ImplTraits {
using DReg = std::remove_extent_t<typename Impl::DRegisters>;
using AReg = std::remove_extent_t<typename Impl::ARegisters>;
using BReg = std::remove_extent_t<typename Impl::BRegisters>;
using CReg = std::remove_extent_t<typename Impl::CRegisters>;
static constexpr int kDRegs = std::extent_v<typename Impl::DRegisters>;
static constexpr int kARegs = std::extent_v<typename Impl::ARegisters>;
static constexpr int kBRegs = std::extent_v<typename Impl::BRegisters>;
static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
};
// Dispatcher for SM70 MMA operations
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB>
struct MmaSm70Dispatcher {
using CRegType = void;
using ARegType = void;
using BRegType = void;
static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *,
const CRegType *) {
static_assert(always_false_v<std::integral_constant<int, M>>,
"tl::mma_sync_sm70: unsupported configuration. "
"SM70 only supports m16n16k4 with FP16 inputs and FP16/FP32 "
"accumulation.");
}
};
// Helper to call fma with unpacked register arrays
template <class Impl, size_t... DIdx, size_t... AIdx, size_t... BIdx,
size_t... CIdx>
TL_DEVICE void
call_fma_impl_sm70(typename MmaSm70ImplTraits<Impl>::DReg *d,
const typename MmaSm70ImplTraits<Impl>::AReg *a,
const typename MmaSm70ImplTraits<Impl>::BReg *b,
const typename MmaSm70ImplTraits<Impl>::CReg *c,
std::index_sequence<DIdx...>, std::index_sequence<AIdx...>,
std::index_sequence<BIdx...>, std::index_sequence<CIdx...>) {
Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]...);
}
template <class Impl>
TL_DEVICE void call_fma_sm70(typename MmaSm70ImplTraits<Impl>::DReg *d,
const typename MmaSm70ImplTraits<Impl>::AReg *a,
const typename MmaSm70ImplTraits<Impl>::BReg *b,
const typename MmaSm70ImplTraits<Impl>::CReg *c) {
call_fma_impl_sm70<Impl>(
d, a, b, c, std::make_index_sequence<MmaSm70ImplTraits<Impl>::kDRegs>{},
std::make_index_sequence<MmaSm70ImplTraits<Impl>::kARegs>{},
std::make_index_sequence<MmaSm70ImplTraits<Impl>::kBRegs>{},
std::make_index_sequence<MmaSm70ImplTraits<Impl>::kCRegs>{});
}
// Define dispatchers for all supported SM70 configurations
// Note: m8n8k4 instruction computes m16n16k4 at warp level
#define TL_DEFINE_MMA_SM70_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, \
TransAValue, TransBValue) \
template <> \
struct MmaSm70Dispatcher<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, 16, 16, 4, TransAValue, \
TransBValue> { \
using Impl = MmaSm70Impl<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, TransAValue, TransBValue>; \
using Traits = MmaSm70ImplTraits<Impl>; \
using CRegType = typename Traits::DReg; \
using ARegType = typename Traits::AReg; \
using BRegType = typename Traits::BReg; \
static_assert( \
std::is_same_v<typename Traits::DReg, typename Traits::CReg>, \
"tl::mma_sync_sm70 requires matching accumulator/output regs"); \
static TL_DEVICE void exec(CRegType *d, const ARegType *a, \
const BRegType *b, const CRegType *c) { \
call_fma_sm70<Impl>(d, a, b, c); \
} \
};
// FP16 inputs with FP16 accumulation (all layout combinations)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, true, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, true, false)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, false, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat16, false, false)
// FP16 inputs with FP32 accumulation (all layout combinations)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, true, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, true, false)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, false, true)
TL_DEFINE_MMA_SM70_DISPATCHER(kFloat16, kFloat16, kFloat32, false, false)
#undef TL_DEFINE_MMA_SM70_DISPATCHER
} // namespace detail
/// SM70 MMA synchronous instruction wrapper
/// Supports m16n16k4 shape (m8n8k4 instruction at warp level) with FP16 inputs
/// and FP16/FP32 accumulation
///
/// @tparam AType Input A data type (kFloat16)
/// @tparam BType Input B data type (kFloat16)
/// @tparam CType Accumulator/output data type (kFloat16 or kFloat32)
/// @tparam M Matrix M dimension (16)
/// @tparam N Matrix N dimension (16)
/// @tparam K Matrix K dimension (4)
/// @tparam TransA Whether A is transposed (false=row-major, true=col-major)
/// @tparam TransB Whether B is transposed (false=row-major, true=col-major)
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB>
TL_DEVICE void mma_sync_sm70(
typename detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K, TransA,
TransB>::CRegType *c,
const typename detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K,
TransA, TransB>::ARegType *a,
const typename detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K,
TransA, TransB>::BRegType *b) {
using Dispatcher =
detail::MmaSm70Dispatcher<AType, BType, CType, M, N, K, TransA, TransB>;
static_assert(!std::is_void_v<typename Dispatcher::CRegType>,
"tl::mma_sync_sm70: unsupported configuration. "
"SM70 only supports m16n16k4 with FP16 inputs.");
Dispatcher::exec(c, a, b, c);
}
} // namespace tl
#pragma once
#include "../common.h"
#include <cute/arch/cluster_sm90.hpp>
namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false;
#endif
// Generic declaration: unsupported by default
template <DataType C_type>
TL_DEVICE void
tcgen05mma_ss(uint64_t const & /*desc_a*/, uint64_t const & /*desc_b*/,
uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/,
uint32_t const & /*desc_val*/, int const & /*mask0*/,
int const & /*mask1*/, int const & /*mask2*/,
int const & /*mask3*/) {
static_assert(
always_false_v<std::integral_constant<int, static_cast<int>(C_type)>>,
"tl::tcgen05mma_ss: unsupported accumulator type");
}
// TS variants: A from TMEM, B from SMEM (desc)
// Generic declaration: unsupported by default
template <DataType C_type>
TL_DEVICE void
tcgen05mma_ts(uint32_t const & /*tmem_a*/, uint64_t const & /*desc_b*/,
uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/,
uint32_t const & /*desc_val*/, int const & /*mask0*/,
int const & /*mask1*/, int const & /*mask2*/,
int const & /*mask3*/) {
static_assert(
always_false_v<std::integral_constant<int, static_cast<int>(C_type)>>,
"tl::tcgen05mma_ts: unsupported accumulator type");
}
// F16/BF16 instruction kind (maps to kind::f16)
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kFloat16>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// BF16 maps to the same f16-kind instruction
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kBFloat16>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ts<DataType::kFloat16>(tmem_a, desc_b, tmem_c, scalec, desc_val,
mask0, mask1, mask2, mask3);
}
// TF32 instruction kind
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kTensorFloat32>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::tf32 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// INT8 instruction kind
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kInt8>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::i8 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// FP8 family instruction kind (maps to f8f6f4)
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kFloat8_e4m3>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], [%1], %2, %3, "
"{%5, %6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kFloat8_e5m2>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ts<DataType::kFloat8_e4m3>(tmem_a, desc_b, tmem_c, scalec,
desc_val, mask0, mask1, mask2, mask3);
}
// F16/BF16 instruction kind (maps to kind::f16)
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
// idescE upper 32 bits carry the instruction descriptor; lower 32 ignored for
// SS Load TMEM base from shared memory slot handled by caller
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// BF16 maps to the same f16-kind instruction
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kBFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ss<DataType::kFloat16>(desc_a, desc_b, tmem_c, scalec, desc_val,
mask0, mask1, mask2, mask3);
}
// TF32 instruction kind
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kTensorFloat32>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::tf32 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// INT8 instruction kind
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kInt8>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::i8 [%0], %1, %2, %3, {%5, %6, "
"%7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// FP8 family instruction kind (maps to f8f6f4)
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kFloat8_e4m3>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kFloat8_e5m2>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ss<DataType::kFloat8_e4m3>(desc_a, desc_b, tmem_c, scalec,
desc_val, mask0, mask1, mask2, mask3);
}
// WS variants: tcgen05.mma.ws.cta_group::1.kind::xxx
// Generic declaration falls back to static assert
template <DataType C_type>
TL_DEVICE void
tcgen05mma_ws_ss(uint64_t const & /*desc_a*/, uint64_t const & /*desc_b*/,
uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/,
uint32_t const & /*desc_val*/, int const & /*mask0*/,
int const & /*mask1*/, int const & /*mask2*/,
int const & /*mask3*/) {
static_assert(
always_false_v<std::integral_constant<int, static_cast<int>(C_type)>>,
"tl::tcgen05mma_ws_ss: unsupported accumulator type");
}
// F16/BF16 ws
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kBFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ws_ss<DataType::kFloat16>(desc_a, desc_b, tmem_c, scalec, desc_val,
mask0, mask1, mask2, mask3);
}
// TF32 ws
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kTensorFloat32>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::tf32 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
// INT8 ws
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kInt8>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::i8 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
// FP8 ws (maps to f8f6f4)
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kFloat8_e4m3>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kFloat8_e5m2>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ws_ss<DataType::kFloat8_e4m3>(
desc_a, desc_b, tmem_c, scalec, desc_val, mask0, mask1, mask2, mask3);
}
} // namespace tl
#pragma once
#include "../common.h"
#include "cute/arch/mma_sm90_gmma.hpp"
#include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp>
#include <type_traits>
#include <utility>
namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false;
#endif
// 主类模板 - 移除默认参数,因为特化不能有默认参数
template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
int K, bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
printf("DEBUG: WgmmaSSImpl fallback - A_type=%d (kFloat16=%d), B_type=%d, "
"C_type=%d, M=%d, N=%d, K=%d, tnspA=%d, tnspB=%d, scaleA=%d, "
"scaleB=%d\n",
(int)A_type, (int)DataType::kFloat16, (int)B_type, (int)C_type, M, N,
K, (int)tnspA, (int)tnspB, scaleA, scaleB);
// 暂时注释掉 static_assert 来看调试输出
// static_assert(always_false_v<decltype(c)>,
// "wgmma_ss: No specialization available for given template
// parameters!");
};
};
// ================================= F16 x F16 -> F16
// =================================
namespace detail {
// M64N8K16 F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 8, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
template <bool IsMnMajor> struct MajorValue {
static constexpr auto value =
IsMnMajor ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K;
};
// M64N16K16 F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 16, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
template <int Scale> struct ScaleInValue {
static_assert(Scale == 1 || Scale == -1,
"tl::wgmma requires scale factors of +1 or -1.");
static constexpr auto value = Scale == 1 ? cute::SM90::GMMA::ScaleIn::One
: cute::SM90::GMMA::ScaleIn::Neg;
};
// M64N32K16 F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 32, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
}
};
template <int Scale>
inline constexpr bool IsValidScale = (Scale == 1 || Scale == -1);
// M64N64K16 F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 64, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %18, 0;\n"
"wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15},"
" %16, %17, p, %19, %20, %21, %22;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
template <class Impl> struct CallWgmmaSS {
using CReg = std::remove_extent_t<typename Impl::CRegisters>;
static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
static_assert(sizeof(CReg) == sizeof(uint32_t),
"tl::wgmma_ss expects 32-bit accumulator registers.");
// M64N96K16 F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 96, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %26, 0;\n"
"wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23}, "
"%24, %25, p, %27, %28, %29, %30;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]),
"+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]),
"+r"(c[22]), "+r"(c[23])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
template <size_t... Idx>
TL_DEVICE static void Run(uint64_t desc_a, uint64_t desc_b, CReg *c,
cute::SM90::GMMA::ScaleOut scale,
std::index_sequence<Idx...>) {
Impl::fma(desc_a, desc_b, c[Idx]..., scale);
}
};
// M64N128K16 F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 128, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
TL_DEVICE static void exec(uint64_t desc_a, uint64_t desc_b, uint32_t *c_raw,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %34, 0;\n"
"wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31}, "
"%32, %33, p, %35, %36, %37, %38;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]),
"+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]),
"+r"(c[22]), "+r"(c[23]), "+r"(c[24]), "+r"(c[25]),
"+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]),
"+r"(c[30]), "+r"(c[31])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One
: cute::SM90::GMMA::ScaleOut::Zero;
auto c = reinterpret_cast<CReg *>(c_raw);
Run(desc_a, desc_b, c, scale, std::make_index_sequence<kCRegs>{});
}
};
// M64N192K16 F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 192, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %50, 0;\n"
"wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31, "
"%32, %33, %34, %35, %36, %37, %38, %39, "
"%40, %41, %42, %43, %44, %45, %46, %47}, "
"%48, %49, p, %51, %52, %53, %54;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), "+r"(c[14]),
"+r"(c[15]), "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]),
"+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), "+r"(c[24]),
"+r"(c[25]), "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]),
"+r"(c[30]), "+r"(c[31]), "+r"(c[32]), "+r"(c[33]), "+r"(c[34]),
"+r"(c[35]), "+r"(c[36]), "+r"(c[37]), "+r"(c[38]), "+r"(c[39]),
"+r"(c[40]), "+r"(c[41]), "+r"(c[42]), "+r"(c[43]), "+r"(c[44]),
"+r"(c[45]), "+r"(c[46]), "+r"(c[47])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
template <class Impl> struct CallWgmmaRS {
using AReg = std::remove_extent_t<typename Impl::ARegisters>;
using CReg = std::remove_extent_t<typename Impl::CRegisters>;
static constexpr int kARegs = std::extent_v<typename Impl::ARegisters>;
static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
static_assert(sizeof(AReg) == sizeof(uint32_t),
"tl::wgmma_rs expects 32-bit register operands for A.");
static_assert(sizeof(CReg) == sizeof(uint32_t) ||
sizeof(CReg) == sizeof(float),
"tl::wgmma_rs expects 32-bit accumulator registers.");
template <size_t... AIdx, size_t... CIdx>
TL_DEVICE static void
Run(const AReg *a, uint64_t desc_b, CReg *c, cute::SM90::GMMA::ScaleOut scale,
std::index_sequence<AIdx...>, std::index_sequence<CIdx...>) {
Impl::fma(a[AIdx]..., desc_b, c[CIdx]..., scale);
}
};
// M64N256K16 F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 256, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %66, 0;\n"
"wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31, "
"%32, %33, %34, %35, %36, %37, %38, %39, "
"%40, %41, %42, %43, %44, %45, %46, %47, "
"%48, %49, %50, %51, %52, %53, %54, %55, "
"%56, %57, %58, %59, %60, %61, %62, %63}, "
"%64, %65, p, %67, %68, %69, %70;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), "+r"(c[14]),
"+r"(c[15]), "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]),
"+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), "+r"(c[24]),
"+r"(c[25]), "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]),
"+r"(c[30]), "+r"(c[31]), "+r"(c[32]), "+r"(c[33]), "+r"(c[34]),
"+r"(c[35]), "+r"(c[36]), "+r"(c[37]), "+r"(c[38]), "+r"(c[39]),
"+r"(c[40]), "+r"(c[41]), "+r"(c[42]), "+r"(c[43]), "+r"(c[44]),
"+r"(c[45]), "+r"(c[46]), "+r"(c[47]), "+r"(c[48]), "+r"(c[49]),
"+r"(c[50]), "+r"(c[51]), "+r"(c[52]), "+r"(c[53]), "+r"(c[54]),
"+r"(c[55]), "+r"(c[56]), "+r"(c[57]), "+r"(c[58]), "+r"(c[59]),
"+r"(c[60]), "+r"(c[61]), "+r"(c[62]), "+r"(c[63])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
TL_DEVICE static void exec(const uint32_t *a_raw, uint64_t desc_b,
uint32_t *c_raw, bool scale_out) {
auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One
: cute::SM90::GMMA::ScaleOut::Zero;
auto a = reinterpret_cast<const AReg *>(a_raw);
auto c = reinterpret_cast<CReg *>(c_raw);
Run(a, desc_b, c, scale, std::make_index_sequence<kARegs>{},
std::make_index_sequence<kCRegs>{});
}
};
// ================================= F16 x F16 -> F32
// =================================
} // namespace detail
// M64N8K16 F16->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
64, 8, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N16K16 F16->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
64, 16, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
}
};
// M64N32K16 F16->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
64, 32, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %18, 0;\n"
"wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15}, "
"%16, %17, p, %19, %20, %21, %22;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N64K16 F16->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
64, 64, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %34, 0;\n"
"wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31}, "
"%32, %33, p, %35, %36, %37, %38;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]),
"+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]),
"+r"(c[22]), "+r"(c[23]), "+r"(c[24]), "+r"(c[25]),
"+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]),
"+r"(c[30]), "+r"(c[31])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// ================================= BF16 x BF16 -> F32
// =================================
// M64N8K16 BF16->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kBFloat16, DataType::kBFloat16, DataType::kFloat32,
64, 8, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N16K16 BF16->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kBFloat16, DataType::kBFloat16, DataType::kFloat32,
64, 16, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
int K, bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl {
static_assert(detail::IsValidScale<scaleA>, "tl::wgmma_ss: invalid scaleA");
static_assert(detail::IsValidScale<scaleB>, "tl::wgmma_ss: invalid scaleB");
TL_DEVICE static void execute(uint64_t, uint64_t, uint32_t *, bool) {
static_assert(always_false_v<std::integral_constant<int, M>>,
"tl::wgmma_ss: unsupported configuration");
}
};
// ================================= TF32 x TF32 -> F32
// =================================
// M64N8K8 TF32->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kTensorFloat32, DataType::kTensorFloat32,
DataType::kFloat32, 64, 8, 8, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
int K, bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaRSImpl {
static_assert(detail::IsValidScale<scaleA>, "tl::wgmma_rs: invalid scaleA");
static_assert(detail::IsValidScale<scaleB>, "tl::wgmma_rs: invalid scaleB");
TL_DEVICE static void execute(const uint32_t *, uint64_t, uint32_t *, bool) {
static_assert(always_false_v<std::integral_constant<int, M>>,
"tl::wgmma_rs: unsupported configuration");
}
};
// M64N16K8 TF32->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kTensorFloat32, DataType::kTensorFloat32,
DataType::kFloat32, 64, 16, 8, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
}
};
#define TL_WGMMA_DEFINE_SS_GENERAL(AType, BType, CType, M, N, K, ImplName) \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> \
struct WgmmaSSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
K, tnspA, tnspB, scaleA, scaleB> { \
static_assert(detail::IsValidScale<scaleA>, \
"tl::wgmma_ss: invalid scaleA"); \
static_assert(detail::IsValidScale<scaleB>, \
"tl::wgmma_ss: invalid scaleB"); \
using Impl = \
cute::SM90::GMMA::ImplName<detail::MajorValue<tnspA>::value, \
detail::MajorValue<tnspB>::value, \
detail::ScaleInValue<scaleA>::value, \
detail::ScaleInValue<scaleB>::value>; \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \
uint32_t *c, bool scale_out) { \
detail::CallWgmmaSS<Impl>::exec(desc_a, desc_b, c, scale_out); \
} \
};
// ================================= INT8 x INT8 -> INT32
// =================================
#define TL_WGMMA_DEFINE_SS_TN(AType, BType, CType, M, N, K, ImplName) \
template <int scaleA, int scaleB> \
struct WgmmaSSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
K, false, false, scaleA, scaleB> { \
static_assert(detail::IsValidScale<scaleA>, \
"tl::wgmma_ss: invalid scaleA"); \
static_assert(detail::IsValidScale<scaleB>, \
"tl::wgmma_ss: invalid scaleB"); \
using Impl = \
cute::SM90::GMMA::ImplName<detail::ScaleInValue<scaleA>::value, \
detail::ScaleInValue<scaleB>::value>; \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \
uint32_t *c, bool scale_out) { \
detail::CallWgmmaSS<Impl>::exec(desc_a, desc_b, c, scale_out); \
} \
};
// M64N8K32 S8->S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kInt8, DataType::kInt8, DataType::kInt32, 64, 8,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
#define TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \
ImplName) \
template <int scaleA, int scaleB> \
struct WgmmaSSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
K, false, false, scaleA, scaleB> { \
static_assert(detail::IsValidScale<scaleA>, \
"tl::wgmma_ss: invalid scaleA"); \
static_assert(detail::IsValidScale<scaleB>, \
"tl::wgmma_ss: invalid scaleB"); \
static_assert(scaleA == 1 && scaleB == 1, \
"tl::wgmma_ss: only +1 scaling supported for this WGMMA"); \
using Impl = cute::SM90::GMMA::ImplName; \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \
uint32_t *c, bool scale_out) { \
detail::CallWgmmaSS<Impl>::exec(desc_a, desc_b, c, scale_out); \
} \
};
// M64N16K32 S8->S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kInt8, DataType::kInt8, DataType::kInt32, 64, 16,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
#define TL_WGMMA_DEFINE_RS_GENERAL(AType, BType, CType, M, N, K, ImplName) \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> \
struct WgmmaRSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
K, tnspA, tnspB, scaleA, scaleB> { \
static_assert(!tnspA, "tl::wgmma_rs: operand A must be K-major"); \
static_assert(detail::IsValidScale<scaleA>, \
"tl::wgmma_rs: invalid scaleA"); \
static_assert(detail::IsValidScale<scaleB>, \
"tl::wgmma_rs: invalid scaleB"); \
using Impl = \
cute::SM90::GMMA::ImplName<detail::MajorValue<tnspA>::value, \
detail::MajorValue<tnspB>::value, \
detail::ScaleInValue<scaleA>::value, \
detail::ScaleInValue<scaleB>::value>; \
TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \
uint32_t *c, bool scale_out) { \
detail::CallWgmmaRS<Impl>::exec(a, desc_b, c, scale_out); \
} \
};
// ================================= FP8 x FP8 -> F16/F32
// =================================
#define TL_WGMMA_DEFINE_RS_TN(AType, BType, CType, M, N, K, ImplName) \
template <int scaleA, int scaleB> \
struct WgmmaRSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
K, false, false, scaleA, scaleB> { \
static_assert(detail::IsValidScale<scaleA>, \
"tl::wgmma_rs: invalid scaleA"); \
static_assert(detail::IsValidScale<scaleB>, \
"tl::wgmma_rs: invalid scaleB"); \
using Impl = \
cute::SM90::GMMA::ImplName<detail::ScaleInValue<scaleA>::value, \
detail::ScaleInValue<scaleB>::value>; \
TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \
uint32_t *c, bool scale_out) { \
detail::CallWgmmaRS<Impl>::exec(a, desc_b, c, scale_out); \
} \
};
// M64N8K32 E4M3->F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e4m3,
DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
#define TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \
ImplName) \
template <int scaleA, int scaleB> \
struct WgmmaRSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
K, false, false, scaleA, scaleB> { \
static_assert(detail::IsValidScale<scaleA>, \
"tl::wgmma_rs: invalid scaleA"); \
static_assert(detail::IsValidScale<scaleB>, \
"tl::wgmma_rs: invalid scaleB"); \
static_assert(scaleA == 1 && scaleB == 1, \
"tl::wgmma_rs: only +1 scaling supported for this WGMMA"); \
using Impl = cute::SM90::GMMA::ImplName; \
TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \
uint32_t *c, bool scale_out) { \
detail::CallWgmmaRS<Impl>::exec(a, desc_b, c, scale_out); \
} \
};
// M64N8K32 E4M3->F32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e4m3,
DataType::kFloat32, 64, 8, 32, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
#define TL_WGMMA_FOREACH_N_FLOAT_MUL8(OP) \
OP(8) \
OP(16) \
OP(24) \
OP(32) \
OP(40) \
OP(48) \
OP(56) \
OP(64) \
OP(72) \
OP(80) \
OP(88) \
OP(96) \
OP(104) \
OP(112) \
OP(120) \
OP(128) \
OP(136) \
OP(144) \
OP(152) \
OP(160) \
OP(168) \
OP(176) \
OP(184) \
OP(192) \
OP(200) \
OP(208) \
OP(216) \
OP(224) \
OP(232) \
OP(240) \
OP(248) \
OP(256)
#define TL_WGMMA_FOREACH_N_INT32_MUL8(OP) \
OP(8) \
OP(16) \
OP(24) \
OP(32) \
OP(48) \
OP(64) \
OP(80) \
OP(96) \
OP(112) \
OP(128) \
OP(144) \
OP(160) \
OP(176) \
OP(192) \
OP(208) \
OP(224) \
OP(240) \
OP(256)
#define TL_WGMMA_DEFINE_F16_F16_F16_SS(N) \
TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \
MMA_64x##N##x16_F16F16F16_SS)
#define TL_WGMMA_DEFINE_F16_F16_F32_SS(N) \
TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32F16F16_SS)
#define TL_WGMMA_DEFINE_BF16_BF16_F32_SS(N) \
TL_WGMMA_DEFINE_SS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32BF16BF16_SS)
#define TL_WGMMA_DEFINE_F32_TF32_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \
MMA_64x##N##x8_F32TF32TF32_SS_TN)
#define TL_WGMMA_DEFINE_S32_S8S8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8S8_SS_TN)
#define TL_WGMMA_DEFINE_S32_S8U8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8U8_SS_TN)
#define TL_WGMMA_DEFINE_S32_U8S8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8S8_SS_TN)
#define TL_WGMMA_DEFINE_S32_U8U8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8U8_SS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E5M2_SS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E5M2_SS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E5M2_SS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E5M2_SS_TN)
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F16_SS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F32_SS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_BF16_BF16_F32_SS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_TF32_SS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8S8_SS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8U8_SS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8S8_SS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8U8_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN);
#define TL_WGMMA_DEFINE_F16_F16_F16_RS(N) \
TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \
MMA_64x##N##x16_F16F16F16_RS)
#define TL_WGMMA_DEFINE_F16_F16_F32_RS(N) \
TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32F16F16_RS)
#define TL_WGMMA_DEFINE_BF16_BF16_F32_RS(N) \
TL_WGMMA_DEFINE_RS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32BF16BF16_RS)
#define TL_WGMMA_DEFINE_F32_TF32_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \
MMA_64x##N##x8_F32TF32TF32_RS_TN)
#define TL_WGMMA_DEFINE_S32_S8S8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8S8_RS_TN)
#define TL_WGMMA_DEFINE_S32_S8U8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8U8_RS_TN)
#define TL_WGMMA_DEFINE_S32_U8S8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8S8_RS_TN)
#define TL_WGMMA_DEFINE_S32_U8U8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8U8_RS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E5M2_RS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E5M2_RS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E5M2_RS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E5M2_RS_TN)
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F16_RS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F32_RS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_BF16_BF16_F32_RS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_TF32_RS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8S8_RS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8U8_RS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8S8_RS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8U8_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN);
#undef TL_WGMMA_DEFINE_F16_F16_F16_SS
#undef TL_WGMMA_DEFINE_F16_F16_F32_SS
#undef TL_WGMMA_DEFINE_BF16_BF16_F32_SS
#undef TL_WGMMA_DEFINE_F32_TF32_SS_TN
#undef TL_WGMMA_DEFINE_S32_S8S8_SS_TN
#undef TL_WGMMA_DEFINE_S32_S8U8_SS_TN
#undef TL_WGMMA_DEFINE_S32_U8S8_SS_TN
#undef TL_WGMMA_DEFINE_S32_U8U8_SS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F16_F16_F16_RS
#undef TL_WGMMA_DEFINE_F16_F16_F32_RS
#undef TL_WGMMA_DEFINE_BF16_BF16_F32_RS
#undef TL_WGMMA_DEFINE_F32_TF32_RS_TN
#undef TL_WGMMA_DEFINE_S32_S8S8_RS_TN
#undef TL_WGMMA_DEFINE_S32_S8U8_RS_TN
#undef TL_WGMMA_DEFINE_S32_U8S8_RS_TN
#undef TL_WGMMA_DEFINE_S32_U8U8_RS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN
#undef TL_WGMMA_FOREACH_N_FLOAT_MUL8
#undef TL_WGMMA_FOREACH_N_INT32_MUL8
#undef TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE
#undef TL_WGMMA_DEFINE_SS_GENERAL
#undef TL_WGMMA_DEFINE_SS_TN
#undef TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE
#undef TL_WGMMA_DEFINE_RS_GENERAL
#undef TL_WGMMA_DEFINE_RS_TN
// 函数模板委托给类模板
template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
int K, bool tnspA, bool tnspB, int scaleA = 1, int scaleB = 1>
TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
......@@ -519,129 +460,12 @@ TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
scaleB>::execute(desc_a, desc_b, c, scale_out);
}
// ================================= Mixed Precision Support
// =================================
// Mixed precision: S8 x U8 -> S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kInt8, DataType::kUInt8, DataType::kInt32, 64, 8,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// Mixed precision: U8 x S8 -> S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kUInt8, DataType::kInt8, DataType::kInt32, 64, 8,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// Mixed precision: U8 x U8 -> S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kUInt8, DataType::kUInt8, DataType::kInt32, 64, 8,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// Mixed precision FP8: E4M3 x E5M2 -> F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e5m2,
DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// Mixed precision FP8: E5M2 x E4M3 -> F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat8_e5m2, DataType::kFloat8_e4m3,
DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
int K, bool tnspA, bool tnspB, int scaleA = 1, int scaleB = 1>
TL_DEVICE void wgmma_rs(const uint32_t *a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// ================================= Convenience Templates
// =================================
// Type trait to determine the number of output registers needed
template <DataType C_type, int M, int N> struct WgmmaOutputRegs {
static constexpr int value =
(M * N * (C_type == DataType::kFloat32 ? 32 : 16)) / (32 * 8);
};
// Type trait to get element size in bits
template <DataType dtype> struct ElementBits {
static constexpr int value =
(dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 ||
dtype == DataType::kInt32)
? 32
: (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 ||
dtype == DataType::kInt16 || dtype == DataType::kUInt16)
? 16
: (dtype == DataType::kInt8 || dtype == DataType::kUInt8 ||
dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2)
? 8
: (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4
: 8;
};
WgmmaRSImpl<A_type, B_type, C_type, M, N, K, tnspA, tnspB, scaleA,
scaleB>::execute(a, desc_b, c, scale_out);
}
} // namespace tl
......@@ -67,6 +67,20 @@ template <int NumMma> TL_DEVICE void warpgroup_wait() {
cute::warpgroup_wait<NumMma>();
}
TL_DEVICE void warpgroup_fence_operand(uint32_t *regs, int count) {
#pragma unroll
for (int i = 0; i < count; ++i) {
cute::warpgroup_fence_operand(regs[i]);
}
}
TL_DEVICE void warpgroup_fence_operand(float *regs, int count) {
#pragma unroll
for (int i = 0; i < count; ++i) {
cute::warpgroup_fence_operand(regs[i]);
}
}
// Template parameter:
// thread_extent: the logical size (in number of threads) of each "group"
// within which we want to elect exactly ONE representative
......
#pragma once
#include "common.h"
#include <cstdint>
#include <type_traits>
namespace tl {
// Select a wider accumulator type for improved numerical accuracy.
// Default: accumulate in the same type. Specialize FP16/BF16 to float.
template <typename T> struct AccType {
using type = T;
};
template <> struct AccType<half_t> {
using type = float;
};
template <> struct AccType<bfloat16_t> {
using type = float;
};
struct SumOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x + y;
......@@ -40,53 +54,6 @@ struct BitXorOp {
}
};
template <class Reducer, int Threads, bool UseAbs, bool NeedAccumulate>
struct SharedReduceWarp {
template <typename T>
static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst,
int total_dest, int reduce_extent, int tail,
T init_value) {
if (total_dest <= 0 || reduce_extent <= 0)
return;
constexpr int kWarpSize = 32;
static_assert(Threads % kWarpSize == 0,
"SharedReduceWarp expects blockDim.x to be a multiple of "
"warp size on CUDA.");
const int tid = threadIdx.x;
const int warp_id = tid / kWarpSize;
const int lane = tid % kWarpSize;
const int num_warps = Threads / kWarpSize;
for (int dest_idx = warp_id; dest_idx < total_dest; dest_idx += num_warps) {
const int prefix = tail == 1 ? dest_idx : dest_idx / tail;
const int suffix = tail == 1 ? 0 : dest_idx % tail;
const int src_base = (prefix * reduce_extent) * tail + suffix;
const int dst_index = prefix * tail + suffix;
T partial = init_value;
for (int rv = lane; rv < reduce_extent; rv += kWarpSize) {
T val = src[src_base + rv * tail];
if constexpr (UseAbs) {
val = val < T(0) ? -val : val;
}
partial = Reducer()(partial, val);
}
unsigned mask = __activemask();
for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) {
T other = __shfl_down_sync(mask, partial, offset);
partial = Reducer()(partial, other);
}
if (lane == 0) {
if constexpr (NeedAccumulate) {
partial = Reducer()(dst[dst_index], partial);
}
dst[dst_index] = partial;
}
}
}
};
template <class Reducer, int threads, int scale, int thread_offset = 0,
int all_threads = threads>
struct AllReduce {
......@@ -102,7 +69,7 @@ struct AllReduce {
__syncthreads();
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
}
if constexpr (offset == scale) {
return x;
......@@ -122,7 +89,7 @@ struct AllReduce {
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
} else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
x = Reducer()(x, tl::shfl_xor_sync(uint32_t(-1), x, offset));
}
if constexpr (offset == scale) {
return x;
......@@ -159,7 +126,7 @@ template <int threads, bool reverse = false> struct CumSum1D {
#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off);
T n = (T)tl::shfl_down_sync(MASK, val, off);
if (lane < SEG - off)
val += n;
}
......@@ -234,7 +201,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off);
T n = tl::shfl_down_sync(MASK, val, off);
if (lane < SEG - off)
val += n;
}
......@@ -244,10 +211,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
if (real_col < W)
dst[real_row * W + real_col] = val;
T segSum = (T)__shfl_sync(MASK, val, (T)0);
T segSum = tl::shfl_sync(MASK, val, 0);
if (lane == 0)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, (T)0);
carry = tl::shfl_sync(MASK, carry, 0);
}
} else {
for (int seg = 0; seg * SEG < W; ++seg) {
......@@ -260,7 +227,7 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_up_sync(MASK, val, off);
T n = tl::shfl_up_sync(MASK, val, off);
if (lane >= off)
val += n;
}
......@@ -270,10 +237,10 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
if (real_col < W)
dst[real_row * W + real_col] = val;
T segSum = (T)__shfl_sync(MASK, val, SEG - 1);
T segSum = tl::shfl_sync(MASK, val, SEG - 1);
if (lane == SEG - 1)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, SEG - 1);
carry = tl::shfl_sync(MASK, carry, SEG - 1);
}
}
}
......
......@@ -6,6 +6,7 @@
#endif
#include "common.h"
#include <cute/arch/cluster_sm90.hpp>
namespace tl {
......@@ -59,12 +60,15 @@ inline void __device__ amma_fp16bf16_ss(uint64_t const desc_a,
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]));
}
inline __device__ void amma_commit(uint64_t const *smem_ptr) {
// Wrapper for CUTLASS umma_arrive: elect one lane, then arrive the mbarrier
TL_DEVICE void tcgen05_mma_arrive(void const *smem_ptr) {
uint32_t bar_intptr = smem_ptr_to_uint(smem_ptr);
if (cute::elect_one_sync()) {
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::"
"cluster.b64 [%0];"
:
: "r"(bar_intptr));
}
}
} // namespace tl
......@@ -47,7 +47,7 @@ public:
}
Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op);
Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply([this](Buffer buf) {
auto storage_scope =
......@@ -58,7 +58,7 @@ public:
buf->dtype.bytes());
if (!new_shape.same_as(buf->shape)) {
ObjectPtr<BufferNode> new_buffer =
make_object<BufferNode>(*(buf.get()));
tvm::ffi::make_object<BufferNode>(*(buf.get()));
new_buffer->shape = std::move(new_shape);
buffer_remap_.Set(buf, Buffer(new_buffer));
return Buffer(new_buffer);
......@@ -73,7 +73,7 @@ public:
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store_node = GetRef<BufferStore>(op);
auto store_node = tvm::ffi::GetRef<BufferStore>(op);
Buffer buf = op->buffer;
if (buffer_remap_.count(buf)) {
buf = buffer_remap_[buf];
......@@ -83,7 +83,7 @@ public:
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load_node = GetRef<BufferLoad>(op);
auto load_node = tvm::ffi::GetRef<BufferLoad>(op);
Buffer buf = op->buffer;
if (buffer_remap_.count(buf)) {
buf = buffer_remap_[buf];
......@@ -149,11 +149,11 @@ tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) {
"tl.AlignDynamicSharedMemoryAllocations", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AlignDynamicSharedMemoryAllocations",
AlignDynamicSharedMemoryAllocations);
});
}
} // namespace tl
} // namespace tvm
......@@ -46,13 +46,13 @@ public:
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::attr::kTarget) {
// If a target attribute already exists, use it as-is.
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else if (op->attr_key == tir::attr::thread_extent ||
op->attr_key == tir::attr::pipeline_exec_scope ||
op->attr_key == tir::attr::device_scope) {
// These attributes are only allowed in device-side code, so
// they should be annotated with the function's default target.
Stmt body = GetRef<Stmt>(op);
Stmt body = tvm::ffi::GetRef<Stmt>(op);
return AttrStmt(device_target_, tvm::attr::kTarget, 0, body);
} else {
// All other annotations are ignored
......@@ -90,11 +90,11 @@ tvm::transform::Pass AnnotateDeviceRegions() {
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateDeviceRegions", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AnnotateDeviceRegions",
AnnotateDeviceRegions);
});
}
} // namespace tl
} // namespace tvm
......@@ -124,7 +124,9 @@ private:
}
auto producer_body = if_then_else->then_case;
Optional<Stmt> consumer_body = if_then_else->else_case;
ICHECK(consumer_body.defined()) << "Consumer body is undefined";
// In some degenerate warp-specialized patterns (e.g., producer-only),
// the consumer body may be absent. Handle gracefully by only annotating
// the producer side when consumer is missing.
auto dec_reg = nreg_[0].as<IntImmNode>()->value;
auto inc_reg = nreg_[1].as<IntImmNode>()->value;
......@@ -150,15 +152,20 @@ private:
producer_stmts.push_back(producer_body);
auto new_producer_body = SeqStmt(producer_stmts);
Stmt new_if_stmt;
if (consumer_body.defined()) {
Array<Stmt> consumer_stmts;
consumer_stmts.push_back(inc_reg_stmt);
consumer_stmts.push_back(consumer_body.value());
auto new_consumer_body = SeqStmt(consumer_stmts);
auto new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
new_consumer_body);
auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt);
} else {
// No consumer branch; keep the if-then form.
new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body);
}
auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt);
return new_attr;
} else {
return StmtExprMutator::VisitStmt_(op);
......@@ -181,11 +188,11 @@ tvm::transform::Pass AnnotateWarpGroupRegAlloc() {
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc",
AnnotateWarpGroupRegAlloc);
});
}
} // namespace tl
} // namespace tvm
......@@ -80,8 +80,8 @@ void ArgBinder::Bind(const PrimExpr &arg, const PrimExpr &value,
Bind_(arg, value, arg_name, with_let);
}
void ArgBinder::BindArray(const Array<PrimExpr> &arg,
const Array<PrimExpr> &value,
void ArgBinder::BindArray(const ffi::Array<PrimExpr> &arg,
const ffi::Array<PrimExpr> &value,
const std::string &arg_name) {
ICHECK_EQ(arg.size(), value.size())
<< "Argument " << arg_name << " array size mismatch";
......@@ -250,7 +250,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
PrimExpr expect_stride = make_const(stype, 1);
Array<PrimExpr> conds;
ffi::Array<PrimExpr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
PrimExpr svalue =
......
......@@ -82,7 +82,8 @@ public:
* \param value The target expression value
* \param arg_name argument name.
*/
void BindArray(const Array<PrimExpr> &arg, const Array<PrimExpr> &value,
void BindArray(const ffi::Array<PrimExpr> &arg,
const ffi::Array<PrimExpr> &value,
const std::string &arg_name);
/*!
* \brief Bind symbolic buffer to another symbolic buffer
......@@ -149,7 +150,7 @@ public:
*/
const std::vector<Stmt> &init_nest() const { return init_nest_; }
/*! \return Handle data type of the data */
const Map<Var, PrimExpr> &def_handle_dtype() const {
const ffi::Map<Var, PrimExpr> &def_handle_dtype() const {
return def_handle_dtype_;
}
......@@ -164,7 +165,7 @@ private:
/*! \brief Initialize nest */
std::vector<Stmt> init_nest_;
/*! \brief handle data type in the defintiions */
Map<Var, PrimExpr> def_handle_dtype_;
ffi::Map<Var, PrimExpr> def_handle_dtype_;
/*! \brief asserts generated */
std::vector<Stmt> asserts_;
/*! \brief internal analyzer. */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment