Unverified Commit a7c9a8b9 authored by Siyuan Feng's avatar Siyuan Feng Committed by GitHub
Browse files

Refactor to support upstream tvm (#595)

**Summarize part of the rebase pr:**

1. **Support T.thread_return() → CUDA return syntax**  
   Added support for translating `T.thread_return()` to CUDA's native `return` statement.

2. **Dynamic type support for function inputs**  
   Functions now accept dynamically typed parameters using `typing`:
   ```python
   dyn_type = T.int32 or T.float
   @T.prim_func
   def main(
       a: dyn_type,
   )
   ```

3. **Device Function Codegen**  
   Added support for generating `__device__` functions in CUDA:
   ```python
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def add(a: T.int32, b: T.int32) -> T.int32:
           return a + b

       @T.prim_func
       def main(
           A: T.Buffer((128, 128), "int32"),
           B: T.Buffer((128, 128), "int32"),
           C: T.Buffer((128, 128), "int32"),
       ):
           T.func_attr({"global_symbol": "main"})
           length: T.int32 = Module.add(64, 64)  # Host call
           for bx in...
parent 8edd6941
...@@ -46,7 +46,7 @@ public: ...@@ -46,7 +46,7 @@ public:
bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const; bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const;
static constexpr const char *_type_key = "tl.SwizzledLayout"; static constexpr const char *_type_key = "tl.SwizzledLayout";
bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const; bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const;
void VisitAttrs(tvm::AttrVisitor *v); static void RegisterReflection();
TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode); TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode);
private: private:
......
...@@ -130,7 +130,7 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs, ...@@ -130,7 +130,7 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
for (const IterVar &iter : input_iters) { for (const IterVar &iter : input_iters) {
IterMark iv_mark; IterMark iv_mark;
for (const IterMark &mark : collector.visited_) { for (const IterMark &mark : collector.visited_) {
if (mark->source.as<Var>().same_as(iter->var)) { if (mark->source.as<Var>()->same_as(iter->var)) {
iv_mark = mark; iv_mark = mark;
break; break;
} }
......
...@@ -40,7 +40,7 @@ static int to_CUtensorMapDataType(DataType dtype) { ...@@ -40,7 +40,7 @@ static int to_CUtensorMapDataType(DataType dtype) {
} }
} else if (dtype.is_bfloat16()) { } else if (dtype.is_bfloat16()) {
tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if (dtype.is_e4m3_float8() or dtype.is_e5m2_float8()) { } else if (dtype.is_float8_e4m3() || dtype.is_float8_e5m2()) {
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if (dtype.is_int()) { } else if (dtype.is_int()) {
switch (dtype.bits()) { switch (dtype.bits()) {
...@@ -111,6 +111,12 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -111,6 +111,12 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Stmt(); return Stmt();
} }
if (T.layout_map.count(global_tensor)) {
LOG(WARNING) << "TMA bulk copy cannot support a non-swizzled global "
"layout, fallback to normal copy.";
return Stmt();
}
Array<PrimExpr> indices; Array<PrimExpr> indices;
for (auto r : shared_range) for (auto r : shared_range)
indices.push_back(r->min); indices.push_back(r->min);
......
...@@ -154,7 +154,7 @@ For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -154,7 +154,7 @@ For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
annotations.Set("coalesced_width", coalesced_width); annotations.Set("coalesced_width", coalesced_width);
} }
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
ForKind::kParallel, body, NullOpt, annotations); ForKind::kParallel, body, std::nullopt, annotations);
} }
return Downcast<For>(body); return Downcast<For>(body);
} }
...@@ -254,12 +254,12 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -254,12 +254,12 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
IterVar col_var = loop_vars[loop_vars.size() - 1]; IterVar col_var = loop_vars[loop_vars.size() - 1];
IterVar row_var = loop_vars[loop_vars.size() - 2]; IterVar row_var = loop_vars[loop_vars.size() - 2];
PrimExpr local_layout_thread_map = PrimExpr local_layout_thread_map =
FloorMod(local_layout->ForwardThread(local_indices, NullOpt), 32); FloorMod(local_layout->ForwardThread(local_indices, std::nullopt), 32);
PrimExpr matrix_8x8_thread_map = makeGemmFragment8x8()->ForwardThread( PrimExpr matrix_8x8_thread_map = makeGemmFragment8x8()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt); {FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt);
PrimExpr matrix_8x8_thread_map_trans = PrimExpr matrix_8x8_thread_map_trans =
makeGemmFragment8x8Transposed()->ForwardThread( makeGemmFragment8x8Transposed()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt); {FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt);
PrimExpr local_indices_flattened = PrimExpr local_indices_flattened =
local_tensor.OffsetOf(local_indices_transformed).back(); local_tensor.OffsetOf(local_indices_transformed).back();
if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) && if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) &&
...@@ -376,13 +376,13 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -376,13 +376,13 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (T.layout_map.count(src) && T.layout_map.count(dst)) { if (T.layout_map.count(src) && T.layout_map.count(dst)) {
// Only compare fragment layout // Only compare fragment layout
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") { if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
const FragmentNode *src_layout = T.layout_map[src].as<Fragment>().get(); const auto &src_layout = T.layout_map[src].as<Fragment>();
const FragmentNode *dst_layout = T.layout_map[dst].as<Fragment>().get(); const auto &dst_layout = T.layout_map[dst].as<Fragment>();
if (src_layout && dst_layout) { if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true)) ICHECK((*src_layout)->IsEqual(dst_layout->get(), true))
<< "Get different layout for " << src << " and " << dst << "Get different layout for " << src << " and " << dst
<< "\nLHS = " << src_layout->DebugOutput() << "\nLHS = " << (*src_layout)->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput() << "\nRHS = " << (*dst_layout)->DebugOutput()
<< "\nYou may need to use a shared memory to transform the layout"; << "\nYou may need to use a shared memory to transform the layout";
} }
} }
......
...@@ -223,17 +223,13 @@ bool Gemm::CheckWGMMA() const { ...@@ -223,17 +223,13 @@ bool Gemm::CheckWGMMA() const {
if (C->dtype == DataType::Float(16)) { if (C->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0; return K % 16 == 0;
else if (A->dtype == DataType::NVFloat8E4M3() && else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
B->dtype == DataType::NVFloat8E4M3())
return (!trans_A) && trans_B && K % 32 == 0; return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E4M3() && else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
B->dtype == DataType::NVFloat8E5M2())
return (!trans_A) && trans_B && K % 32 == 0; return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E5M2() && else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
B->dtype == DataType::NVFloat8E4M3())
return (!trans_A) && trans_B && K % 32 == 0; return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E5M2() && else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
B->dtype == DataType::NVFloat8E5M2())
return (!trans_A) && trans_B && K % 32 == 0; return (!trans_A) && trans_B && K % 32 == 0;
else else
return false; return false;
...@@ -245,17 +241,13 @@ bool Gemm::CheckWGMMA() const { ...@@ -245,17 +241,13 @@ bool Gemm::CheckWGMMA() const {
return K % 16 == 0; return K % 16 == 0;
else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32)) else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
return (!trans_A) && trans_B && K % 8 == 0; return (!trans_A) && trans_B && K % 8 == 0;
else if (A->dtype == DataType::NVFloat8E4M3() && else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
B->dtype == DataType::NVFloat8E4M3())
return (!trans_A) && trans_B && K % 32 == 0; return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E4M3() && else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
B->dtype == DataType::NVFloat8E5M2())
return (!trans_A) && trans_B && K % 32 == 0; return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E5M2() && else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
B->dtype == DataType::NVFloat8E4M3())
return (!trans_A) && trans_B && K % 32 == 0; return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E5M2() && else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
B->dtype == DataType::NVFloat8E5M2())
return (!trans_A) && trans_B && K % 32 == 0; return (!trans_A) && trans_B && K % 32 == 0;
else else
return false; return false;
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* *
*/ */
#include <tvm/runtime/registry.h> #include <tvm/ffi/function.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* *
*/ */
#include <tvm/runtime/registry.h> #include <tvm/ffi/function.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
......
...@@ -22,7 +22,7 @@ using namespace tir; ...@@ -22,7 +22,7 @@ using namespace tir;
using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>; using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>;
using LayoutMap = Map<Buffer, Layout>; using LayoutMap = Map<Buffer, Layout>;
using BufferMap = Map<Var, Buffer>; using BufferMap = Map<Var, Buffer>;
using OpBuilderFunc = TypedPackedFunc<void *(Array<PrimExpr>, BufferMap)>; using OpBuilderFunc = ffi::TypedFunction<void *(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \ #define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op &Entry::Get() { \ const Op &Entry::Get() { \
......
...@@ -230,7 +230,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -230,7 +230,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
// Check if coalesced_width is defined // Check if coalesced_width is defined
if (auto coalesced_width = if (auto coalesced_width =
root_->annotations.Get(tl::attr::coalesced_width)) { root_->annotations.Get(tl::attr::coalesced_width)) {
if (const auto *imm = coalesced_width.as<IntImmNode>()) { if (const auto *imm = coalesced_width->as<IntImmNode>()) {
int expected = imm->value; int expected = imm->value;
// Verify that vector_size is divisible by expected // Verify that vector_size is divisible by expected
if (vector_size % expected != 0) { if (vector_size % expected != 0) {
...@@ -278,8 +278,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -278,8 +278,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
continue; continue;
auto vars = auto vars =
loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); }); loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
auto lhs = loop_layout_->ForwardThread(vars, NullOpt); auto lhs = loop_layout_->ForwardThread(vars, std::nullopt);
auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt); auto rhs = fragment->ForwardThread(indice_map_[buffer], std::nullopt);
auto diff = analyzer_.Simplify(lhs - rhs); auto diff = analyzer_.Simplify(lhs - rhs);
ICHECK(is_zero(diff)) ICHECK(is_zero(diff))
<< "Layout infer conflict for " << buffer << " " << source_buffer << "Layout infer conflict for " << buffer << " " << source_buffer
...@@ -304,11 +304,10 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -304,11 +304,10 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
source_buffer.scope() == "local.fragment") { source_buffer.scope() == "local.fragment") {
if (T.layout_map.count(buffer)) { if (T.layout_map.count(buffer)) {
const FragmentNode *src_layout = const FragmentNode *src_layout =
T.layout_map[buffer].as<Fragment>().get(); T.layout_map[buffer].as<FragmentNode>();
Fragment dst_layout_fragment = Fragment dst_layout_fragment =
CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds); CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
const FragmentNode *dst_layout = const FragmentNode *dst_layout = dst_layout_fragment.as<FragmentNode>();
dst_layout_fragment.as<Fragment>().get();
if (as_const_int(dst_layout->ReplicateExtent()) && if (as_const_int(dst_layout->ReplicateExtent()) &&
as_const_int(src_layout->ReplicateExtent()) && as_const_int(src_layout->ReplicateExtent()) &&
(*as_const_int(dst_layout->ReplicateExtent()) > (*as_const_int(dst_layout->ReplicateExtent()) >
...@@ -336,7 +335,7 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const { ...@@ -336,7 +335,7 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
if (predicate_.defined()) { if (predicate_.defined()) {
return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}}); return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
} else { } else {
return NullOpt; return std::nullopt;
} }
} }
...@@ -362,7 +361,8 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) { ...@@ -362,7 +361,8 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
PrimExpr thd_b = loop_layout_->ForwardThread( PrimExpr thd_b = loop_layout_->ForwardThread(
ind_inv->Forward(fwd), ind_inv->Forward(fwd),
FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt) return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent,
std::nullopt)
->CondenseReplicateVar(); ->CondenseReplicateVar();
} }
......
...@@ -201,7 +201,7 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -201,7 +201,7 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
for (int i = src_layout->OutputDim() - 1; i >= 0; i--) { for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
reduce_local = reduce_local =
For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent, For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
ForKind::kUnrolled, reduce_local, NullOpt, ForKind::kUnrolled, reduce_local, std::nullopt,
{{tir::attr::pragma_unroll_explicit, Bool(false)}}); {{tir::attr::pragma_unroll_explicit, Bool(false)}});
} }
stmts.push_back(reduce_local); stmts.push_back(reduce_local);
...@@ -213,7 +213,7 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -213,7 +213,7 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer); arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
for (const auto &iter_split : iter_sum->args) { for (const auto &iter_split : iter_sum->args) {
auto mark = iter_split->source->source.as<Var>(); auto mark = iter_split->source->source.as<Var>();
ICHECK(mark.defined()); ICHECK(mark) << "Not a normalized iterator: " << iter_split->source;
if (mark.value().same_as(src_vars[this->dim]->var)) { if (mark.value().same_as(src_vars[this->dim]->var)) {
auto scale = as_const_int(iter_split->scale); auto scale = as_const_int(iter_split->scale);
auto extent = as_const_int(iter_split->extent); auto extent = as_const_int(iter_split->extent);
...@@ -307,7 +307,7 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -307,7 +307,7 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
auto thd = src_layout->ForwardThread( auto thd = src_layout->ForwardThread(
fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
Fragment dst_layout = Fragment dst_layout =
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt) Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
->CondenseReplicateVar() ->CondenseReplicateVar()
->BindThreadRange(T.thread_bounds); ->BindThreadRange(T.thread_bounds);
return {{dst, dst_layout}}; return {{dst, dst_layout}};
......
...@@ -7,13 +7,12 @@ ...@@ -7,13 +7,12 @@
#include "runtime.h" #include "runtime.h"
#include "../target/cuda.h" #include "../target/cuda.h"
#include <tvm/runtime/registry.h> #include <tvm/ffi/function.h>
#include <tvm/node/node.h>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace runtime;
#if (CUDA_MAJOR_VERSION >= 12) #if (CUDA_MAJOR_VERSION >= 12)
template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) { template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
std::stringstream ss; std::stringstream ss;
...@@ -39,37 +38,35 @@ struct TensorMapArgs { ...@@ -39,37 +38,35 @@ struct TensorMapArgs {
CUtensorMapL2promotion l2Promotion; CUtensorMapL2promotion l2Promotion;
CUtensorMapFloatOOBfill oobFill; CUtensorMapFloatOOBfill oobFill;
static TensorMapArgs Extract(TVMArgs args) { static TensorMapArgs Extract(PackedArgs args) {
TensorMapArgs T; TensorMapArgs T;
int idx = 0; int idx = 0;
ICHECK(args.num_args >= 8); ICHECK(args.size() >= 8);
T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++])); T.map = reinterpret_cast<CUtensorMap *>(args[idx++].cast<void *>());
T.type = T.type = static_cast<CUtensorMapDataType>(args[idx++].cast<int64_t>());
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++])); T.tensorRank = static_cast<cuuint32_t>(args[idx++].cast<int64_t>());
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++])); T.globalAddress = args[idx++].cast<void *>();
T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5); ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5);
ICHECK(args.num_args == static_cast<int>(8 + T.tensorRank * 4)); ICHECK(args.size() == static_cast<int>(8 + T.tensorRank * 4));
for (size_t i = 0; i < T.tensorRank; i++) { for (size_t i = 0; i < T.tensorRank; i++) {
T.globalDim[i] = static_cast<cuuint64_t>(args[idx++]); T.globalDim[i] = args[idx++].cast<cuuint64_t>();
} }
for (size_t i = 0; i < T.tensorRank; i++) { for (size_t i = 0; i < T.tensorRank; i++) {
T.globalStride[i] = static_cast<cuuint64_t>(args[idx++]); T.globalStride[i] = args[idx++].cast<cuuint64_t>();
} }
for (size_t i = 0; i < T.tensorRank; i++) { for (size_t i = 0; i < T.tensorRank; i++) {
T.boxDim[i] = static_cast<cuuint64_t>(args[idx++]); T.boxDim[i] = args[idx++].cast<cuuint64_t>();
} }
for (size_t i = 0; i < T.tensorRank; i++) { for (size_t i = 0; i < T.tensorRank; i++) {
T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]); T.elementStrides[i] = args[idx++].cast<cuuint64_t>();
} }
T.interleave = T.interleave =
static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapInterleave>(args[idx++].cast<int64_t>());
T.swizzle = T.swizzle = static_cast<CUtensorMapSwizzle>(args[idx++].cast<int64_t>());
static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapL2promotion>(args[idx++].cast<int64_t>());
T.oobFill = T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapFloatOOBfill>(args[idx++].cast<int64_t>());
return T; return T;
} }
...@@ -93,20 +90,23 @@ struct TensorMapArgs { ...@@ -93,20 +90,23 @@ struct TensorMapArgs {
}; };
// set device api // set device api
TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled) TVM_FFI_STATIC_INIT_BLOCK({
.set_body([](TVMArgs args, TVMRetValue *ret) { namespace refl = tvm::ffi::reflection;
TensorMapArgs T = TensorMapArgs::Extract(args); refl::GlobalDef().def_packed(
CUresult result = cuTensorMapEncodeTiled( "tvm_tensormap_create_tiled", [](PackedArgs args, Any *ret) {
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, TensorMapArgs T = TensorMapArgs::Extract(args);
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave, CUresult result = cuTensorMapEncodeTiled(
T.swizzle, T.l2Promotion, T.oobFill); T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
if (result != CUDA_SUCCESS) { T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
LOG_FATAL << "Failed to initialize the TMA descriptor " << result T.swizzle, T.l2Promotion, T.oobFill);
<< std::endl if (result != CUDA_SUCCESS) {
<< T.ToDebugString(); LOG_FATAL << "Failed to initialize the TMA descriptor " << result
} << std::endl
*ret = static_cast<int>(result); << T.ToDebugString();
}); }
*ret = static_cast<int>(result);
});
});
struct TensorMapIm2ColArgs { struct TensorMapIm2ColArgs {
CUtensorMap *map; CUtensorMap *map;
...@@ -122,42 +122,40 @@ struct TensorMapIm2ColArgs { ...@@ -122,42 +122,40 @@ struct TensorMapIm2ColArgs {
CUtensorMapL2promotion l2Promotion; CUtensorMapL2promotion l2Promotion;
CUtensorMapFloatOOBfill oobFill; CUtensorMapFloatOOBfill oobFill;
static TensorMapIm2ColArgs Extract(TVMArgs args) { static TensorMapIm2ColArgs Extract(PackedArgs args) {
TensorMapIm2ColArgs T; TensorMapIm2ColArgs T;
int idx = 0; int idx = 0;
ICHECK(args.num_args >= 8); ICHECK(args.size() >= 8);
T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++])); T.map = reinterpret_cast<CUtensorMap *>(args[idx++].cast<void *>());
T.type = T.type = static_cast<CUtensorMapDataType>(args[idx++].cast<int64_t>());
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++])); T.tensorRank = static_cast<cuuint32_t>(args[idx++].cast<int64_t>());
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++])); T.globalAddress = args[idx++].cast<void *>();
T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5); ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5);
ICHECK(args.num_args == static_cast<int>(6 + T.tensorRank * 5)); ICHECK(args.size() == static_cast<int>(6 + T.tensorRank * 5));
for (size_t i = 0; i < T.tensorRank; i++) { for (size_t i = 0; i < T.tensorRank; i++) {
T.globalDim[i] = static_cast<cuuint64_t>(args[idx++]); T.globalDim[i] = args[idx++].cast<cuuint64_t>();
} }
for (size_t i = 0; i < T.tensorRank; i++) { for (size_t i = 0; i < T.tensorRank; i++) {
T.globalStride[i] = static_cast<cuuint64_t>(args[idx++]); T.globalStride[i] = args[idx++].cast<cuuint64_t>();
} }
for (size_t i = 0; i < T.tensorRank; i++) { for (size_t i = 0; i < T.tensorRank; i++) {
T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]); T.elementStrides[i] = args[idx++].cast<cuuint64_t>();
} }
for (size_t i = 0; i < T.tensorRank - 2; i++) { for (size_t i = 0; i < T.tensorRank - 2; i++) {
T.pixelBoxLowerCorner[i] = static_cast<int>(args[idx++]); T.pixelBoxLowerCorner[i] = args[idx++].cast<int>();
} }
for (size_t i = 0; i < T.tensorRank - 2; i++) { for (size_t i = 0; i < T.tensorRank - 2; i++) {
T.pixelBoxUpperCorner[i] = static_cast<int>(args[idx++]); T.pixelBoxUpperCorner[i] = args[idx++].cast<int>();
} }
T.smem_box_pixel = static_cast<cuuint64_t>(args[idx++]); T.smem_box_pixel = args[idx++].cast<cuuint64_t>();
T.smem_box_channel = static_cast<cuuint64_t>(args[idx++]); T.smem_box_channel = args[idx++].cast<cuuint64_t>();
T.interleave = T.interleave =
static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapInterleave>(args[idx++].cast<int64_t>());
T.swizzle = T.swizzle = static_cast<CUtensorMapSwizzle>(args[idx++].cast<int64_t>());
static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapL2promotion>(args[idx++].cast<int64_t>());
T.oobFill = T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapFloatOOBfill>(args[idx++].cast<int64_t>());
return T; return T;
} }
...@@ -185,21 +183,25 @@ struct TensorMapIm2ColArgs { ...@@ -185,21 +183,25 @@ struct TensorMapIm2ColArgs {
} }
}; };
TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col) TVM_FFI_STATIC_INIT_BLOCK({
.set_body([](TVMArgs args, TVMRetValue *ret) { namespace refl = tvm::ffi::reflection;
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args); refl::GlobalDef().def_packed(
CUresult result = cuTensorMapEncodeIm2col( "tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) {
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
T.globalStride + 1, T.pixelBoxLowerCorner, T.pixelBoxUpperCorner, CUresult result = cuTensorMapEncodeIm2col(
T.smem_box_channel, T.smem_box_pixel, T.elementStrides, T.interleave, T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.swizzle, T.l2Promotion, T.oobFill); T.globalStride + 1, T.pixelBoxLowerCorner, T.pixelBoxUpperCorner,
if (result != CUDA_SUCCESS) { T.smem_box_channel, T.smem_box_pixel, T.elementStrides,
LOG_FATAL << "Failed to initialize the TMA descriptor " << result T.interleave, T.swizzle, T.l2Promotion, T.oobFill);
<< std::endl if (result != CUDA_SUCCESS) {
<< T.ToDebugString(); LOG_FATAL << "Failed to initialize the TMA descriptor " << result
} << std::endl
*ret = static_cast<int>(result); << T.ToDebugString();
}); }
*ret = static_cast<int>(result);
});
});
#endif // (CUDA_MAJOR_VERSION >= 12) #endif // (CUDA_MAJOR_VERSION >= 12)
} // namespace tl } // namespace tl
......
...@@ -22,27 +22,22 @@ ...@@ -22,27 +22,22 @@
*/ */
#include "codegen_cpp.h" #include "codegen_cpp.h"
#include <tvm/relay/executor.h>
#include <tvm/relay/runtime.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <algorithm>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector>
#include "support/str_escape.h" #include "support/str_escape.h"
#include "target/build_common.h" #include "target/build_common.h"
#include "target/func_registry_generator.h"
#include "target/source/codegen_params.h" #include "target/source/codegen_params.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
CodeGenTileLangCPP::CodeGenTileLangCPP() { CodeGenTileLangCPP::CodeGenTileLangCPP() {
module_name_ = name_supply_->FreshName("__tvm_module_ctx"); module_name_ = name_supply_->FreshName("__tvm_ffi_library_ctx");
} }
void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts, void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts,
...@@ -59,7 +54,7 @@ void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts, ...@@ -59,7 +54,7 @@ void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts,
} }
void CodeGenTileLangCPP::InitGlobalContext() { void CodeGenTileLangCPP::InitGlobalContext() {
decl_stream << "void* " << tvm::runtime::symbol::tvm_module_ctx decl_stream << "void* " << tvm::runtime::symbol::tvm_ffi_library_ctx
<< " = NULL;\n"; << " = NULL;\n";
} }
...@@ -384,13 +379,13 @@ void CodeGenTileLangCPP::VisitExpr_(const CallNode *op, ...@@ -384,13 +379,13 @@ void CodeGenTileLangCPP::VisitExpr_(const CallNode *op,
const std::string &type = op->args[0].as<StringImmNode>()->value; const std::string &type = op->args[0].as<StringImmNode>()->value;
const IntImmNode *num = op->args[1].as<IntImmNode>(); const IntImmNode *num = op->args[1].as<IntImmNode>();
ICHECK(num != nullptr); ICHECK(num != nullptr);
static_assert(alignof(TVMValue) % alignof(DLTensor) == 0, "invariant"); static_assert(alignof(TVMFFIAny) % alignof(DLTensor) == 0, "invariant");
size_t unit = sizeof(TVMValue); size_t unit = sizeof(TVMFFIAny);
size_t size = 0; size_t size = 0;
if (type == "shape") { if (type == "shape") {
size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit; size = (num->value * sizeof(runtime::tvm_index_t) + unit - 1) / unit;
} else if (type == "arg_value") { } else if (type == "arg_value") {
size = (num->value * sizeof(TVMValue) + unit - 1) / unit; size = (num->value * sizeof(TVMFFIAny) + unit - 1) / unit;
} else if (type == "arg_tcode") { } else if (type == "arg_tcode") {
size = (num->value * sizeof(int) + unit - 1) / unit; size = (num->value * sizeof(int) + unit - 1) / unit;
} else if (type == "array") { } else if (type == "array") {
...@@ -399,7 +394,7 @@ void CodeGenTileLangCPP::VisitExpr_(const CallNode *op, ...@@ -399,7 +394,7 @@ void CodeGenTileLangCPP::VisitExpr_(const CallNode *op,
LOG(FATAL) << "Unknown stack alloca type " << type; LOG(FATAL) << "Unknown stack alloca type " << type;
} }
this->PrintIndent(); this->PrintIndent();
this->stream << "TVMValue " << stack_name << "[" << size << "];\n"; this->stream << "TVMFFIAny " << stack_name << "[" << size << "];\n";
os << stack_name; os << stack_name;
} else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
auto function_info = GetFunctionInfo(op, false /* has_resource_handle */); auto function_info = GetFunctionInfo(op, false /* has_resource_handle */);
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "codegen_cuda.h" #include "codegen_cuda.h"
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h> #include <tvm/ffi/function.h>
#include <tvm/tir/index_map.h> #include <tvm/tir/index_map.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
...@@ -39,15 +39,75 @@ static std::string GetFP8Type(DataType type) { ...@@ -39,15 +39,75 @@ static std::string GetFP8Type(DataType type) {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) " LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) "
"for FP8"; "for FP8";
} }
if (type.code() == DataType::kFloat8_e4m3fn) { if (type.is_float8_e4m3fn() || type.is_float8_e4m3fnuz() ||
type.is_float8_e4m3()) {
stream << "fp8_e4" << vec << "_t"; stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e4m3fnuz) { } else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() ||
stream << "fp8_e4" << vec << "_t"; type.is_float8_e5m2()) {
} else if (type.code() == DataType::kFloat8_e5m2) {
stream << "fp8_e5" << vec << "_t"; stream << "fp8_e5" << vec << "_t";
} else { } else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen"; LOG(FATAL) << "Unsupported FP8 type in CUDA codegen but got " << type;
}
return stream.str();
}
std::string GetFP6Type(DataType type) {
std::stringstream stream;
int32_t lanes = type.lanes();
std::string vec;
if (type.is_scalar()) {
vec = "";
} else if (lanes == 2) {
vec = "x2";
} else if (lanes == 4) {
vec = "x4";
} else if (lanes == 8) {
vec = "x8";
} else if (lanes == 16) {
vec = "x16";
} else {
LOG(FATAL)
<< "Only support scalar and vector types of width (2, 4) for FP6";
}
stream << "__nv_fp6";
std::string suffix;
if (type.code() == DataType::kFloat6_e2m3fn) {
suffix = "_e2m3";
} else if (type.code() == DataType::kFloat6_e3m2fn) {
suffix = "_e3m2";
} else {
LOG(FATAL) << "Unsupported FP6 type in CUDA codegen";
}
stream << vec << suffix;
return stream.str();
}
std::string GetFP4Type(DataType type) {
std::stringstream stream;
int32_t lanes = type.lanes();
std::string vec;
if (type.is_scalar()) {
vec = "";
} else if (lanes == 2) {
vec = "x2";
} else if (lanes == 4) {
vec = "x4";
} else if (lanes == 8) {
vec = "x8";
} else if (lanes == 16) {
vec = "x16";
} else {
LOG(FATAL)
<< "Only support scalar and vector types of width (2, 4) for FP4";
}
stream << "__nv_fp4";
std::string suffix;
if (type.code() == DataType::kFloat4_e2m1fn) {
suffix = "_e2m1";
} else {
LOG(FATAL) << "Unsupported FP4 type in CUDA codegen";
} }
stream << vec << suffix;
return stream.str(); return stream.str();
} }
...@@ -259,6 +319,22 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -259,6 +319,22 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
enable_fp8_ = true; enable_fp8_ = true;
os << GetFP8Type(t); os << GetFP8Type(t);
return; return;
} else if (t.is_float6()) {
enable_fp6_ = true;
if (t.lanes() <= 4) {
os << GetFP6Type(t);
} else {
fail = true;
}
return;
} else if (t.is_float4()) {
enable_fp4_ = true;
if (t.lanes() <= 4) {
os << GetFP4Type(t);
} else {
fail = true;
}
return;
} else if (t == DataType::Bool()) { } else if (t == DataType::Bool()) {
os << "bool"; os << "bool";
return; return;
...@@ -678,7 +754,7 @@ void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, ...@@ -678,7 +754,7 @@ void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol,
bool skip_first_arg, bool skip_first_arg,
std::ostream &os) { // NOLINT(*) std::ostream &os) { // NOLINT(*)
DataType ret_dtype = GetRuntimeDataType(ret_type); DataType ret_dtype = GetRuntimeDataType(ret_type);
if (ret_dtype.is_vector()) { if (ret_dtype.is_fixed_length_vector()) {
// //
// Emit an unsupported vector call // Emit an unsupported vector call
// //
...@@ -799,13 +875,19 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, ...@@ -799,13 +875,19 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) { auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) {
this->PrintIndent(); // Cache context into a private ss, otherwise the let node may generate
this->stream << name << "("; // within the function call arguments.
std::ostringstream ss;
for (size_t i = offset; i < op->args.size(); i++) { for (size_t i = offset; i < op->args.size(); i++) {
if (i > offset) if (i > offset)
this->stream << ", "; ss << ", ";
this->stream << this->PrintExpr(op->args[i]); ss << this->PrintExpr(op->args[i]);
} }
this->PrintIndent();
this->stream << name << "(";
this->stream << ss.str();
this->stream << ");\n"; this->stream << ");\n";
}; };
if (op->op.same_as(builtin::ptx_cp_async())) { if (op->op.same_as(builtin::ptx_cp_async())) {
...@@ -858,22 +940,24 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -858,22 +940,24 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(tl::sync_thread_partial())) { } else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("tl::syncthreads_partial"); print_extern_call_stmt("tl::syncthreads_partial");
} else if (op->op.same_as(tl::tma_load())) { } else if (op->op.same_as(tl::tma_load())) {
this->PrintIndent(); std::ostringstream ss;
ICHECK_GE(op->args.size(), 2); ICHECK_GE(op->args.size(), 2);
this->stream << "tl::tma_load("; ss << "tl::tma_load(";
auto desc = op->args[0]; auto desc = op->args[0];
this->stream << this->PrintExpr(desc) << ", "; ss << this->PrintExpr(desc) << ", ";
if (const IntImmNode *imm = op->args[1].as<IntImmNode>()) { if (const IntImmNode *imm = op->args[1].as<IntImmNode>()) {
this->stream << "_mbarrier[" << imm->value << "], "; ss << "_mbarrier[" << imm->value << "], ";
} else { } else {
this->stream << this->PrintExpr(op->args[1]) << ", "; ss << this->PrintExpr(op->args[1]) << ", ";
} }
for (size_t i = 2; i < op->args.size(); i++) { for (size_t i = 2; i < op->args.size(); i++) {
if (i > 2) if (i > 2)
this->stream << ", "; ss << ", ";
this->stream << this->PrintExpr(op->args[i]); ss << this->PrintExpr(op->args[i]);
} }
this->stream << ");\n"; ss << ");\n";
this->PrintIndent();
this->stream << ss.str();
} else if (op->op.same_as(tl::tma_load_im2col())) { } else if (op->op.same_as(tl::tma_load_im2col())) {
print_extern_call_stmt("tl::tma_load_im2col"); print_extern_call_stmt("tl::tma_load_im2col");
} else if (op->op.same_as(tl::tma_store())) { } else if (op->op.same_as(tl::tma_store())) {
...@@ -1111,8 +1195,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1111,8 +1195,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
// To store the 32x8 output back to a 16x16 tile in shared or global memory, // To store the 32x8 output back to a 16x16 tile in shared or global memory,
// we invert this map to determine the output location for each 8 element. // we invert this map to determine the output location for each 8 element.
const auto *index_map_func = const auto index_map_func = ffi::Function::GetGlobal(
runtime::Registry::Get("tir.index_map.shared_16x16_to_mma_32x8_layout"); "tir.index_map.shared_16x16_to_mma_32x8_layout");
IndexMap index_map; IndexMap index_map;
if (!index_map_func) { if (!index_map_func) {
...@@ -1289,6 +1373,100 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1289,6 +1373,100 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr
<< ")), \"r\"((int)" << guard << ")\n"; << ")), \"r\"((int)" << guard << ")\n";
stream << ");\n"; stream << ");\n";
} else if (op->op.same_as(builtin::reinterpret())) {
DataType tgt_dtype = op->dtype;
DataType src_dtype = op->args[0]->dtype;
PrimExpr value = op->args[0];
// Handle float4_e2m1fn reinterpret
if (!src_dtype.is_float4_e2m1fn() && !tgt_dtype.is_float4_e2m1fn()) {
return CodeGenC::VisitExpr_(op, os);
}
if (src_dtype == tgt_dtype || tgt_dtype.lanes() * tgt_dtype.bits() ==
src_dtype.lanes() * src_dtype.bits()) {
return CodeGenC::VisitExpr_(op, os);
}
CHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes())
<< "E2M1 float4 reinterpret expects source and target to have the same "
"number of lanes. "
<< "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;
CHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes())
<< "E2M1 float4 reinterpret expects source and target to have the same "
"number of bytes. "
<< "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;
int lanes = tgt_dtype.lanes();
int ssa_scope = BeginScope();
if (lanes == 1) {
// The case of lane=1 is same as the normal reinterpret,
// except that we allow the src and dst dtype to have different number of
// bits.
std::string rhs = SSAGetID(PrintExpr(value), src_dtype);
os << "(*(";
this->PrintType(tgt_dtype, os);
os << " *)(&(" << rhs << ")))";
} else if (lanes == 2) {
if (tgt_dtype.is_float4_e2m1fn()) {
// We view the source as an uint16, and then extract bits of two fp4
// numbers, and finally reinterpret the result as fp4x2.
value =
tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value});
tir::Var temp_var("temp_var", DataType::UInt(16));
value =
tir::Let(temp_var, value,
tir::Cast(DataType::UInt(8),
(temp_var & IntImm(DataType::UInt(16), 0xF)) |
((temp_var >> 4) &
IntImm(DataType::UInt(16), 0xF0))));
} else {
value = tir::Cast(
DataType::UInt(16),
tir::Call(DataType::UInt(8), tir::builtin::reinterpret(), {value}));
tir::Var temp_var("temp_var", DataType::UInt(16));
value =
tir::Let(temp_var, value,
(temp_var & IntImm(DataType::UInt(16), 0xF)) |
((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 4));
}
os << PrintExpr(
tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value}));
} else if (lanes == 4) {
if (tgt_dtype.is_float4_e2m1fn()) {
// We view the source as an uint32, and then extract bits of four fp4
// numbers, and finally reinterpret the result as fp4x4.
value =
tir::Call(DataType::UInt(32), tir::builtin::reinterpret(), {value});
tir::Var temp_var("temp_var", DataType::UInt(32));
value = tir::Let(
temp_var, value,
tir::Cast(
DataType::UInt(16),
(temp_var & IntImm(DataType::UInt(32), 0xF)) |
((temp_var >> 4) & IntImm(DataType::UInt(32), 0xF0)) |
((temp_var >> 8) & IntImm(DataType::UInt(32), 0xF00)) |
((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000))));
} else {
value = tir::Cast(DataType::UInt(32),
tir::Call(DataType::UInt(16),
tir::builtin::reinterpret(), {value}));
tir::Var temp_var("temp_var", DataType::UInt(32));
value = tir::Let(
temp_var, value,
(temp_var & IntImm(DataType::UInt(32), 0xF)) |
((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 4) |
((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) |
((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12));
}
os << PrintExpr(
tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value}));
} else {
LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: "
<< lanes;
}
EndScope(ssa_scope);
} else if (op->op.same_as(builtin::thread_return())) {
os << "return";
} else { } else {
CodeGenC::VisitExpr_(op, os); CodeGenC::VisitExpr_(op, os);
} }
......
...@@ -80,16 +80,21 @@ private: ...@@ -80,16 +80,21 @@ private:
std::string vid_global_barrier_state_; std::string vid_global_barrier_state_;
// Global barrier expected node. // Global barrier expected node.
std::string vid_global_barrier_expect_; std::string vid_global_barrier_expect_;
// whether enable fp16 // whether enable fp16
bool enable_fp16_{false}; bool enable_fp16_{false};
// whether enable bf16 // whether enable bf16
bool enable_bf16_{false}; bool enable_bf16_{false};
// whether enable fp8 // whether enable fp8
bool enable_fp8_{false}; bool enable_fp8_{false};
// whether enable sparse gemm // whether enable fp6
bool enable_sparse_gemm_{false}; bool enable_fp6_{false};
// whether enable fp4
bool enable_fp4_{false};
// whether enable int8 // whether enable int8
bool enable_int8_{false}; bool enable_int8_{false};
// whether enable sparse gemm
bool enable_sparse_gemm_{false};
// whether enable warp shuffle intrinsics // whether enable warp shuffle intrinsics
bool enable_warp_shuffle_{false}; bool enable_warp_shuffle_{false};
// whether need math_constants.h // whether need math_constants.h
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* \file codegen_webgpu.cc * \file codegen_webgpu.cc
*/ */
#include "codegen_webgpu.h" #include "codegen_webgpu.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
...@@ -704,11 +705,11 @@ public: ...@@ -704,11 +705,11 @@ public:
return runtime::ModulePropertyMask::kBinarySerializable; return runtime::ModulePropertyMask::kBinarySerializable;
} }
PackedFunc GetFunction(const String &name, ffi::Function GetFunction(const String &name,
const ObjectPtr<Object> &sptr_to_self) final { const ObjectPtr<Object> &sptr_to_self) final {
LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run " LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run "
"through tvmjs"; "through tvmjs";
return PackedFunc(nullptr); return ffi::Function(nullptr);
} }
void SaveToBinary(dmlc::Stream *stream) final { void SaveToBinary(dmlc::Stream *stream) final {
...@@ -773,10 +774,13 @@ runtime::Module BuildTileLangWebGPU(IRModule mod, Target target) { ...@@ -773,10 +774,13 @@ runtime::Module BuildTileLangWebGPU(IRModule mod, Target target) {
return runtime::Module(n); return runtime::Module(n);
} }
TVM_REGISTER_GLOBAL("target.build.tilelang_webgpu") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed([](IRModule mod, Target target) { namespace refl = tvm::ffi::reflection;
return BuildTileLangWebGPU(mod, target); refl::GlobalDef().def("target.build.tilelang_webgpu",
}); [](IRModule mod, Target target) {
return BuildTileLangWebGPU(mod, target);
});
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#include "codegen_cpp.h" #include "codegen_cpp.h"
#include <tvm/ffi/reflection/registry.h>
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
runtime::Module BuildCPPHost(IRModule mod, Target target) { runtime::Module BuildCPPHost(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
bool emit_asserts = false; bool emit_asserts = false;
bool emit_fwd_func_decl = true; bool emit_fwd_func_decl = true;
...@@ -67,7 +67,10 @@ runtime::Module BuildCPPHost(IRModule mod, Target target) { ...@@ -67,7 +67,10 @@ runtime::Module BuildCPPHost(IRModule mod, Target target) {
return CSourceModuleCreate(code, "c", cg.GetFunctionNames()); return CSourceModuleCreate(code, "c", cg.GetFunctionNames());
} }
TVM_REGISTER_GLOBAL("target.build.tilelang_cpp").set_body_typed(BuildCPPHost); TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.tilelang_cpp", BuildCPPHost);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#include "codegen_cuda.h" #include "codegen_cuda.h"
#include "runtime/cuda/cuda_module.h" #include "runtime/cuda/cuda_module.h"
#include "runtime/pack_args.h"
#include <tvm/ffi/reflection/registry.h>
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -18,7 +20,7 @@ ExtractFuncInfo(const IRModule &mod) { ...@@ -18,7 +20,7 @@ ExtractFuncInfo(const IRModule &mod) {
if (f->params[i]->dtype.is_handle()) { if (f->params[i]->dtype.is_handle()) {
auto ptr = f->params[i]->type_annotation.as<PointerTypeNode>(); auto ptr = f->params[i]->type_annotation.as<PointerTypeNode>();
if (ptr && ptr->storage_scope == "grid_constant") { if (ptr && ptr->storage_scope == "grid_constant") {
info.arg_types.push_back(DataType(kTVMGridConstant, 64, 1)); info.arg_types.push_back(DataType(runtime::kDLGridConstant, 64, 1));
continue; continue;
} }
} }
...@@ -36,7 +38,6 @@ ExtractFuncInfo(const IRModule &mod) { ...@@ -36,7 +38,6 @@ ExtractFuncInfo(const IRModule &mod) {
} }
runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
CodeGenTileLangCUDA cg; CodeGenTileLangCUDA cg;
cg.Init(output_ssa); cg.Init(output_ssa);
...@@ -52,13 +53,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { ...@@ -52,13 +53,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto *f = Registry::Get("tilelang_callback_cuda_postproc")) { if (const auto f =
code = (*f)(code, target).operator std::string(); ffi::Function::GetGlobal("tilelang_callback_cuda_postproc")) {
code = (*f)(code, target).cast<std::string>();
} }
std::string fmt = "ptx"; std::string fmt = "ptx";
std::string ptx; std::string ptx;
if (const auto *f = Registry::Get("tilelang_callback_cuda_compile")) { if (const auto f =
ptx = (*f)(code, target).operator std::string(); ffi::Function::GetGlobal("tilelang_callback_cuda_compile")) {
ptx = (*f)(code, target).cast<std::string>();
if (ptx[0] != '/') if (ptx[0] != '/')
fmt = "cubin"; fmt = "cubin";
} else { } else {
...@@ -68,7 +71,6 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { ...@@ -68,7 +71,6 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
} }
runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) { runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
CodeGenTileLangCUDA cg; CodeGenTileLangCUDA cg;
cg.Init(output_ssa); cg.Init(output_ssa);
...@@ -84,16 +86,20 @@ runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) { ...@@ -84,16 +86,20 @@ runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto *f = Registry::Get("tilelang_callback_cuda_postproc")) { if (const auto f =
code = (*f)(code, target).operator std::string(); ffi::Function::GetGlobal("tilelang_callback_cuda_postproc")) {
code = (*f)(code, target).cast<std::string>();
} }
return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code); return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code);
} }
TVM_REGISTER_GLOBAL("target.build.tilelang_cuda") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(BuildTileLangCUDA); namespace refl = tvm::ffi::reflection;
TVM_REGISTER_GLOBAL("target.build.tilelang_cuda_without_compile") refl::GlobalDef()
.set_body_typed(BuildTileLangCUDAWithoutCompile); .def("target.build.tilelang_cuda", BuildTileLangCUDA)
.def("target.build.tilelang_cuda_without_compile",
BuildTileLangCUDAWithoutCompile);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#if defined(__linux__) #if defined(__linux__)
#include <sys/stat.h> #include <sys/stat.h>
#include <tvm/ffi/reflection/registry.h>
#endif #endif
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
...@@ -95,10 +96,13 @@ runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) { ...@@ -95,10 +96,13 @@ runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
return ROCMModuleCreate("ptx", "fmt", ExtractFuncInfo(mod), code, return ROCMModuleCreate("ptx", "fmt", ExtractFuncInfo(mod), code,
std::string()); std::string());
} }
TVM_REGISTER_GLOBAL("target.build.tilelang_hip") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(BuildTileLangHIP); namespace refl = tvm::ffi::reflection;
TVM_REGISTER_GLOBAL("target.build.tilelang_hip_without_compile") refl::GlobalDef()
.set_body_typed(BuildTileLangHIPWithoutCompile); .def("target.build.tilelang_hip", BuildTileLangHIP)
.def("target.build.tilelang_hip_without_compile",
BuildTileLangHIPWithoutCompile);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \brief align dynamic shared memory allocations * \brief align dynamic shared memory allocations
*/ */
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
...@@ -147,8 +148,11 @@ tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) { ...@@ -147,8 +148,11 @@ tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) {
"tl.AlignDynamicSharedMemoryAllocations", {}); "tl.AlignDynamicSharedMemoryAllocations", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.AlignDynamicSharedMemoryAllocations") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(AlignDynamicSharedMemoryAllocations); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AlignDynamicSharedMemoryAllocations",
AlignDynamicSharedMemoryAllocations);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -22,8 +22,9 @@ ...@@ -22,8 +22,9 @@
* \brief Split device function from host. * \brief Split device function from host.
*/ */
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h> #include <tvm/ir/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h> #include <tvm/target/target.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
...@@ -87,8 +88,11 @@ tvm::transform::Pass AnnotateDeviceRegions() { ...@@ -87,8 +88,11 @@ tvm::transform::Pass AnnotateDeviceRegions() {
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateDeviceRegions", {}); return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateDeviceRegions", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.AnnotateDeviceRegions") TVM_FFI_STATIC_INIT_BLOCK({
.set_body_typed(AnnotateDeviceRegions); namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AnnotateDeviceRegions",
AnnotateDeviceRegions);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment