Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
......@@ -30,23 +30,13 @@ enum class ReduceTypeEnum : uint8_t {
class ReduceTypeNode : public Object {
public:
int type{-1}; ///< Internal type identifier
static constexpr const char *_type_key = "tl.ReduceType";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceTypeNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceType", ReduceTypeNode, Object);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ReduceTypeNode>().def_ro("type", &ReduceTypeNode::type);
}
bool SEqualReduce(const ReduceTypeNode *other, SEqualReducer equal) const {
return equal(type, other->type);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(type); }
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/// Type checking methods
bool isSum() const { return type == int(ReduceTypeEnum::kSum); }
bool isAbsSum() const { return type == int(ReduceTypeEnum::kAbsSum); }
......@@ -61,9 +51,10 @@ public:
/// Wrapper class for reduction type with string-based construction
class ReduceType : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ReduceType, ObjectRef, ReduceTypeNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceType, ObjectRef,
ReduceTypeNode);
TVM_DLL ReduceType(std::string type) {
auto node = make_object<ReduceTypeNode>();
auto node = tvm::ffi::make_object<ReduceTypeNode>();
if (type == "sum") {
node->type = int(ReduceTypeEnum::kSum);
} else if (type == "abssum") {
......@@ -91,40 +82,27 @@ public:
class ReduceOpNode : public TileOperatorNode {
public:
tir::Buffer src, dst; ///< Source and destination buffers
// Optional: keep the original regions used to construct this op
BufferRegion srcRegion_, dstRegion_;
int dim; ///< Dimension to reduce along
ReduceType type; ///< Type of reduction operation
bool clear; ///< Whether to clear destination before reduction
static constexpr const char *_type_key = "tl.ReduceOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceOpNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceOp", ReduceOpNode,
TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ReduceOpNode>()
.def_ro("src", &ReduceOpNode::src)
.def_ro("dst", &ReduceOpNode::dst)
.def_ro("srcRegion", &ReduceOpNode::srcRegion_)
.def_ro("dstRegion", &ReduceOpNode::dstRegion_)
.def_ro("dim", &ReduceOpNode::dim)
.def_ro("type", &ReduceOpNode::type)
.def_ro("clear", &ReduceOpNode::clear);
}
bool SEqualReduce(const ReduceOpNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(dim, other->dim) && equal(type, other->type) &&
equal(clear, other->clear);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(dim);
hash_reduce(type);
hash_reduce(clear);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/// Lower the operator to TIR statements
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
/// Infer memory layout for buffers
......@@ -145,7 +123,8 @@ private:
/// Wrapper class for reduction operations
class ReduceOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ReduceOp, TileOperator, ReduceOpNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceOp, TileOperator,
ReduceOpNode);
TVM_DLL ReduceOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......@@ -156,8 +135,17 @@ public:
tir::Buffer src, dst; ///< Source and destination buffers
int dim; ///< Dimension along which to compute cumulative sum
bool reverse; ///< Whether to compute in reverse order
static constexpr const char *_type_key = "tl.CumSumOp";
TVM_DECLARE_FINAL_OBJECT_INFO(CumSumOpNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode,
TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<CumSumOpNode>()
.def_ro("src", &CumSumOpNode::src)
.def_ro("dst", &CumSumOpNode::dst)
.def_ro("dim", &CumSumOpNode::dim)
.def_ro("reverse", &CumSumOpNode::reverse);
}
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
......@@ -169,7 +157,8 @@ public:
/// Wrapper class for cumulative sum operations
class CumSumOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(CumSumOp, TileOperator, CumSumOpNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CumSumOp, TileOperator,
CumSumOpNode);
TVM_DLL CumSumOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......
......@@ -44,7 +44,7 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
PrimExpr extent = args[2 + i];
ranges.push_back(Range::FromMinExtent(min, extent));
}
ObjectPtr<RegionOpNode> node = make_object<RegionOpNode>();
ObjectPtr<RegionOpNode> node = tvm::ffi::make_object<RegionOpNode>();
node->buffer_ = load->buffer;
node->access_mask_ = static_cast<int>(*as_const_int(args[1]));
node->ranges_ = ranges;
......@@ -57,7 +57,7 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A new TileOperator that owns a copied RegionOpNode.
*/
TileOperator RegionOpNode::Clone() const {
auto op = make_object<RegionOpNode>(*this);
auto op = tvm::ffi::make_object<RegionOpNode>(*this);
return RegionOp(op);
}
......@@ -118,5 +118,7 @@ TIR_REGISTER_TL_OP(RegionOp, region)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TVM_FFI_STATIC_INIT_BLOCK() { RegionOpNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
......@@ -80,8 +80,8 @@ public:
Array<Range> ranges_;
int access_mask_;
static constexpr const char *_type_key = "tl.RegionOp";
TVM_DECLARE_FINAL_OBJECT_INFO(RegionOpNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.RegionOp", RegionOpNode,
TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
......@@ -101,25 +101,12 @@ public:
.def_ro("ranges", &RegionOpNode::ranges_)
.def_ro("access_mask", &RegionOpNode::access_mask_);
}
bool SEqualReduce(const RegionOpNode *other, SEqualReducer equal) const {
return equal(buffer_, other->buffer_) && equal(ranges_, other->ranges_) &&
equal(access_mask_, other->access_mask_);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer_);
hash_reduce(ranges_);
hash_reduce(access_mask_);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
};
class RegionOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RegionOp, TileOperator, RegionOpNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RegionOp, TileOperator,
RegionOpNode);
TVM_DLL RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
......
#ifndef TVM_TL_OP_TCGEN5_META_H_
#define TVM_TL_OP_TCGEN5_META_H_
#include <cstdint>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>
#include <utility>
#include <vector>
namespace tvm {
namespace tl {
using runtime::DataType;
struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k;
};
inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
}
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 16 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 16);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 16);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 16);
FAIL;
} else {
FAIL;
}
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 32 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 32);
FAIL;
} else {
FAIL;
}
}
FAIL;
#undef FAIL
#undef SUCCESS
}
inline uint32_t GetTCGEN5InstrDesc(int atom_m, int atom_n, int atom_k,
DataType ab_dtype, DataType c_dtype,
bool a_is_k_major, bool b_is_k_major,
int scale_in_a, int scale_in_b) {
ICHECK(atom_m % 16 == 0) << "atom_m must be divisible by 16";
ICHECK(atom_n % 8 == 0) << "atom_n must be divisible by 8";
ICHECK(atom_k == 16 || atom_k == 32)
<< "Unsupported atom_k for TCGEN5MMA descriptor: " << atom_k;
ICHECK(scale_in_a == 1 || scale_in_a == -1)
<< "scale_in_a must be +/-1 for TCGEN5MMA";
ICHECK(scale_in_b == 1 || scale_in_b == -1)
<< "scale_in_b must be +/-1 for TCGEN5MMA";
auto encode_dtype = [&](DataType dtype) -> uint32_t {
if (dtype.is_float16()) {
return static_cast<uint32_t>(0);
} else if (dtype.is_bfloat16()) {
return static_cast<uint32_t>(1);
} else if (dtype.is_float8_e4m3fn() || dtype.is_float8_e4m3fnuz() ||
dtype.is_float8_e4m3()) {
return static_cast<uint32_t>(0);
} else if (dtype.is_float8_e5m2fnuz() || dtype.is_float8_e5m2()) {
return static_cast<uint32_t>(1);
}
LOG(FATAL) << "Unsupported dtype for TCGEN5MMA descriptor: " << dtype;
return 0u;
};
uint32_t a_format = encode_dtype(ab_dtype);
uint32_t b_format = a_format;
uint32_t c_format = 0;
if (c_dtype.is_float16()) {
c_format = 0;
} else if (c_dtype.is_float()) {
c_format = 1;
} else if (c_dtype.is_int()) {
c_format = 2;
} else {
LOG(FATAL) << "Unsupported accumulator dtype for TCGEN5MMA descriptor: "
<< c_dtype;
}
auto set_bits = [](uint32_t value, int start, int width) -> uint32_t {
uint32_t mask = (width == 32) ? 0xFFFFFFFFu : ((1u << width) - 1);
return (value & mask) << start;
};
uint32_t desc = 0;
desc |= set_bits(0, 0, 2); // sparse_id2
desc |= set_bits(0, 2, 1); // sparse_flag
desc |= set_bits(0, 3, 1); // saturate
desc |= set_bits(c_format, 4, 2);
desc |= set_bits(a_format, 7, 3);
desc |= set_bits(b_format, 10, 3);
uint32_t a_neg = (scale_in_a == -1) ? 1u : 0u;
uint32_t b_neg = (scale_in_b == -1) ? 1u : 0u;
desc |= set_bits(a_neg, 13, 1);
desc |= set_bits(b_neg, 14, 1);
uint32_t a_major = a_is_k_major ? 0u : 1u;
uint32_t b_major = b_is_k_major ? 0u : 1u;
desc |= set_bits(a_major, 15, 1);
desc |= set_bits(b_major, 16, 1);
uint32_t n_dim = static_cast<uint32_t>(atom_n >> 3);
uint32_t m_dim = static_cast<uint32_t>(atom_m >> 4);
desc |= set_bits(n_dim, 17, 6);
desc |= set_bits(0, 23, 1);
desc |= set_bits(m_dim, 24, 5);
desc |= set_bits(0, 29, 1);
uint32_t max_shift = 0u;
desc |= set_bits(max_shift, 30, 2);
return desc;
}
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_TCGEN5_META_H_
......@@ -89,7 +89,7 @@ struct TensorMapArgs {
};
// set device api
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("tvm_tensormap_create_tiled", [](PackedArgs args,
Any *ret) {
......@@ -104,7 +104,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
}
*ret = static_cast<int>(result);
});
});
}
struct TensorMapIm2ColArgs {
CUtensorMap *map;
......@@ -180,7 +180,7 @@ struct TensorMapIm2ColArgs {
}
};
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed(
"tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) {
......@@ -197,7 +197,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
}
*ret = static_cast<int>(result);
});
});
}
#endif // (CUDA_MAJOR_VERSION >= 12)
......
#pragma once
#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/string.h>
namespace tvm {
using ffi::Array;
using ffi::Function;
using ffi::Map;
using ffi::Optional;
using ffi::String;
} // namespace tvm
......@@ -29,6 +29,7 @@
#include <unordered_set>
#include <utility>
#include "../support/ffi_aliases.h"
#include "support/str_escape.h"
#include "target/build_common.h"
#include "target/source/codegen_params.h"
......@@ -54,8 +55,7 @@ void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts,
}
void CodeGenTileLangCPP::InitGlobalContext() {
decl_stream << "void* " << tvm::runtime::symbol::tvm_ffi_library_ctx
<< " = NULL;\n";
decl_stream << "void* " << ffi::symbol::tvm_ffi_library_ctx << " = NULL;\n";
}
void CodeGenTileLangCPP::DefineModuleName() {
......@@ -256,8 +256,8 @@ void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) {
// reserve keywords
ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol)
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
......
......@@ -73,10 +73,10 @@ public:
void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*)
void VisitStmt_(const AllocateNode *op) final; // NOLINT(*)
void GenerateForwardFunctionDeclarations(String global_symbol,
const Array<Type> &arg_types,
void GenerateForwardFunctionDeclarations(ffi::String global_symbol,
const ffi::Array<Type> &arg_types,
const Type &ret_type) override;
Array<String> GetFunctionNames() { return function_names_; }
ffi::Array<ffi::String> GetFunctionNames() { return function_names_; }
private:
/* \brief Internal structure to store information about function calls */
......@@ -92,7 +92,7 @@ private:
/* \brief mapping global packed func to the unique name */
std::unordered_map<std::string, std::string> declared_globals_;
/* \brief names of the functions declared in this module */
Array<String> function_names_;
ffi::Array<ffi::String> function_names_;
/*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_;
/*! \brief whether to emit forward function declarations in the resulting C
......
......@@ -20,6 +20,7 @@
namespace tvm {
namespace codegen {
using namespace tvm::tl::codegen;
using namespace ffi;
struct CUDAMath {
std::string operator()(DataType t, std::string name) const {
......@@ -259,6 +260,21 @@ std::string CodeGenTileLangCUDA::Finish() {
if (need_mma_h_) {
decl_stream << "#include <mma.h>\n";
}
if (need_mma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/mma.h>\n";
}
if (need_wgmma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/wgmma.h>\n";
}
if (need_tcgen05mma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/tcgen05mma.h>\n";
}
if (need_mma_sm70_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/mma_sm70.h>\n";
}
if (need_tcgen05_common_h_) {
decl_stream << "#include <tl_templates/cuda/tcgen_05.h>\n";
}
if (enable_fp8_) {
decl_stream << "#include <tl_templates/cuda/cuda_fp8.h>\n";
}
......@@ -919,6 +935,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< "__half22float2(*((half2*)(&(" << src << "))+1));\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// half8 -> float8
PrintIndent();
stream << "((float2*)(&" << sret << "))[0] = "
<< "__half22float2(*(half2*)(&(" << src << ")));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[1] = "
<< "__half22float2(*((half2*)(&(" << src << "))+1));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[2] = "
<< "__half22float2(*((half2*)(&(" << src << "))+2));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[3] = "
<< "__half22float2(*((half2*)(&(" << src << "))+3));\n";
os << sret;
return;
}
} else if (from_ty.is_float() && target_ty.is_float16()) {
// Use __float22half2_rn for vectorized conversion (float2 -> half2)
......@@ -939,6 +971,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// float8 -> half8
PrintIndent();
stream << "((half2*)(&" << sret << "))[0] = "
<< "__float22half2_rn(*(float2*)(&(" << src << ")));\n";
PrintIndent();
stream << "((half2*)(&" << sret << "))[1] = "
<< "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
PrintIndent();
stream << "((half2*)(&" << sret << "))[2] = "
<< "__float22half2_rn(*((float2*)(&(" << src << "))+2));\n";
PrintIndent();
stream << "((half2*)(&" << sret << "))[3] = "
<< "__float22half2_rn(*((float2*)(&(" << src << "))+3));\n";
os << sret;
return;
}
}
......@@ -965,6 +1013,26 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< src << "))+1));\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// bfloat162x4 -> float8
PrintIndent();
stream << "((float2*)(&" << sret << "))[0] = "
<< "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
<< src << ")));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[1] = "
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<< src << "))+1));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[2] = "
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<< src << "))+2));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[3] = "
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<< src << "))+3));\n";
os << sret;
return;
}
} else if (from_ty.is_float() && target_ty.is_bfloat16()) {
// Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162)
......@@ -985,6 +1053,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// float8 -> bfloat162x4
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = "
<< "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = "
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[2] = "
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+2));\n";
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[3] = "
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+3));\n";
os << sret;
return;
}
}
......@@ -1017,6 +1101,36 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< "))+1), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// float8 -> fp8x8
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
<< "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
<< ")), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+1), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[2] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+2), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[3] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+3), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
}
}
......@@ -1034,6 +1148,52 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
os << sret;
}
void CodeGenTileLangCUDA::VisitExpr_(const MinNode *op, std::ostream &os) {
// TODO(wt): Consider vectorized reduction and impl for other dtypes
DataType t = op->dtype;
// Standard min/max functions don't support bfloat16 or float16
if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) {
os << "cutlass::fast_min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b)
<< ")";
return;
}
// For float32 and float64 scalar, use standard min functions
if (t.is_float() && t.is_scalar()) {
if (t.bits() == 32 || t.bits() == 64) {
os << "min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
return;
}
}
// For all other scalar types (int, uint), use default implementation
CodeGenC::VisitExpr_(op, os);
}
void CodeGenTileLangCUDA::VisitExpr_(const MaxNode *op, std::ostream &os) {
// TODO(wt): Consider vectorized reduction and impl for other dtypes
DataType t = op->dtype;
// Standard min/max functions don't support bfloat16 or float16
if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) {
os << "cutlass::fast_max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b)
<< ")";
return;
}
// For float32 and float64 scalar, use standard max functions
if (t.is_float() && t.is_scalar()) {
if (t.bits() == 32 || t.bits() == 64) {
os << "max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
return;
}
}
// For all other scalar types (int, uint), use default implementation
CodeGenC::VisitExpr_(op, os);
}
void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args,
bool skip_first_arg,
......@@ -1132,7 +1292,7 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
if (scope.empty()) {
scope = GetPtrStorageScope(buffer->data);
}
if (scope == "local.var" || scope == "local.descriptor") {
if (scope == "local.var" || scope.find("local.descriptor") == 0) {
os << vid;
return os.str();
}
......@@ -1452,6 +1612,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::warpgroup_wait<" << std::to_string(num_mma)
<< ">();\n";
} else if (op->op.same_as(tl::warpgroup_fence_operand())) {
ICHECK_EQ(op->args.size(), 4U);
std::string dtype = Downcast<StringImm>(op->args[0])->value;
std::string data_ptr = this->PrintExpr(op->args[1]);
std::string offset = this->PrintExpr(op->args[2]);
std::string num_regs = this->PrintExpr(op->args[3]);
auto dtype_enum = tl::codegen::ptx::DTypeFromString(dtype);
std::string cast_type = "uint32_t";
if (dtype_enum == tl::codegen::ptx::DataType::kFloat32 ||
dtype_enum == tl::codegen::ptx::DataType::kTensorFloat32) {
cast_type = "float";
}
this->PrintIndent();
this->stream << "tl::warpgroup_fence_operand(reinterpret_cast<" << cast_type
<< "*>(" << data_ptr << " + " << offset << "), " << num_regs
<< ");\n";
} else if (op->op.same_as(tl::set_max_nreg())) {
this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value;
......@@ -1563,14 +1739,124 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
bool saturate = Downcast<Bool>(op->args[12])->value;
std::string bit_op =
op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
std::string asm_code = PrintMMAAssembly(
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias,
b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype);
auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype);
auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
need_mma_instruction_h_ = true;
this->PrintIndent();
this->stream << asm_code;
std::string mma_call =
"tl::mma_sync<(AType), (BType), (CType), (M), (N), (K), (TransA), "
"(TransB)>(reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), "
"reinterpret_cast<const (ARegType)*>((A_ptr) + (A_offset)), "
"reinterpret_cast<const (BRegType)*>((B_ptr) + (B_offset)));\n";
tl::codegen::Replacer replacer;
// TODO(lei): Type Workaround for TF32, should be removed when
// we introduced tfloat32_t in the frontend.
std::string AType = tl::codegen::ptx::DTypeEnumToString(dtype_a_enum);
if (AType == "tl::DataType::kFloat32") {
AType = "tl::DataType::kTensorFloat32";
}
std::string BType = tl::codegen::ptx::DTypeEnumToString(dtype_b_enum);
if (BType == "tl::DataType::kFloat32") {
BType = "tl::DataType::kTensorFloat32";
}
std::string ARegType = tl::codegen::GetMMARegisterType(dtype_a_enum);
if (ARegType == "float") {
ARegType = "uint32_t";
}
std::string BRegType = tl::codegen::GetMMARegisterType(dtype_b_enum);
if (BRegType == "float") {
BRegType = "uint32_t";
}
replacer.register_rule("(AType)", AType);
replacer.register_rule("(BType)", BType);
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(M)", std::to_string(m));
replacer.register_rule("(N)", std::to_string(n));
replacer.register_rule("(K)", std::to_string(k));
replacer.register_rule("(TransA)", A_layout == "row" ? "false" : "true");
replacer.register_rule("(TransB)", B_layout == "row" ? "false" : "true");
replacer.register_rule("(ARegType)", ARegType);
replacer.register_rule("(BRegType)", BRegType);
replacer.register_rule("(CRegType)",
tl::codegen::GetMMARegisterType(dtype_c_enum));
replacer.register_rule("(A_ptr)", a_ref);
replacer.register_rule("(A_offset)", a_bias);
replacer.register_rule("(B_ptr)", b_ref);
replacer.register_rule("(B_offset)", b_bias);
replacer.register_rule("(C_ptr)", c_ref);
replacer.register_rule("(C_offset)", c_bias);
this->stream << replacer.rewrite(mma_call);
} else if (op->op.same_as(tl::ptx_mma_sm70())) {
// arg 0: shape: mXnXkX
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: fp16
// arg 4: B precision: fp16
// arg 5: C precision: fp16, fp32
// arg 6: A multiplicand
// arg 7: A multiplicand index
// arg 8: B multiplicand
// arg 9: B multiplicand index
// arg 10: C accumulator
// arg 11: C accumulator index
// arg 12: saturate
ICHECK_EQ(op->args.size(), 12U);
std::string shape = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
std::string a_ref = this->PrintExpr(op->args[6]);
std::string a_bias = this->PrintExpr(op->args[7]);
std::string b_ref = this->PrintExpr(op->args[8]);
std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype);
auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype);
auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
need_mma_sm70_instruction_h_ = true;
this->PrintIndent();
std::string mma_call =
"tl::mma_sync_sm70<(AType), (BType), (CType), (M), (N), (K), (TransA), "
"(TransB)>(reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), "
"reinterpret_cast<const (ARegType)*>((A_ptr) + (A_offset)), "
"reinterpret_cast<const (BRegType)*>((B_ptr) + (B_offset)));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(AType)",
tl::codegen::ptx::DTypeEnumToString(dtype_a_enum));
replacer.register_rule("(BType)",
tl::codegen::ptx::DTypeEnumToString(dtype_b_enum));
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(M)", std::to_string(m));
replacer.register_rule("(N)", std::to_string(n));
replacer.register_rule("(K)", std::to_string(k));
replacer.register_rule("(TransA)", A_layout == "row" ? "false" : "true");
replacer.register_rule("(TransB)", B_layout == "row" ? "false" : "true");
replacer.register_rule("(ARegType)",
tl::codegen::GetMMARegisterType(dtype_a_enum));
replacer.register_rule("(BRegType)",
tl::codegen::GetMMARegisterType(dtype_b_enum));
replacer.register_rule("(CRegType)",
tl::codegen::GetMMARegisterType(dtype_c_enum));
replacer.register_rule("(A_ptr)", a_ref);
replacer.register_rule("(A_offset)", a_bias);
replacer.register_rule("(B_ptr)", b_ref);
replacer.register_rule("(B_offset)", b_bias);
replacer.register_rule("(C_ptr)", c_ref);
replacer.register_rule("(C_offset)", c_bias);
this->stream << replacer.rewrite(mma_call);
} else if (op->op.same_as(builtin::ptx_mma_sp())) {
// arg 0: shape: mXnXkX
// arg 1: A layout: row/col
......@@ -1636,27 +1922,32 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string B_offset = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_offset = this->PrintExpr(op->args[11]);
bool scale_out = Downcast<Bool>(op->args[12])->value;
std::string scale_out = this->PrintExpr(op->args[12]);
bool scale_in_a = Downcast<Bool>(op->args[13])->value;
bool scale_in_b = Downcast<Bool>(op->args[14])->value;
const bool a_is_shared = true;
this->PrintIndent();
std::string asm_code = PrintWGMMAAssembly(
shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc,
A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a,
scale_in_b, a_is_shared, "", "", "", false);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
need_wgmma_instruction_h_ = true;
std::string wgmma_asm_code =
"tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), "
"(tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), "
"uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n";
// replace patterns
tl::codegen::Replacer replacer;
replacer.register_rule("(AType)",
tl::codegen::ptx::DTypeEnumToString(A_dtype));
replacer.register_rule("(BType)",
tl::codegen::ptx::DTypeEnumToString(B_dtype));
std::string AType = tl::codegen::ptx::DTypeEnumToString(A_dtype);
if (AType == "tl::DataType::kFloat32") {
AType = "tl::DataType::kTensorFloat32";
}
std::string BType = tl::codegen::ptx::DTypeEnumToString(B_dtype);
if (BType == "tl::DataType::kFloat32") {
BType = "tl::DataType::kTensorFloat32";
}
replacer.register_rule("(AType)", AType);
replacer.register_rule("(BType)", BType);
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(C_dtype));
replacer.register_rule("(M)", std::to_string(m));
......@@ -1671,45 +1962,184 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
replacer.register_rule("(desc_b)", b_desc);
replacer.register_rule("(B_offset)", B_offset);
replacer.register_rule("(C)", c_ref + " + " + c_offset);
replacer.register_rule("(scale_out)", scale_out ? "true" : "false");
replacer.register_rule("(scale_out)", scale_out);
wgmma_asm_code = replacer.rewrite(wgmma_asm_code);
this->stream << wgmma_asm_code;
} else if (op->op.same_as(tl::ptx_wgmma_rs())) {
// arg 0: dtype
// arg 1: shape
// arg 2: A_layout
// arg 3: B_layout
// arg 4: A_dtype
// arg 5: B_dtype
// arg 6: C_dtype
// arg 7: multiplicand_a
// arg 8: multiplicand_b
// arg 0: shape
// arg 1: B_layout
// arg 2: A_dtype
// arg 3: B_dtype
// arg 4: C_dtype
// arg 5: multiplicand_a
// arg 6: multiplicand_a offset
// arg 7: multiplicand_b descriptor
// arg 8: multiplicand_b offset
// arg 9: accumulator
// arg 10: saturate
ICHECK_EQ(op->args.size(), 15U) << "ptx_wgmma_rs args is " << op->args;
// arg 10: accumulator offset
// arg 11: scale_out
// arg 12: scale_in_a
// arg 13: scale_in_b
ICHECK_EQ(op->args.size(), 14U) << "ptx_wgmma_rs args is " << op->args;
std::string shape = Downcast<StringImm>(op->args[0])->value;
bool A_layout = Downcast<Bool>(op->args[1])->value;
bool B_layout = Downcast<Bool>(op->args[2])->value;
std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
std::string a_ref = this->PrintExpr(op->args[6]);
std::string A_offset = this->PrintExpr(op->args[7]);
std::string b_desc = this->PrintExpr(op->args[8]);
std::string B_offset = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_offset = this->PrintExpr(op->args[11]);
bool scale_out = Downcast<Bool>(op->args[12])->value;
bool scale_in_a = Downcast<Bool>(op->args[13])->value;
bool scale_in_b = Downcast<Bool>(op->args[14])->value;
bool b_is_k_major = Downcast<Bool>(op->args[1])->value;
std::string A_dtype = Downcast<StringImm>(op->args[2])->value;
std::string B_dtype = Downcast<StringImm>(op->args[3])->value;
std::string C_dtype = Downcast<StringImm>(op->args[4])->value;
std::string a_ref = this->PrintExpr(op->args[5]);
std::string A_offset = this->PrintExpr(op->args[6]);
std::string b_desc = this->PrintExpr(op->args[7]);
std::string B_offset = this->PrintExpr(op->args[8]);
std::string c_ref = this->PrintExpr(op->args[9]);
std::string c_offset = this->PrintExpr(op->args[10]);
std::string scale_out = this->PrintExpr(op->args[11]);
bool scale_in_a = Downcast<Bool>(op->args[12])->value;
bool scale_in_b = Downcast<Bool>(op->args[13])->value;
auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype);
auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype);
auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
const bool a_is_shared = false;
need_wgmma_instruction_h_ = true;
this->PrintIndent();
std::string asm_code = PrintWGMMAAssembly(
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, A_offset,
b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b,
a_is_shared, "", "", "", false);
this->stream << asm_code;
std::string wgmma_call =
"tl::wgmma_rs<(AType), (BType), (CType), (M), (N), (K), (tnspA), "
"(tnspB), (scaleA), (scaleB)>(reinterpret_cast<const "
"uint32_t*>((A_ptr) + (A_offset)), "
"uint64_t((desc_b) + (B_offset)), "
"reinterpret_cast<uint32_t*>((C_ptr) + (C_offset)), "
"(scale_out));\n";
tl::codegen::Replacer replacer;
std::string AType = tl::codegen::ptx::DTypeEnumToString(A_dtype);
if (AType == "tl::DataType::kFloat32") {
AType = "tl::DataType::kTensorFloat32";
}
std::string BType = tl::codegen::ptx::DTypeEnumToString(B_dtype);
if (BType == "tl::DataType::kFloat32") {
BType = "tl::DataType::kTensorFloat32";
}
replacer.register_rule("(AType)", AType);
replacer.register_rule("(BType)", BType);
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(M)", std::to_string(m));
replacer.register_rule("(N)", std::to_string(n));
replacer.register_rule("(K)", std::to_string(k));
replacer.register_rule("(tnspA)", "false");
replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true");
replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1");
replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1");
replacer.register_rule("(A_ptr)", a_ref);
replacer.register_rule("(A_offset)", A_offset);
replacer.register_rule("(desc_b)", b_desc);
replacer.register_rule("(B_offset)", B_offset);
replacer.register_rule("(C_ptr)", c_ref);
replacer.register_rule("(C_offset)", c_offset);
replacer.register_rule("(scale_out)", scale_out);
wgmma_call = replacer.rewrite(wgmma_call);
this->stream << wgmma_call;
} else if (op->op.same_as(tl::ptx_tcgen05_mma_ss())) {
ICHECK_EQ(op->args.size(), 14U)
<< "ptx_tcgen05_mma_ss args is " << op->args;
std::string C_dtype = Downcast<StringImm>(op->args[0])->value;
std::string a_desc = this->PrintExpr(op->args[1]);
std::string A_offset = this->PrintExpr(op->args[2]);
std::string b_desc = this->PrintExpr(op->args[3]);
std::string B_offset = this->PrintExpr(op->args[4]);
std::string c_ref = this->PrintExpr(op->args[5]);
std::string c_offset = this->PrintExpr(op->args[6]);
PrimExpr desc_expr = op->args[7];
std::string scale_out = this->PrintExpr(op->args[8]);
std::string mask0 = this->PrintExpr(op->args[9]);
std::string mask1 = this->PrintExpr(op->args[10]);
std::string mask2 = this->PrintExpr(op->args[11]);
std::string mask3 = this->PrintExpr(op->args[12]);
bool enable_ws = Downcast<Bool>(op->args[13])->value;
auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
need_tcgen05mma_instruction_h_ = true;
this->PrintIndent();
std::string tcgen05_call =
"tl::(tcgen05_name)<(CType)>(uint64_t((desc_a) + (A_offset)), "
"uint64_t((desc_b) + (B_offset)), (*reinterpret_cast<uint32_t*>((C))) "
"+ (C_offset), "
"(scale_out), static_cast<uint32_t>((desc_val)), (mask0), (mask1), "
"(mask2), (mask3));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(desc_a)", a_desc);
replacer.register_rule("(A_offset)", A_offset);
replacer.register_rule("(desc_b)", b_desc);
replacer.register_rule("(B_offset)", B_offset);
replacer.register_rule("(C)", c_ref);
replacer.register_rule("(C_offset)", c_offset);
replacer.register_rule("(tcgen05_name)",
enable_ws ? "tcgen05mma_ws_ss" : "tcgen05mma_ss");
replacer.register_rule("(scale_out)", scale_out);
replacer.register_rule("(desc_val)", this->PrintExpr(desc_expr));
replacer.register_rule("(mask0)", mask0);
replacer.register_rule("(mask1)", mask1);
replacer.register_rule("(mask2)", mask2);
replacer.register_rule("(mask3)", mask3);
tcgen05_call = replacer.rewrite(tcgen05_call);
this->stream << tcgen05_call;
} else if (op->op.same_as(tl::ptx_tcgen05_mma_ts())) {
// TS: A from TMEM, B from SMEM (desc)
ICHECK_EQ(op->args.size(), 13U)
<< "ptx_tcgen05_mma_ts args is " << op->args;
std::string kind_dtype = Downcast<StringImm>(op->args[0])->value;
std::string a_ref = this->PrintExpr(op->args[1]);
std::string A_offset = this->PrintExpr(op->args[2]);
std::string b_desc = this->PrintExpr(op->args[3]);
std::string B_offset = this->PrintExpr(op->args[4]);
std::string c_ref = this->PrintExpr(op->args[5]);
std::string c_offset = this->PrintExpr(op->args[6]);
PrimExpr desc_expr = op->args[7];
std::string scale_out = this->PrintExpr(op->args[8]);
std::string mask0 = this->PrintExpr(op->args[9]);
std::string mask1 = this->PrintExpr(op->args[10]);
std::string mask2 = this->PrintExpr(op->args[11]);
std::string mask3 = this->PrintExpr(op->args[12]);
auto dtype_enum = tl::codegen::ptx::DTypeFromString(kind_dtype);
need_tcgen05mma_instruction_h_ = true;
this->PrintIndent();
std::string tcgen05_call =
"tl::tcgen05mma_ts<(CType)>( (*reinterpret_cast<uint32_t*>((A))) + "
"(A_offset), "
"uint64_t((desc_b) + (B_offset)), (*reinterpret_cast<uint32_t*>((C))) "
"+ (C_offset), "
"(scale_out), static_cast<uint32_t>((desc_val)), (mask0), (mask1), "
"(mask2), (mask3));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_enum));
replacer.register_rule("(A)", a_ref);
replacer.register_rule("(A_offset)", A_offset);
replacer.register_rule("(desc_b)", b_desc);
replacer.register_rule("(B_offset)", B_offset);
replacer.register_rule("(C)", c_ref);
replacer.register_rule("(C_offset)", c_offset);
replacer.register_rule("(scale_out)", scale_out);
replacer.register_rule("(desc_val)", this->PrintExpr(desc_expr));
replacer.register_rule("(mask0)", mask0);
replacer.register_rule("(mask1)", mask1);
replacer.register_rule("(mask2)", mask2);
replacer.register_rule("(mask3)", mask3);
tcgen05_call = replacer.rewrite(tcgen05_call);
this->stream << tcgen05_call;
} else if (op->op.same_as(tl::tcgen05_mma_arrive())) {
ICHECK_EQ(op->args.size(), 1U) << "tcgen05_mma_arrive expects 1 argument";
need_tcgen05_common_h_ = true;
this->PrintIndent();
this->stream << "tl::tcgen05_mma_arrive(" << this->PrintExpr(op->args[0])
<< ");\n";
} else if (op->op.same_as(builtin::ptx_ldmatrix())) {
// arg 0: whether the matrix is loaded in column major format or not.
// arg 1: number of matrices to load.
......@@ -2021,8 +2451,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
"A_ptr, B_ptr, C_ptr>, but got "
<< op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
op->args, true, os);
this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_sp())) {
ICHECK(op->args.size() == 5)
<< "tl_gemm_sp expects 5 arguments <op_instance, A_ptr, B_ptr, C_ptr, "
......@@ -2030,8 +2460,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
<< op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]);
enable_sparse_gemm_ = true;
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
op->args, true, os);
this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::get_lane_idx())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_lane_idx expects at most one argument <warp_size>.";
......@@ -2069,19 +2499,35 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
os << ")";
} else if (op->op.same_as(tl::tl_shuffle_elect())) {
os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
} else if (op->op.same_as(tl::initialize_descriptor())) {
} else if (op->op.same_as(tl::initialize_wgmma_descriptor())) {
ICHECK(op->args.size() == 5)
<< "tl_initialize_descriptor expects 5 arguments but got "
<< "tl_initialize_wgmma_descriptor expects 5 arguments but got "
<< op->args.size();
auto descriptor = op->args[0];
auto start_address = op->args[1];
auto layout_type = op->args[2];
auto leading_byte_offset = op->args[3];
auto stride_byte_offset = op->args[4];
os << "tl::initialize_descriptor<" << PrintExpr(layout_type) << ", "
os << "tl::initialize_wgmma_descriptor<" << PrintExpr(layout_type) << ", "
<< PrintExpr(leading_byte_offset) << ", "
<< PrintExpr(stride_byte_offset) << ">(" << PrintExpr(descriptor) << ", "
<< PrintExpr(start_address) << ")";
} else if (op->op.same_as(tl::initialize_tcgen05_descriptor())) {
ICHECK(op->args.size() == 7)
<< "tl_initialize_tcgen05_descriptor expects 7 arguments but got "
<< op->args.size();
auto descriptor = op->args[0];
auto start_address = op->args[1];
auto leading_byte_offset = op->args[2];
auto stride_byte_offset = op->args[3];
auto base_offset = op->args[4];
auto leading_abs = op->args[5];
auto swizzle_mode = op->args[6];
os << "tl::initialize_tcgen05_descriptor(" << PrintExpr(descriptor) << ", "
<< PrintExpr(start_address) << ", " << PrintExpr(leading_byte_offset)
<< ", " << PrintExpr(stride_byte_offset) << ", "
<< PrintExpr(base_offset) << ", " << PrintExpr(leading_abs) << ", "
<< PrintExpr(swizzle_mode) << ")";
} else if (op->op.same_as(tl::increase_descriptor_offset())) {
ICHECK(op->args.size() == 2)
<< "tl_increase_descriptor_offset expects 2 arguments but got "
......@@ -2232,8 +2678,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
<< "Accumulator only support half, float and int type for now";
}
PrintWmmaScope(scope, op->dtype, buffer, stream);
} else if (scope == "local.descriptor") {
} else if (scope == "local.descriptor.wgmma") {
stream << "tl::GmmaDescriptor " << vid << ";\n";
} else if (scope == "local.descriptor.tcgen05_smem") {
stream << "tl::Tcgen05SMemDescriptor " << vid << ";\n";
} else if (scope == "local.descriptor.tcgen05_instr") {
stream << "tl::Tcgen05InstrDescriptor " << vid << ";\n";
} else {
PrintStorageScope(scope, stream);
PrintType(op->dtype, stream);
......@@ -2275,7 +2725,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
init = user_init;
}
stream << ' ' << vid << " = " << PrintExpr(init) << ";\n";
} else if (scope != "local.descriptor") {
} else if (scope.find("local.descriptor") != 0) {
ICHECK(false) << "Unsupported scope: " << scope;
}
}
......@@ -2297,6 +2747,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
stream << " " << vid_global_barrier_expect_ << " = 0;\n";
PrintIndent();
stream << "}\n";
}
if (call && (call->op.same_as(tvm::tl::device_assert()))) {
std::string cond = PrintExpr(call->args[0]);
this->PrintIndent();
stream << "device_assert(" << cond << ");\n";
} else if (call && call->op.same_as(tvm::tl::device_assert_with_msg())) {
std::string cond = PrintExpr(call->args[0]);
std::string msg_expr = PrintExpr(call->args[1]);
this->PrintIndent();
stream << "device_assert_with_msg(" << cond << ", " << msg_expr << ");\n";
} else {
CodeGenC::VisitStmt_(op);
}
......@@ -2304,8 +2764,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
CHECK_LE(lanes, 4) << "Translate Ramp Node " << GetRef<Ramp>(op) << " with "
<< lanes << " lanes is not allowed.";
CHECK_LE(lanes, 4) << "Translate Ramp Node " << tvm::ffi::GetRef<Ramp>(op)
<< " with " << lanes << " lanes is not allowed.";
os << "(make_";
PrintType(op->dtype, os);
os << "(";
......@@ -2540,12 +3000,29 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
inline void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p) { // NOLINT(*)
// Type code is kBFloat
if (op->dtype.is_bfloat16()) {
os << "bfloat16_t";
os << '(' << std::hexfloat << op->value << 'f';
os << "/*" << std::scientific << op->value << "*/";
os << ')';
// Type code is kBFloat/kFloat16
// which is indeed CUTLASS supported types currently
if (op->dtype.is_bfloat16() || op->dtype.is_float16()) {
std::ostringstream temp;
if (std::isinf(op->value)) {
if (op->value < 0) {
temp << "-";
}
temp << "std::numeric_limits<";
p->PrintType(op->dtype, temp);
temp << ">::infinity()";
} else if (std::isnan(op->value)) {
temp << "std::numeric_limits<";
p->PrintType(op->dtype, temp);
temp << ">::quiet_NaN()";
} else {
p->PrintType(op->dtype, temp);
temp << '(' << std::hexfloat << op->value << 'f';
temp << "/*" << std::scientific << op->value << "*/";
temp << ')';
}
p->MarkConst(temp.str());
os << temp.str();
return;
}
// Type code is kFloat8_e5m2 or kE4M4Float
......@@ -2556,7 +3033,7 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
os << ')';
return;
}
// Type code is kFloat
// Type code is kFloat64/kFloat32 (kFloat16 is handled above)
switch (op->dtype.bits()) {
case 64:
case 32: {
......@@ -2580,13 +3057,6 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
os << temp.str();
break;
}
case 16: {
os << "half_t" << '(';
FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
PrintConst(const_f32.get(), os, p);
os << ')';
break;
}
default:
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
}
......@@ -2807,7 +3277,7 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
ICHECK(global_symbol)
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
......
......@@ -51,6 +51,8 @@ public:
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitExpr_(const MinNode *op, std::ostream &os) final;
void VisitExpr_(const MaxNode *op, std::ostream &os) final;
void VisitStmt_(const EvaluateNode *op) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final;
......@@ -58,14 +60,14 @@ public:
// Override this as a work around for __grid_constant__ parameter
void AddFunction(const GlobalVar &gvar, const PrimFunc &f);
void PrintFunctionSignature(const String &function_name, const PrimFunc &func,
std::ostream &os);
void PrintFunctionSignature(const ffi::String &function_name,
const PrimFunc &func, std::ostream &os);
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
void PrintCallExtern(Type ret_type, ffi::String global_symbol,
const ffi::Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private:
......@@ -106,6 +108,16 @@ private:
bool need_math_constants_h_{false};
// whether need mma.h
bool need_mma_h_{false};
// whether need tl mma instruction header
bool need_mma_instruction_h_{false};
// whether need tl wgmma instruction header
bool need_wgmma_instruction_h_{false};
// whether need tl tcgen05mma instruction header
bool need_tcgen05mma_instruction_h_{false};
// whether need tl mma_sm70 instruction header
bool need_mma_sm70_instruction_h_{false};
// whether need tcgen_05 common header
bool need_tcgen05_common_h_{false};
// whether need cast_smem_ptr_to_int helper function
bool need_cast_smem_ptr_to_int_{false};
// whether need cooperative_groups.h
......
......@@ -929,7 +929,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"float32", "float"},
{"float64", "double"},
{"float16x4", "float16x4"},
{"bfloat16x4", "bfloat16x4"},
{"bfloat16x4", "bfloat16x4_vec"},
{"float32x4", "float32x4"},
{"float8_e4m3fnuzx4", "fp8_e4_4_t"},
{"float8_e4m3fnuzx8", "long"},
......@@ -1025,8 +1025,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
"A_ptr, B_ptr, C_ptr>, but got "
<< op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
op->args, true, os);
this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_sp())) {
LOG(FATAL) << "tl_gemm_sp is not supported on HIP";
} else if (op->op.same_as(tl::loop_break())) {
......@@ -1375,7 +1375,7 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
ICHECK(global_symbol.has_value())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
......
......@@ -56,8 +56,8 @@ public:
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
void PrintCallExtern(Type ret_type, ffi::String global_symbol,
const ffi::Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private:
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_webgpu.cc
*/
#include "codegen_webgpu.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>
#include <algorithm>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "arith/pattern_match.h"
#include "runtime/meta_data.h"
#include "runtime/thread_storage_scope.h"
#include "target/build_common.h"
namespace tvm {
namespace codegen {
// WebGPU Info
struct WebGPUWorkGroupInfo {
int workgroup_size[3] = {1, 1, 1};
// whether we have ref to block index z is used.
bool has_block_index_z{false};
// set of handles that have write access
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> write_access_set;
};
class WebGPUWorkgroupInfoCollector : public StmtExprVisitor {
public:
static WebGPUWorkGroupInfo Collect(const Stmt &stmt) {
WebGPUWorkgroupInfoCollector collector;
collector(stmt);
return collector.info_;
}
private:
void VisitExpr_(const VarNode *op) final {
StmtExprVisitor::VisitExpr_(op);
Var buffer_var = GetRef<Var>(op);
if (buffer_var.dtype().is_handle()) {
info_.write_access_set.insert(buffer_var);
}
}
void VisitStmt_(const BufferStoreNode *op) final {
StmtExprVisitor::VisitStmt_(op);
info_.write_access_set.insert(op->buffer->data);
}
void VisitStmt_(const AttrStmtNode *op) final {
// record workgroup size
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (!iv->thread_tag.empty()) {
runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
if (ts.rank == 1) {
ICHECK_GE(ts.dim_index, 0)
<< "vthread should have been optimized out by here";
ICHECK_LT(ts.dim_index, 3);
auto *sizeptr = op->value.as<tir::IntImmNode>();
ICHECK(sizeptr) << "CodeGenTileLangWebGPU: only allows constant "
"thread group size "
<< " get " << op->value;
info_.workgroup_size[ts.dim_index] =
static_cast<uint32_t>(sizeptr->value);
} else if (ts.rank == 0) {
if (ts.dim_index == 2) {
info_.has_block_index_z = true;
}
}
}
}
// normal operation
StmtExprVisitor::VisitStmt_(op);
}
WebGPUWorkGroupInfo info_;
};
std::string CodeGenTileLangWebGPU::Finish() {
// Using f16 requires enable directive
if (enable_fp16_) {
header_stream << "enable f16;\n\n";
}
// WebGPU WGSL doesn't support #include.
// We must explicitly include all the templates here.
return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str() +
stream.str();
}
void CodeGenTileLangWebGPU::InitFuncState(const PrimFunc &f) {
CodeGenC::InitFuncState(f);
// analyze the data;
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
alloc_storage_scope_[arg.get()] = "global";
}
}
}
CodeGenTileLangWebGPU::CodeGenTileLangWebGPU(Target target) : target_(target) {}
runtime::FunctionInfo
CodeGenTileLangWebGPU::AddFunction(const PrimFunc &f, bool skip_readonly_decl) {
// clear previous generated state.
this->InitFuncState(f);
// reserve keywords
name_supply_->ReserveName("var");
name_supply_->ReserveName("let");
name_supply_->ReserveName("const");
// skip the first underscore, so SSA variable starts from
name_supply_->FreshName("v_");
// Setup the thread group info.
ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
ICHECK_EQ(name_supply_->FreshName("gridDim"), "gridDim");
// add to alloc buffer type.
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined()) << "CodeGenTileLangWebGPU: Expect PrimFunc "
"to have the global_symbol attribute";
header_stream << "//----------------------------------------\n"
<< "// Function: " << global_symbol.value() << "\n"
<< "//----------------------------------------\n";
runtime::FunctionInfo func_info;
func_info.name = global_symbol.value();
WebGPUWorkGroupInfo info = WebGPUWorkgroupInfoCollector::Collect(f->body);
std::vector<Var> pod_args;
int num_buffer = 0;
// add param_access modes info to launch params
std::ostringstream os_param_access;
os_param_access << "paramWriteAccess:[";
// setup buffer argumemts
for (Var arg : f->params) {
DataType t = arg.dtype();
func_info.arg_types.push_back(t);
if (t.is_handle()) {
auto *ptr = arg->type_annotation.as<PointerTypeNode>();
ICHECK(ptr) << "All handles passed to the CodeGenTileLangWebGPU must "
"have a type_annotation as a "
"PointerType, "
<< "and must point to a PrimType";
auto *prim = ptr->element_type.as<PrimTypeNode>();
ICHECK(prim) << "All handles passed to the CodeGenTileLangWebGPU must "
"have a type_annotation as a "
"PointerType, "
<< "and must point to a PrimType";
DataType value_storage_type = prim->dtype;
if (value_storage_type == DataType::Bool()) {
// We need a physically addressable buffer type to support boolean
// tensors. The loaded byte is cast to bool inside the LoadNode visitor
// below.
value_storage_type =
boolean_storage_type_.with_lanes(value_storage_type.lanes());
}
std::string vid = AllocVarID(arg.get());
std::string access_mode;
if (num_buffer != 0) {
os_param_access << ",";
}
if (skip_readonly_decl || info.write_access_set.count(arg)) {
access_mode = "read_write";
os_param_access << "1";
} else {
access_mode = "read";
os_param_access << "0";
}
// add extra access mode info to launch params
this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") "
<< "var<storage, " << access_mode << "> " << vid
<< " : array<";
this->PrintType(value_storage_type, this->decl_stream);
this->decl_stream << ">;\n";
} else {
pod_args.push_back(arg);
}
}
// Store all pod arguments in a single buffer of int32
// do bitcast to change to other data types
// always pass gridDimX in to get around of the 65535 gridDim
// restrictions in some platforms
std::string type_pod_args = name_supply_->FreshName("PODArgs");
std::string val_pod_args = name_supply_->FreshName("podArgs");
std::string packGridDimX = name_supply_->FreshName("packGridDimX");
this->decl_stream << "\nstruct " << type_pod_args << " {\n";
for (size_t i = 0; i < pod_args.size(); ++i) {
const Var &v = pod_args[i];
ICHECK(!v.dtype().is_handle());
std::string vid = AllocVarID(v.get());
if (v.dtype() == DataType::Int(32)) {
this->decl_stream << " " << vid << ": i32";
} else if (v.dtype() == DataType::UInt(32)) {
this->decl_stream << " " << vid << ": u32";
} else if (v.dtype() == DataType::Float(32)) {
this->decl_stream << " " << vid << ": f32";
} else {
LOG(FATAL) << "Do not support pod argument type " << v.dtype();
}
this->decl_stream << ",\n";
// value ref
std::ostringstream vref;
vref << val_pod_args << "." << vid;
var_idmap_[v.get()] = vref.str();
}
this->decl_stream << " " << packGridDimX << ": u32\n}\n";
this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") "
<< "var<uniform> " << val_pod_args << " : " << type_pod_args
<< ";\n\n";
// setup thread tags and param access in launch param tags;
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto &thread_tag : opt.value()) {
func_info.launch_param_tags.push_back(thread_tag);
}
}
os_param_access << "]";
func_info.launch_param_tags.push_back(os_param_access.str());
ICHECK(!info.has_block_index_z) << "blockIdx.z is not supported in WebGPU to "
"accommodate large blockIdx.x";
// annotate workgroup
this->stream << "@compute @workgroup_size(" << info.workgroup_size[0] << ", "
<< info.workgroup_size[1] << ", " << info.workgroup_size[2]
<< ")\n";
// add to alloc buffer type.
// Function header.
this->stream << "fn " << func_info.name << "(\n"
<< " @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
<< " @builtin(num_workgroups) gridDim : vec3<u32>,\n"
<< " @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
<< ") {\n";
// skip out of bound grids
this->stream << " if (blockIdx.z * gridDim.x + blockIdx.x > " // NOLINT(*)
<< val_pod_args << "." << packGridDimX << ") { return; }\n";
// the function scope.
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
return func_info;
}
void CodeGenTileLangWebGPU::BindThreadIndex(const IterVar &iv) {
ICHECK(!var_idmap_.count(iv->var.get()));
std::ostringstream os;
PrintType(iv->var.dtype(), os);
if (iv->thread_tag == "blockIdx.x") {
// WebGPU have restriction to limit the maximum size of blockId.x to be
// 65535 We allow runtime to spread the load out to blockIdx.z so it can be
// a large number.
os << "(blockIdx.z * gridDim.x + blockIdx.x)";
std::string tidx = os.str();
std::string aggregated_bidx = SSAGetID(os.str(), iv->var.dtype());
var_idmap_[iv->var.get()] = aggregated_bidx;
} else {
os << "(" << iv->thread_tag << ")";
std::string tidx = os.str();
this->MarkConst(tidx);
var_idmap_[iv->var.get()] = tidx;
}
}
void CodeGenTileLangWebGPU::PrintType(DataType t,
std::ostream &os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
LOG(FATAL) << "Cannot print handle type in WebGPU";
}
if (t.is_void()) {
os << "void";
return;
}
if (t == DataType::Bool()) {
os << "bool";
return;
}
if (lanes != 1) {
// ICHECK(lanes >= 2 && lanes <= 4) << "CodeGenTileLangWebGPU: only allows
// vector with lanes in {2, 3, 4} " << " while lanes is " << lanes;
os << "vec" << lanes << "<";
}
if (t.is_float()) {
ICHECK(t.bits() == 16 || t.bits() == 32)
<< "CodeGenTileLangWebGPU: only support f16 or f32";
if (t.bits() == 16) {
// Using f16 requires enable directive
enable_fp16_ = true;
}
os << "f" << t.bits();
} else if (t.is_uint()) {
ICHECK(t.bits() != 64) << "CodeGenTileLangWebGPU: do not support u64";
os << "u" << t.bits();
} else if (t.is_int()) {
ICHECK(t.bits() != 64) << "CodeGenTileLangWebGPU: do not support i64";
os << "i" << t.bits();
} else {
LOG(FATAL) << "CodeGenTileLangWebGPU: Cannot convert type " << t
<< " to WebGPU type";
}
if (lanes != 1) {
os << ">";
}
}
void CodeGenTileLangWebGPU::PrintStorageSync(const CallNode *op) {
const std::string &sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
this->PrintIndent();
this->stream << "workgroupBarrier();\n";
} else if (sync == "shared") {
this->PrintIndent();
this->stream << "workgroupBarrier();\n";
} else if (sync == "global") {
LOG(FATAL) << "global barrier not supported";
}
}
void CodeGenTileLangWebGPU::PrintSSAAssign(const std::string &target,
const std::string &src,
DataType type) {
stream << "let " << target << " : ";
PrintType(type, stream);
stream << " = " << src << ";\n";
}
void CodeGenTileLangWebGPU::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
int lanes = op->dtype.lanes();
PrintType(op->dtype, os);
os << "(";
for (int i = 0; i < lanes; ++i) {
if (i != 0)
os << ", ";
os << v;
}
os << ')';
}
PrimExpr CodeGenTileLangWebGPU::EnforceU32(PrimExpr value) {
return cast(DataType::UInt(32, value.dtype().lanes()), value);
}
void CodeGenTileLangWebGPU::VisitExpr_(const CallNode *op,
std::ostream &os) { // NOLINT(*)
if (op->op.same_as(builtin::reinterpret())) {
// generate bitcast<TYPE>(ARG)
os << "bitcast<";
this->PrintType(op->dtype, os);
os << ">(";
this->PrintExpr(op->args[0], os);
os << ")";
} else if (op->op.same_as(builtin::shift_right())) {
os << '(';
this->PrintExpr(op->args[0], os);
os << ">>";
// WebGPU requires shift bits to be u32.
this->PrintExpr(EnforceU32(op->args[1]), os);
os << ')';
} else if (op->op.same_as(builtin::shift_left())) {
os << '(';
this->PrintExpr(op->args[0], os);
os << "<<";
// WebGPU requires shift bits to be u32.
this->PrintExpr(EnforceU32(op->args[1]), os);
os << ')';
} else if (op->op.same_as(builtin::if_then_else())) {
// conditional that skips eval if cond evals to false
std::string result = name_supply_->FreshName("condval");
std::string cond = PrintExpr(op->args[0]);
this->PrintIndent();
this->stream << "var " << result << " : ";
PrintType(op->dtype, this->stream);
this->stream << ";\n";
this->PrintIndent();
this->stream << "if (" << cond << ") {\n";
{
int then_scope = this->BeginScope();
std::string true_val = PrintExpr(op->args[1]);
this->PrintIndent();
this->stream << result << " = " << true_val << ";\n} else {\n";
this->EndScope(then_scope);
}
{
int else_scope = this->BeginScope();
std::string false_val = PrintExpr(op->args[2]);
this->PrintIndent();
this->stream << result << " = " << false_val << ";\n}\n";
this->EndScope(else_scope);
}
os << result;
} else {
CodeGenC::VisitExpr_(op, os);
}
}
void CodeGenTileLangWebGPU::VisitExpr_(const CastNode *op,
std::ostream &os) { // NOLINT(*)
PrintType(op->dtype, os);
os << "(" << PrintExpr(op->value) << ")";
}
void CodeGenTileLangWebGPU::VisitExpr_(const SelectNode *op,
std::ostream &os) { // NOLINT(*)
os << "select(" << PrintExpr(op->false_value) << ", "
<< PrintExpr(op->true_value) << ", " << PrintExpr(op->condition) << ")";
}
void CodeGenTileLangWebGPU::VisitExpr_(const IntImmNode *op,
std::ostream &os) { // NOLINT(*)
if (op->dtype.bits() == 32) {
std::ostringstream temp;
if (op->dtype.is_int()) {
temp << op->value << "i";
} else {
ICHECK(op->dtype.is_uint());
temp << op->value << "u";
}
this->MarkConst(temp.str());
os << temp.str();
} else {
this->PrintType(op->dtype, os);
os << "(" << op->value << ")";
}
}
void CodeGenTileLangWebGPU::VisitExpr_(const FloatImmNode *op,
std::ostream &os) { // NOLINT(*)
std::ostringstream temp;
temp << std::scientific << op->value;
if (op->dtype.bits() == 32) {
temp << 'f';
} else if (op->dtype.bits() == 16) {
// Using f16 requires enable directive
enable_fp16_ = true;
temp << 'h';
} else {
LOG(FATAL) << "Unsupported floating point bits " << op->dtype.bits();
}
MarkConst(temp.str());
os << temp.str();
}
void CodeGenTileLangWebGPU::VisitExpr_(const BufferLoadNode *op,
std::ostream &os) { // NOLINT(*)
// NOTE: direct impl of load/store for correctness
// Each printing stmt must stand on their own after all preprocessing steps
// to ensure correctness in the case of nested-expression
// do not try to lift common printings from each case
ICHECK_EQ(op->indices.size(), 1)
<< "Load from non-flat memory not supported.";
DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
DataType element_dtype = op->buffer->dtype;
int lanes = op->dtype.lanes();
std::string buffer_vid = GetVarID(buffer_var.get());
if (value_dtype.lanes() == element_dtype.lanes()) {
// Direct buffer loading
// Special handle bool loading
if (value_dtype == DataType::Bool()) {
this->PrintType(value_dtype, os);
os << "(";
} else {
ICHECK(value_dtype == element_dtype);
}
ICHECK_EQ(index.dtype().lanes(), 1);
os << buffer_vid << "[" << this->PrintExpr(index) << "]";
// Special handle bool loading
if (value_dtype == DataType::Bool()) {
os << ")";
}
} else {
// Vector load from scalar buffer
ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array";
ICHECK(value_dtype.element_of() == element_dtype)
<< "WebGPU vector loading requires base type to match";
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
// vec3<f32>(buf[base + 0], buf[base + 1], buf[base + 2]);
std::string base_vid =
SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype());
PrintType(element_dtype.with_lanes(value_dtype.lanes()), os);
os << "(";
for (int i = 0; i < lanes; ++i) {
if (i != 0)
os << ", ";
os << buffer_vid << "[" << base_vid << " + " << i << "]";
}
os << ")";
} else {
// vec3<f32>(buf[index[0]], buf[index[1]], buf[index[2]]);
std::string index_vid = SSAGetID(PrintExpr(index), index.dtype());
PrintType(element_dtype.with_lanes(value_dtype.lanes()), os);
os << "(";
for (int i = 0; i < lanes; ++i) {
if (i != 0)
os << ", ";
os << buffer_vid << "[" << index_vid << "[" << i << "]]";
}
os << ")";
}
}
}
void CodeGenTileLangWebGPU::VisitStmt_(const LetStmtNode *op) {
// use ssa form.
if (print_ssa_form_) {
std::string value = PrintExpr(op->value);
ICHECK(!var_idmap_.count(op->var.get()));
var_idmap_[op->var.get()] = value;
} else {
PrintIndent();
std::string value = PrintExpr(op->value);
this->stream << "let " << AllocVarID(op->var.get()) << " : ";
PrintType(op->var.dtype(), this->stream);
this->stream << " = " << value << ";\n";
}
PrintStmt(op->body);
}
void CodeGenTileLangWebGPU::VisitStmt_(const BufferStoreNode *op) {
CHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
DataType value_dtype = op->value.dtype();
DataType element_dtype = op->buffer->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
std::string buffer_vid = GetVarID(buffer_var.get());
if (value_dtype.lanes() == element_dtype.lanes()) {
// must execute print expr first
// so we won't have recursive append to stream
std::string index_vid = PrintExpr(index);
std::string value_vid = PrintExpr(op->value);
// now print the assignment line.
this->PrintIndent();
stream << buffer_vid << "[" << index_vid << "] = ";
// special explicit conversion of bool
if (value_dtype == DataType::Bool()) {
PrintType(element_dtype, stream);
stream << "(";
} else {
ICHECK(value_dtype == element_dtype);
}
stream << value_vid;
// Special handle bool store
if (value_dtype == DataType::Bool()) {
stream << ")";
}
stream << ";\n";
} else {
// Vector store into scalar buffer
ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array";
ICHECK(value_dtype.element_of() == element_dtype)
<< "WebGPU vector stire requires base type to match";
std::string value_vid = PrintExpr(op->value);
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, value_dtype.lanes()).Match(index)) {
// buf[base + 0] = value[0]
// buf[base + 1] = value[1]
std::string base_vid =
SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype());
for (int i = 0; i < value_dtype.lanes(); ++i) {
this->PrintIndent();
stream << buffer_vid << "[" << base_vid << " + " << i
<< "] = " << value_vid << "[" << i << "];\n";
}
} else {
// buf[index[0]] = value[0]
// buf[index[1]] = value[1]
std::string index_vid = SSAGetID(PrintExpr(index), index.dtype());
for (int i = 0; i < value_dtype.lanes(); ++i) {
this->PrintIndent();
stream << buffer_vid << "[" << index_vid << "[" << i
<< "]] = " << value_vid << "[" << i << "];\n";
}
}
}
}
void CodeGenTileLangWebGPU::VisitStmt_(const AllocateNode *op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (storage_scope.rank == runtime::StorageRank::kShared) {
this->decl_stream << "var<workgroup> " << vid << " : array<";
PrintType(op->dtype, this->decl_stream);
this->decl_stream << ", " << constant_size << ">;\n";
} else if (storage_scope.rank == runtime::StorageRank::kLocal) {
// TODO(Charlie): These code would cause non-uniformity as it introduces
// variables in module scope rather than function scope; but it was included
// for some unknown reasons; kept for now. this->decl_stream <<
// "var<private> " << vid << " : array<"; PrintType(op->dtype,
// this->decl_stream); this->decl_stream << ", " << constant_size << ">;\n";
this->PrintIndent();
this->stream << "var " << vid << " : array<";
PrintType(op->dtype, this->stream);
this->stream << ", " << constant_size << ">;\n";
} else {
LOG(FATAL) << "WebGPU: Do not support storage scope: "
<< storage_scope.to_string();
}
this->PrintStmt(op->body);
}
void CodeGenTileLangWebGPU::VisitStmt_(const ForNode *op) {
std::string extent = PrintExpr(op->extent);
std::string vid = AllocVarID(op->loop_var.get());
ICHECK(is_zero(op->min));
PrintIndent();
stream << "for (var " << vid << " : ";
PrintType(op->loop_var.dtype(), stream);
stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
PrintIndent();
stream << "}\n";
}
void CodeGenTileLangWebGPU::VisitStmt_(const AssertStmtNode *op) {
// skip assert
PrintStmt(op->body);
}
void CodeGenTileLangWebGPU::VisitStmt_(const AllocateConstNode *op) {
LOG(FATAL) << "WebGPU: do not support alloc const";
}
void CodeGenTileLangWebGPU::VisitStmt_(const WhileNode *op) {
PrintIndent();
stream << "while (true) {\n";
int while_scope = BeginScope();
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (!(" << cond << ")) { break; }\n";
PrintStmt(op->body);
this->EndScope(while_scope);
PrintIndent();
stream << "}\n";
}
//-------------------------------------------------
// WebGPUSourceModule to enable export
//-------------------------------------------------
class WebGPUSourceModuleNode final : public runtime::ModuleNode {
public:
explicit WebGPUSourceModuleNode(
std::unordered_map<std::string, std::string> smap,
std::unordered_map<std::string, runtime::FunctionInfo> fmap)
: smap_(smap), fmap_(fmap) {}
const char *type_key() const final { return "webgpu"; }
/*! \brief Get the property of the runtime module .*/
int GetPropertyMask() const final {
return runtime::ModulePropertyMask::kBinarySerializable;
}
ffi::Function GetFunction(const String &name,
const ObjectPtr<Object> &sptr_to_self) final {
LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run "
"through tvmjs";
return ffi::Function(nullptr);
}
void SaveToBinary(dmlc::Stream *stream) final {
stream->Write(fmap_);
stream->Write(smap_);
}
String GetSource(const String &format) final {
if (format == "func_info") {
std::ostringstream stream;
dmlc::JSONWriter(&stream).Write(fmap_);
return stream.str();
} else {
std::ostringstream os;
for (const auto &kv : smap_) {
os << kv.second;
}
return os.str();
}
}
private:
// function shader code table.
std::unordered_map<std::string, std::string> smap_;
// function information table.
std::unordered_map<std::string, runtime::FunctionInfo> fmap_;
};
//-------------------------------------------------
// Build logic.
//-------------------------------------------------
runtime::Module BuildTileLangWebGPU(IRModule mod, Target target) {
mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));
bool output_ssa = false;
bool skip_readonly_decl = false;
std::unordered_map<std::string, std::string> smap;
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
// narrow all i64 to i32
mod = tir::transform::ForceNarrowIndexToInt32()(std::move(mod));
for (auto kv : mod->functions) {
CodeGenTileLangWebGPU cg(target);
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangWebGPU: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenTileLangWebGPU: expect calling_conv equals "
"CallingConv::kDeviceKernelLaunch";
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined()) << "CodeGenTileLangWebGPU: Expect PrimFunc "
"to have the global_symbol attribute";
std::string f_name = global_symbol.value();
cg.Init(output_ssa);
fmap[f_name] = cg.AddFunction(f, skip_readonly_decl);
std::string code = cg.Finish();
smap[f_name] = code;
}
auto n = make_object<WebGPUSourceModuleNode>(smap, fmap);
return runtime::Module(n);
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.tilelang_webgpu",
[](IRModule mod, Target target) {
return BuildTileLangWebGPU(mod, target);
});
});
} // namespace codegen
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_webgpu.h
* \brief Generate WebGPU shaders in WGSL.
*
* This module generates WGSL shading language.
* See https://www.w3.org/TR/WGSL/ for the language reference.
*/
#ifndef TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
#define TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
#include <tvm/target/codegen.h>
#include <string>
#include "target/source/codegen_c.h"
namespace tvm {
namespace codegen {
/*!
* \brief WebGPU code generator.
*
* Note WGSL have a different syntax from normal C.
* We only leverage the C for expression generation and
* write most of the language generations.
*/
class CodeGenTileLangWebGPU final : public CodeGenC {
public:
explicit CodeGenTileLangWebGPU(Target target);
// overrides
std::string Finish() final;
using CodeGenC::AddFunction;
runtime::FunctionInfo AddFunction(const PrimFunc &f,
bool skip_readonly_decl); // NOLINT(*)
void InitFuncState(const PrimFunc &f) final;
void PrintStorageSync(const CallNode *op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
// assignment printing
void PrintSSAAssign(const std::string &target, const std::string &src,
DataType type) final;
// overload visitor
void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const CallNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BufferLoadNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const CastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const SelectNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const IntImmNode *op, std::ostream &os) final; // NOLINT(*)
// stmt printing
void VisitStmt_(const LetStmtNode *op) final;
void VisitStmt_(const BufferStoreNode *op) final;
void VisitStmt_(const ForNode *op) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AssertStmtNode *op) final;
void VisitStmt_(const AllocateConstNode *op) final;
void VisitStmt_(const WhileNode *op) final;
private:
/*!
* \brief Enforce value to be U32.
*/
static PrimExpr EnforceU32(PrimExpr value);
/*!
* \brief Storage type of bool values.
*/
DataType boolean_storage_type_{DataType::Int(8)};
// whether enable fp16
bool enable_fp16_{false};
/*! \brief the header stream for function label and enable directive if any,
* goes before any other declaration */
std::ostringstream header_stream;
Target target_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
......@@ -5,6 +5,7 @@
#include <tvm/tir/builtin.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
#include "target/intrin_rule.h"
namespace tvm {
......
......@@ -5,6 +5,7 @@
#include <tvm/tir/builtin.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
#include "target/intrin_rule.h"
namespace tvm {
......
......@@ -74,9 +74,9 @@ DataType DTypeFromString(const std::string str) {
return DataType::kInt64;
} else if (str == "uint64" || str == ".u64") {
return DataType::kUInt64;
} else if (str == "e4m3" || str == ".e4m3") {
} else if (str == "float8_e4m3" || str == "e4m3" || str == ".e4m3") {
return DataType::kFloat8_e4m3;
} else if (str == "e5m2" || str == ".e5m2") {
} else if (str == "float8_e5m2" || str == "e5m2" || str == ".e5m2") {
return DataType::kFloat8_e5m2;
} else if (str == "float16" || str == "fp16" || str == ".f16") {
return DataType::kFloat16;
......@@ -1529,5 +1529,20 @@ std::string PrintWaitBarrierAsm(const std::string &barrier) {
return predicated_asm_code;
}
std::string GetMMARegisterType(const ptx::DataType &dtype) {
switch (dtype) {
case ptx::DataType::kInt32:
return "unsigned";
case ptx::DataType::kUInt32:
return "unsigned";
case ptx::DataType::kFloat32:
return "float";
case ptx::DataType::kFloat64:
return "double";
default:
return "unsigned";
}
}
} // namespace codegen
} // namespace tvm::tl
......@@ -269,6 +269,11 @@ std::string PrintArriveBarrierExpectTxAsm(const std::string &barrier,
*/
std::string PrintWaitBarrierAsm(const std::string &barrier);
/*!
* \brief Return the register-level C++ type used by MMA fragments.
*/
std::string GetMMARegisterType(const ptx::DataType &dtype);
} // namespace codegen
} // namespace tvm::tl
......
#include "codegen_cpp.h"
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/reflection/registry.h>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace codegen {
runtime::Module BuildCPPHost(IRModule mod, Target target) {
ffi::Module BuildCPPHost(IRModule mod, Target target) {
bool output_ssa = false;
bool emit_asserts = false;
bool emit_fwd_func_decl = true;
......@@ -67,10 +70,10 @@ runtime::Module BuildCPPHost(IRModule mod, Target target) {
return CSourceModuleCreate(code, "c", cg.GetFunctionNames());
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.tilelang_cpp", BuildCPPHost);
});
}
} // namespace codegen
} // namespace tvm
......@@ -26,18 +26,19 @@ ExtractFuncInfo(const IRModule &mod) {
}
info.arg_types.push_back(f->params[i].dtype());
}
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag);
}
}
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info;
}
return fmap;
}
runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
ffi::Module BuildTileLangCUDA(IRModule mod, Target target) {
bool output_ssa = false;
CodeGenTileLangCUDA cg;
cg.Init(output_ssa);
......@@ -70,7 +71,7 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
return runtime::CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}
runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
ffi::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
bool output_ssa = false;
CodeGenTileLangCUDA cg;
cg.Init(output_ssa);
......@@ -93,13 +94,13 @@ runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code);
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("target.build.tilelang_cuda", BuildTileLangCUDA)
.def("target.build.tilelang_cuda_without_compile",
BuildTileLangCUDAWithoutCompile);
});
}
} // namespace codegen
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment