Unverified Commit a7c9a8b9 authored by Siyuan Feng's avatar Siyuan Feng Committed by GitHub
Browse files

Refactor to support upstream tvm (#595)

**Summarize part of the rebase pr:**

1. **Support T.thread_return() → CUDA return syntax**  
   Added support for translating `T.thread_return()` to CUDA's native `return` statement.

2. **Dynamic type support for function inputs**  
   Functions now accept dynamically typed parameters using `typing`:
   ```python
   dyn_type = T.int32 or T.float
   @T.prim_func
   def main(
       a: dyn_type,
   )
   ```

3. **Device Function Codegen**  
   Added support for generating `__device__` functions in CUDA:
   ```python
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def add(a: T.int32, b: T.int32) -> T.int32:
           return a + b

       @T.prim_func
       def main(
           A: T.Buffer((128, 128), "int32"),
           B: T.Buffer((128, 128), "int32"),
           C: T.Buffer((128, 128), "int32"),
       ):
           T.func_attr({"global_symbol": "main"})
           length: T.int32 = Module.add(64, 64)  # Host call
           for bx in...
parent 8edd6941
......@@ -46,7 +46,7 @@ public:
bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const;
static constexpr const char *_type_key = "tl.SwizzledLayout";
bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const;
void VisitAttrs(tvm::AttrVisitor *v);
static void RegisterReflection();
TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode);
private:
......
......@@ -130,7 +130,7 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
for (const IterVar &iter : input_iters) {
IterMark iv_mark;
for (const IterMark &mark : collector.visited_) {
if (mark->source.as<Var>().same_as(iter->var)) {
if (mark->source.as<Var>()->same_as(iter->var)) {
iv_mark = mark;
break;
}
......
......@@ -40,7 +40,7 @@ static int to_CUtensorMapDataType(DataType dtype) {
}
} else if (dtype.is_bfloat16()) {
tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if (dtype.is_e4m3_float8() or dtype.is_e5m2_float8()) {
} else if (dtype.is_float8_e4m3() || dtype.is_float8_e5m2()) {
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if (dtype.is_int()) {
switch (dtype.bits()) {
......@@ -111,6 +111,12 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Stmt();
}
if (T.layout_map.count(global_tensor)) {
LOG(WARNING) << "TMA bulk copy cannot support a non-swizzled global "
"layout, fallback to normal copy.";
return Stmt();
}
Array<PrimExpr> indices;
for (auto r : shared_range)
indices.push_back(r->min);
......
......@@ -154,7 +154,7 @@ For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
annotations.Set("coalesced_width", coalesced_width);
}
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
ForKind::kParallel, body, NullOpt, annotations);
ForKind::kParallel, body, std::nullopt, annotations);
}
return Downcast<For>(body);
}
......@@ -254,12 +254,12 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
IterVar col_var = loop_vars[loop_vars.size() - 1];
IterVar row_var = loop_vars[loop_vars.size() - 2];
PrimExpr local_layout_thread_map =
FloorMod(local_layout->ForwardThread(local_indices, NullOpt), 32);
FloorMod(local_layout->ForwardThread(local_indices, std::nullopt), 32);
PrimExpr matrix_8x8_thread_map = makeGemmFragment8x8()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt);
PrimExpr matrix_8x8_thread_map_trans =
makeGemmFragment8x8Transposed()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, std::nullopt);
PrimExpr local_indices_flattened =
local_tensor.OffsetOf(local_indices_transformed).back();
if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) &&
......@@ -376,13 +376,13 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
// Only compare fragment layout
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
const FragmentNode *src_layout = T.layout_map[src].as<Fragment>().get();
const FragmentNode *dst_layout = T.layout_map[dst].as<Fragment>().get();
const auto &src_layout = T.layout_map[src].as<Fragment>();
const auto &dst_layout = T.layout_map[dst].as<Fragment>();
if (src_layout && dst_layout) {
ICHECK(src_layout->IsEqual(dst_layout, true))
ICHECK((*src_layout)->IsEqual(dst_layout->get(), true))
<< "Get different layout for " << src << " and " << dst
<< "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << dst_layout->DebugOutput()
<< "\nLHS = " << (*src_layout)->DebugOutput()
<< "\nRHS = " << (*dst_layout)->DebugOutput()
<< "\nYou may need to use a shared memory to transform the layout";
}
}
......
......@@ -223,17 +223,13 @@ bool Gemm::CheckWGMMA() const {
if (C->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
else if (A->dtype == DataType::NVFloat8E4M3() &&
B->dtype == DataType::NVFloat8E4M3())
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E4M3() &&
B->dtype == DataType::NVFloat8E5M2())
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E5M2() &&
B->dtype == DataType::NVFloat8E4M3())
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E5M2() &&
B->dtype == DataType::NVFloat8E5M2())
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else
return false;
......@@ -245,17 +241,13 @@ bool Gemm::CheckWGMMA() const {
return K % 16 == 0;
else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
return (!trans_A) && trans_B && K % 8 == 0;
else if (A->dtype == DataType::NVFloat8E4M3() &&
B->dtype == DataType::NVFloat8E4M3())
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E4M3() &&
B->dtype == DataType::NVFloat8E5M2())
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E5M2() &&
B->dtype == DataType::NVFloat8E4M3())
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E5M2() &&
B->dtype == DataType::NVFloat8E5M2())
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else
return false;
......
......@@ -4,7 +4,7 @@
*
*/
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
......
......@@ -4,7 +4,7 @@
*
*/
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
......
......@@ -22,7 +22,7 @@ using namespace tir;
using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>;
using LayoutMap = Map<Buffer, Layout>;
using BufferMap = Map<Var, Buffer>;
using OpBuilderFunc = TypedPackedFunc<void *(Array<PrimExpr>, BufferMap)>;
using OpBuilderFunc = ffi::TypedFunction<void *(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op &Entry::Get() { \
......
......@@ -230,7 +230,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
// Check if coalesced_width is defined
if (auto coalesced_width =
root_->annotations.Get(tl::attr::coalesced_width)) {
if (const auto *imm = coalesced_width.as<IntImmNode>()) {
if (const auto *imm = coalesced_width->as<IntImmNode>()) {
int expected = imm->value;
// Verify that vector_size is divisible by expected
if (vector_size % expected != 0) {
......@@ -278,8 +278,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
continue;
auto vars =
loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
auto lhs = loop_layout_->ForwardThread(vars, NullOpt);
auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt);
auto lhs = loop_layout_->ForwardThread(vars, std::nullopt);
auto rhs = fragment->ForwardThread(indice_map_[buffer], std::nullopt);
auto diff = analyzer_.Simplify(lhs - rhs);
ICHECK(is_zero(diff))
<< "Layout infer conflict for " << buffer << " " << source_buffer
......@@ -304,11 +304,10 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
source_buffer.scope() == "local.fragment") {
if (T.layout_map.count(buffer)) {
const FragmentNode *src_layout =
T.layout_map[buffer].as<Fragment>().get();
T.layout_map[buffer].as<FragmentNode>();
Fragment dst_layout_fragment =
CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
const FragmentNode *dst_layout =
dst_layout_fragment.as<Fragment>().get();
const FragmentNode *dst_layout = dst_layout_fragment.as<FragmentNode>();
if (as_const_int(dst_layout->ReplicateExtent()) &&
as_const_int(src_layout->ReplicateExtent()) &&
(*as_const_int(dst_layout->ReplicateExtent()) >
......@@ -336,7 +335,7 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
if (predicate_.defined()) {
return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
} else {
return NullOpt;
return std::nullopt;
}
}
......@@ -362,7 +361,8 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
PrimExpr thd_b = loop_layout_->ForwardThread(
ind_inv->Forward(fwd),
FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt)
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent,
std::nullopt)
->CondenseReplicateVar();
}
......
......@@ -201,7 +201,7 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
reduce_local =
For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
ForKind::kUnrolled, reduce_local, NullOpt,
ForKind::kUnrolled, reduce_local, std::nullopt,
{{tir::attr::pragma_unroll_explicit, Bool(false)}});
}
stmts.push_back(reduce_local);
......@@ -213,7 +213,7 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
for (const auto &iter_split : iter_sum->args) {
auto mark = iter_split->source->source.as<Var>();
ICHECK(mark.defined());
ICHECK(mark) << "Not a normalized iterator: " << iter_split->source;
if (mark.value().same_as(src_vars[this->dim]->var)) {
auto scale = as_const_int(iter_split->scale);
auto extent = as_const_int(iter_split->extent);
......@@ -307,7 +307,7 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
auto thd = src_layout->ForwardThread(
fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
Fragment dst_layout =
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
->CondenseReplicateVar()
->BindThreadRange(T.thread_bounds);
return {{dst, dst_layout}};
......
......@@ -7,13 +7,12 @@
#include "runtime.h"
#include "../target/cuda.h"
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/node/node.h>
namespace tvm {
namespace tl {
using namespace runtime;
#if (CUDA_MAJOR_VERSION >= 12)
template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
std::stringstream ss;
......@@ -39,37 +38,35 @@ struct TensorMapArgs {
CUtensorMapL2promotion l2Promotion;
CUtensorMapFloatOOBfill oobFill;
static TensorMapArgs Extract(TVMArgs args) {
static TensorMapArgs Extract(PackedArgs args) {
TensorMapArgs T;
int idx = 0;
ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
T.type =
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++];
ICHECK(args.size() >= 8);
T.map = reinterpret_cast<CUtensorMap *>(args[idx++].cast<void *>());
T.type = static_cast<CUtensorMapDataType>(args[idx++].cast<int64_t>());
T.tensorRank = static_cast<cuuint32_t>(args[idx++].cast<int64_t>());
T.globalAddress = args[idx++].cast<void *>();
ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5);
ICHECK(args.num_args == static_cast<int>(8 + T.tensorRank * 4));
ICHECK(args.size() == static_cast<int>(8 + T.tensorRank * 4));
for (size_t i = 0; i < T.tensorRank; i++) {
T.globalDim[i] = static_cast<cuuint64_t>(args[idx++]);
T.globalDim[i] = args[idx++].cast<cuuint64_t>();
}
for (size_t i = 0; i < T.tensorRank; i++) {
T.globalStride[i] = static_cast<cuuint64_t>(args[idx++]);
T.globalStride[i] = args[idx++].cast<cuuint64_t>();
}
for (size_t i = 0; i < T.tensorRank; i++) {
T.boxDim[i] = static_cast<cuuint64_t>(args[idx++]);
T.boxDim[i] = args[idx++].cast<cuuint64_t>();
}
for (size_t i = 0; i < T.tensorRank; i++) {
T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]);
T.elementStrides[i] = args[idx++].cast<cuuint64_t>();
}
T.interleave =
static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.swizzle =
static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
static_cast<CUtensorMapInterleave>(args[idx++].cast<int64_t>());
T.swizzle = static_cast<CUtensorMapSwizzle>(args[idx++].cast<int64_t>());
T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
static_cast<CUtensorMapL2promotion>(args[idx++].cast<int64_t>());
T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
static_cast<CUtensorMapFloatOOBfill>(args[idx++].cast<int64_t>());
return T;
}
......@@ -93,20 +90,23 @@ struct TensorMapArgs {
};
// set device api
TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled)
.set_body([](TVMArgs args, TVMRetValue *ret) {
TensorMapArgs T = TensorMapArgs::Extract(args);
CUresult result = cuTensorMapEncodeTiled(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< std::endl
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed(
"tvm_tensormap_create_tiled", [](PackedArgs args, Any *ret) {
TensorMapArgs T = TensorMapArgs::Extract(args);
CUresult result = cuTensorMapEncodeTiled(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< std::endl
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
});
struct TensorMapIm2ColArgs {
CUtensorMap *map;
......@@ -122,42 +122,40 @@ struct TensorMapIm2ColArgs {
CUtensorMapL2promotion l2Promotion;
CUtensorMapFloatOOBfill oobFill;
static TensorMapIm2ColArgs Extract(TVMArgs args) {
static TensorMapIm2ColArgs Extract(PackedArgs args) {
TensorMapIm2ColArgs T;
int idx = 0;
ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
T.type =
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++];
ICHECK(args.size() >= 8);
T.map = reinterpret_cast<CUtensorMap *>(args[idx++].cast<void *>());
T.type = static_cast<CUtensorMapDataType>(args[idx++].cast<int64_t>());
T.tensorRank = static_cast<cuuint32_t>(args[idx++].cast<int64_t>());
T.globalAddress = args[idx++].cast<void *>();
ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5);
ICHECK(args.num_args == static_cast<int>(6 + T.tensorRank * 5));
ICHECK(args.size() == static_cast<int>(6 + T.tensorRank * 5));
for (size_t i = 0; i < T.tensorRank; i++) {
T.globalDim[i] = static_cast<cuuint64_t>(args[idx++]);
T.globalDim[i] = args[idx++].cast<cuuint64_t>();
}
for (size_t i = 0; i < T.tensorRank; i++) {
T.globalStride[i] = static_cast<cuuint64_t>(args[idx++]);
T.globalStride[i] = args[idx++].cast<cuuint64_t>();
}
for (size_t i = 0; i < T.tensorRank; i++) {
T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]);
T.elementStrides[i] = args[idx++].cast<cuuint64_t>();
}
for (size_t i = 0; i < T.tensorRank - 2; i++) {
T.pixelBoxLowerCorner[i] = static_cast<int>(args[idx++]);
T.pixelBoxLowerCorner[i] = args[idx++].cast<int>();
}
for (size_t i = 0; i < T.tensorRank - 2; i++) {
T.pixelBoxUpperCorner[i] = static_cast<int>(args[idx++]);
T.pixelBoxUpperCorner[i] = args[idx++].cast<int>();
}
T.smem_box_pixel = static_cast<cuuint64_t>(args[idx++]);
T.smem_box_channel = static_cast<cuuint64_t>(args[idx++]);
T.smem_box_pixel = args[idx++].cast<cuuint64_t>();
T.smem_box_channel = args[idx++].cast<cuuint64_t>();
T.interleave =
static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.swizzle =
static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
static_cast<CUtensorMapInterleave>(args[idx++].cast<int64_t>());
T.swizzle = static_cast<CUtensorMapSwizzle>(args[idx++].cast<int64_t>());
T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
static_cast<CUtensorMapL2promotion>(args[idx++].cast<int64_t>());
T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
static_cast<CUtensorMapFloatOOBfill>(args[idx++].cast<int64_t>());
return T;
}
......@@ -185,21 +183,25 @@ struct TensorMapIm2ColArgs {
}
};
TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col)
.set_body([](TVMArgs args, TVMRetValue *ret) {
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
CUresult result = cuTensorMapEncodeIm2col(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.pixelBoxLowerCorner, T.pixelBoxUpperCorner,
T.smem_box_channel, T.smem_box_pixel, T.elementStrides, T.interleave,
T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< std::endl
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed(
"tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) {
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
CUresult result = cuTensorMapEncodeIm2col(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.pixelBoxLowerCorner, T.pixelBoxUpperCorner,
T.smem_box_channel, T.smem_box_pixel, T.elementStrides,
T.interleave, T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< std::endl
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
});
#endif // (CUDA_MAJOR_VERSION >= 12)
} // namespace tl
......
......@@ -22,27 +22,22 @@
*/
#include "codegen_cpp.h"
#include <tvm/relay/executor.h>
#include <tvm/relay/runtime.h>
#include <tvm/runtime/module.h>
#include <tvm/target/codegen.h>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "support/str_escape.h"
#include "target/build_common.h"
#include "target/func_registry_generator.h"
#include "target/source/codegen_params.h"
namespace tvm {
namespace codegen {
CodeGenTileLangCPP::CodeGenTileLangCPP() {
module_name_ = name_supply_->FreshName("__tvm_module_ctx");
module_name_ = name_supply_->FreshName("__tvm_ffi_library_ctx");
}
void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts,
......@@ -59,7 +54,7 @@ void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts,
}
void CodeGenTileLangCPP::InitGlobalContext() {
decl_stream << "void* " << tvm::runtime::symbol::tvm_module_ctx
decl_stream << "void* " << tvm::runtime::symbol::tvm_ffi_library_ctx
<< " = NULL;\n";
}
......@@ -384,13 +379,13 @@ void CodeGenTileLangCPP::VisitExpr_(const CallNode *op,
const std::string &type = op->args[0].as<StringImmNode>()->value;
const IntImmNode *num = op->args[1].as<IntImmNode>();
ICHECK(num != nullptr);
static_assert(alignof(TVMValue) % alignof(DLTensor) == 0, "invariant");
size_t unit = sizeof(TVMValue);
static_assert(alignof(TVMFFIAny) % alignof(DLTensor) == 0, "invariant");
size_t unit = sizeof(TVMFFIAny);
size_t size = 0;
if (type == "shape") {
size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit;
size = (num->value * sizeof(runtime::tvm_index_t) + unit - 1) / unit;
} else if (type == "arg_value") {
size = (num->value * sizeof(TVMValue) + unit - 1) / unit;
size = (num->value * sizeof(TVMFFIAny) + unit - 1) / unit;
} else if (type == "arg_tcode") {
size = (num->value * sizeof(int) + unit - 1) / unit;
} else if (type == "array") {
......@@ -399,7 +394,7 @@ void CodeGenTileLangCPP::VisitExpr_(const CallNode *op,
LOG(FATAL) << "Unknown stack alloca type " << type;
}
this->PrintIndent();
this->stream << "TVMValue " << stack_name << "[" << size << "];\n";
this->stream << "TVMFFIAny " << stack_name << "[" << size << "];\n";
os << stack_name;
} else if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
auto function_info = GetFunctionInfo(op, false /* has_resource_handle */);
......
......@@ -4,7 +4,7 @@
#include "codegen_cuda.h"
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
......@@ -39,15 +39,75 @@ static std::string GetFP8Type(DataType type) {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) "
"for FP8";
}
if (type.code() == DataType::kFloat8_e4m3fn) {
if (type.is_float8_e4m3fn() || type.is_float8_e4m3fnuz() ||
type.is_float8_e4m3()) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e4m3fnuz) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e5m2) {
} else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() ||
type.is_float8_e5m2()) {
stream << "fp8_e5" << vec << "_t";
} else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen but got " << type;
}
return stream.str();
}
std::string GetFP6Type(DataType type) {
std::stringstream stream;
int32_t lanes = type.lanes();
std::string vec;
if (type.is_scalar()) {
vec = "";
} else if (lanes == 2) {
vec = "x2";
} else if (lanes == 4) {
vec = "x4";
} else if (lanes == 8) {
vec = "x8";
} else if (lanes == 16) {
vec = "x16";
} else {
LOG(FATAL)
<< "Only support scalar and vector types of width (2, 4) for FP6";
}
stream << "__nv_fp6";
std::string suffix;
if (type.code() == DataType::kFloat6_e2m3fn) {
suffix = "_e2m3";
} else if (type.code() == DataType::kFloat6_e3m2fn) {
suffix = "_e3m2";
} else {
LOG(FATAL) << "Unsupported FP6 type in CUDA codegen";
}
stream << vec << suffix;
return stream.str();
}
std::string GetFP4Type(DataType type) {
std::stringstream stream;
int32_t lanes = type.lanes();
std::string vec;
if (type.is_scalar()) {
vec = "";
} else if (lanes == 2) {
vec = "x2";
} else if (lanes == 4) {
vec = "x4";
} else if (lanes == 8) {
vec = "x8";
} else if (lanes == 16) {
vec = "x16";
} else {
LOG(FATAL)
<< "Only support scalar and vector types of width (2, 4) for FP4";
}
stream << "__nv_fp4";
std::string suffix;
if (type.code() == DataType::kFloat4_e2m1fn) {
suffix = "_e2m1";
} else {
LOG(FATAL) << "Unsupported FP4 type in CUDA codegen";
}
stream << vec << suffix;
return stream.str();
}
......@@ -259,6 +319,22 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
enable_fp8_ = true;
os << GetFP8Type(t);
return;
} else if (t.is_float6()) {
enable_fp6_ = true;
if (t.lanes() <= 4) {
os << GetFP6Type(t);
} else {
fail = true;
}
return;
} else if (t.is_float4()) {
enable_fp4_ = true;
if (t.lanes() <= 4) {
os << GetFP4Type(t);
} else {
fail = true;
}
return;
} else if (t == DataType::Bool()) {
os << "bool";
return;
......@@ -678,7 +754,7 @@ void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol,
bool skip_first_arg,
std::ostream &os) { // NOLINT(*)
DataType ret_dtype = GetRuntimeDataType(ret_type);
if (ret_dtype.is_vector()) {
if (ret_dtype.is_fixed_length_vector()) {
//
// Emit an unsupported vector call
//
......@@ -799,13 +875,19 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) {
this->PrintIndent();
this->stream << name << "(";
// Cache context into a private ss, otherwise the let node may generate
// within the function call arguments.
std::ostringstream ss;
for (size_t i = offset; i < op->args.size(); i++) {
if (i > offset)
this->stream << ", ";
this->stream << this->PrintExpr(op->args[i]);
ss << ", ";
ss << this->PrintExpr(op->args[i]);
}
this->PrintIndent();
this->stream << name << "(";
this->stream << ss.str();
this->stream << ");\n";
};
if (op->op.same_as(builtin::ptx_cp_async())) {
......@@ -858,22 +940,24 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("tl::syncthreads_partial");
} else if (op->op.same_as(tl::tma_load())) {
this->PrintIndent();
std::ostringstream ss;
ICHECK_GE(op->args.size(), 2);
this->stream << "tl::tma_load(";
ss << "tl::tma_load(";
auto desc = op->args[0];
this->stream << this->PrintExpr(desc) << ", ";
ss << this->PrintExpr(desc) << ", ";
if (const IntImmNode *imm = op->args[1].as<IntImmNode>()) {
this->stream << "_mbarrier[" << imm->value << "], ";
ss << "_mbarrier[" << imm->value << "], ";
} else {
this->stream << this->PrintExpr(op->args[1]) << ", ";
ss << this->PrintExpr(op->args[1]) << ", ";
}
for (size_t i = 2; i < op->args.size(); i++) {
if (i > 2)
this->stream << ", ";
this->stream << this->PrintExpr(op->args[i]);
ss << ", ";
ss << this->PrintExpr(op->args[i]);
}
this->stream << ");\n";
ss << ");\n";
this->PrintIndent();
this->stream << ss.str();
} else if (op->op.same_as(tl::tma_load_im2col())) {
print_extern_call_stmt("tl::tma_load_im2col");
} else if (op->op.same_as(tl::tma_store())) {
......@@ -1111,8 +1195,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
// To store the 32x8 output back to a 16x16 tile in shared or global memory,
// we invert this map to determine the output location for each 8 element.
const auto *index_map_func =
runtime::Registry::Get("tir.index_map.shared_16x16_to_mma_32x8_layout");
const auto index_map_func = ffi::Function::GetGlobal(
"tir.index_map.shared_16x16_to_mma_32x8_layout");
IndexMap index_map;
if (!index_map_func) {
......@@ -1289,6 +1373,100 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr
<< ")), \"r\"((int)" << guard << ")\n";
stream << ");\n";
} else if (op->op.same_as(builtin::reinterpret())) {
DataType tgt_dtype = op->dtype;
DataType src_dtype = op->args[0]->dtype;
PrimExpr value = op->args[0];
// Handle float4_e2m1fn reinterpret
if (!src_dtype.is_float4_e2m1fn() && !tgt_dtype.is_float4_e2m1fn()) {
return CodeGenC::VisitExpr_(op, os);
}
if (src_dtype == tgt_dtype || tgt_dtype.lanes() * tgt_dtype.bits() ==
src_dtype.lanes() * src_dtype.bits()) {
return CodeGenC::VisitExpr_(op, os);
}
CHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes())
<< "E2M1 float4 reinterpret expects source and target to have the same "
"number of lanes. "
<< "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;
CHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes())
<< "E2M1 float4 reinterpret expects source and target to have the same "
"number of bytes. "
<< "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;
int lanes = tgt_dtype.lanes();
int ssa_scope = BeginScope();
if (lanes == 1) {
// The case of lane=1 is same as the normal reinterpret,
// except that we allow the src and dst dtype to have different number of
// bits.
std::string rhs = SSAGetID(PrintExpr(value), src_dtype);
os << "(*(";
this->PrintType(tgt_dtype, os);
os << " *)(&(" << rhs << ")))";
} else if (lanes == 2) {
if (tgt_dtype.is_float4_e2m1fn()) {
// We view the source as an uint16, and then extract bits of two fp4
// numbers, and finally reinterpret the result as fp4x2.
value =
tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value});
tir::Var temp_var("temp_var", DataType::UInt(16));
value =
tir::Let(temp_var, value,
tir::Cast(DataType::UInt(8),
(temp_var & IntImm(DataType::UInt(16), 0xF)) |
((temp_var >> 4) &
IntImm(DataType::UInt(16), 0xF0))));
} else {
value = tir::Cast(
DataType::UInt(16),
tir::Call(DataType::UInt(8), tir::builtin::reinterpret(), {value}));
tir::Var temp_var("temp_var", DataType::UInt(16));
value =
tir::Let(temp_var, value,
(temp_var & IntImm(DataType::UInt(16), 0xF)) |
((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 4));
}
os << PrintExpr(
tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value}));
} else if (lanes == 4) {
if (tgt_dtype.is_float4_e2m1fn()) {
// We view the source as an uint32, and then extract bits of four fp4
// numbers, and finally reinterpret the result as fp4x4.
value =
tir::Call(DataType::UInt(32), tir::builtin::reinterpret(), {value});
tir::Var temp_var("temp_var", DataType::UInt(32));
value = tir::Let(
temp_var, value,
tir::Cast(
DataType::UInt(16),
(temp_var & IntImm(DataType::UInt(32), 0xF)) |
((temp_var >> 4) & IntImm(DataType::UInt(32), 0xF0)) |
((temp_var >> 8) & IntImm(DataType::UInt(32), 0xF00)) |
((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000))));
} else {
value = tir::Cast(DataType::UInt(32),
tir::Call(DataType::UInt(16),
tir::builtin::reinterpret(), {value}));
tir::Var temp_var("temp_var", DataType::UInt(32));
value = tir::Let(
temp_var, value,
(temp_var & IntImm(DataType::UInt(32), 0xF)) |
((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 4) |
((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) |
((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12));
}
os << PrintExpr(
tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value}));
} else {
LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: "
<< lanes;
}
EndScope(ssa_scope);
} else if (op->op.same_as(builtin::thread_return())) {
os << "return";
} else {
CodeGenC::VisitExpr_(op, os);
}
......
......@@ -80,16 +80,21 @@ private:
std::string vid_global_barrier_state_;
// Global barrier expected node.
std::string vid_global_barrier_expect_;
// whether enable fp16
bool enable_fp16_{false};
// whether enable bf16
bool enable_bf16_{false};
// whether enable fp8
bool enable_fp8_{false};
// whether enable sparse gemm
bool enable_sparse_gemm_{false};
// whether enable fp6
bool enable_fp6_{false};
// whether enable fp4
bool enable_fp4_{false};
// whether enable int8
bool enable_int8_{false};
// whether enable sparse gemm
bool enable_sparse_gemm_{false};
// whether enable warp shuffle intrinsics
bool enable_warp_shuffle_{false};
// whether need math_constants.h
......
......@@ -21,6 +21,7 @@
* \file codegen_webgpu.cc
*/
#include "codegen_webgpu.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/builtin.h>
......@@ -704,11 +705,11 @@ public:
return runtime::ModulePropertyMask::kBinarySerializable;
}
PackedFunc GetFunction(const String &name,
const ObjectPtr<Object> &sptr_to_self) final {
ffi::Function GetFunction(const String &name,
const ObjectPtr<Object> &sptr_to_self) final {
LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run "
"through tvmjs";
return PackedFunc(nullptr);
return ffi::Function(nullptr);
}
void SaveToBinary(dmlc::Stream *stream) final {
......@@ -773,10 +774,13 @@ runtime::Module BuildTileLangWebGPU(IRModule mod, Target target) {
return runtime::Module(n);
}
TVM_REGISTER_GLOBAL("target.build.tilelang_webgpu")
.set_body_typed([](IRModule mod, Target target) {
return BuildTileLangWebGPU(mod, target);
});
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.tilelang_webgpu",
[](IRModule mod, Target target) {
return BuildTileLangWebGPU(mod, target);
});
});
} // namespace codegen
} // namespace tvm
#include "codegen_cpp.h"
#include <tvm/ffi/reflection/registry.h>
namespace tvm {
namespace codegen {
runtime::Module BuildCPPHost(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
bool emit_asserts = false;
bool emit_fwd_func_decl = true;
......@@ -67,7 +67,10 @@ runtime::Module BuildCPPHost(IRModule mod, Target target) {
return CSourceModuleCreate(code, "c", cg.GetFunctionNames());
}
TVM_REGISTER_GLOBAL("target.build.tilelang_cpp").set_body_typed(BuildCPPHost);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.tilelang_cpp", BuildCPPHost);
});
} // namespace codegen
} // namespace tvm
#include "codegen_cuda.h"
#include "runtime/cuda/cuda_module.h"
#include "runtime/pack_args.h"
#include <tvm/ffi/reflection/registry.h>
namespace tvm {
namespace codegen {
......@@ -18,7 +20,7 @@ ExtractFuncInfo(const IRModule &mod) {
if (f->params[i]->dtype.is_handle()) {
auto ptr = f->params[i]->type_annotation.as<PointerTypeNode>();
if (ptr && ptr->storage_scope == "grid_constant") {
info.arg_types.push_back(DataType(kTVMGridConstant, 64, 1));
info.arg_types.push_back(DataType(runtime::kDLGridConstant, 64, 1));
continue;
}
}
......@@ -36,7 +38,6 @@ ExtractFuncInfo(const IRModule &mod) {
}
runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenTileLangCUDA cg;
cg.Init(output_ssa);
......@@ -52,13 +53,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
}
std::string code = cg.Finish();
if (const auto *f = Registry::Get("tilelang_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string();
if (const auto f =
ffi::Function::GetGlobal("tilelang_callback_cuda_postproc")) {
code = (*f)(code, target).cast<std::string>();
}
std::string fmt = "ptx";
std::string ptx;
if (const auto *f = Registry::Get("tilelang_callback_cuda_compile")) {
ptx = (*f)(code, target).operator std::string();
if (const auto f =
ffi::Function::GetGlobal("tilelang_callback_cuda_compile")) {
ptx = (*f)(code, target).cast<std::string>();
if (ptx[0] != '/')
fmt = "cubin";
} else {
......@@ -68,7 +71,6 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
}
runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenTileLangCUDA cg;
cg.Init(output_ssa);
......@@ -84,16 +86,20 @@ runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
}
std::string code = cg.Finish();
if (const auto *f = Registry::Get("tilelang_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string();
if (const auto f =
ffi::Function::GetGlobal("tilelang_callback_cuda_postproc")) {
code = (*f)(code, target).cast<std::string>();
}
return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code);
}
TVM_REGISTER_GLOBAL("target.build.tilelang_cuda")
.set_body_typed(BuildTileLangCUDA);
TVM_REGISTER_GLOBAL("target.build.tilelang_cuda_without_compile")
.set_body_typed(BuildTileLangCUDAWithoutCompile);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("target.build.tilelang_cuda", BuildTileLangCUDA)
.def("target.build.tilelang_cuda_without_compile",
BuildTileLangCUDAWithoutCompile);
});
} // namespace codegen
} // namespace tvm
#if defined(__linux__)
#include <sys/stat.h>
#include <tvm/ffi/reflection/registry.h>
#endif
#include <hip/hip_runtime.h>
......@@ -95,10 +96,13 @@ runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
return ROCMModuleCreate("ptx", "fmt", ExtractFuncInfo(mod), code,
std::string());
}
TVM_REGISTER_GLOBAL("target.build.tilelang_hip")
.set_body_typed(BuildTileLangHIP);
TVM_REGISTER_GLOBAL("target.build.tilelang_hip_without_compile")
.set_body_typed(BuildTileLangHIPWithoutCompile);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("target.build.tilelang_hip", BuildTileLangHIP)
.def("target.build.tilelang_hip_without_compile",
BuildTileLangHIPWithoutCompile);
});
} // namespace codegen
} // namespace tvm
......@@ -3,6 +3,7 @@
* \brief align dynamic shared memory allocations
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
......@@ -147,8 +148,11 @@ tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) {
"tl.AlignDynamicSharedMemoryAllocations", {});
}
TVM_REGISTER_GLOBAL("tl.transform.AlignDynamicSharedMemoryAllocations")
.set_body_typed(AlignDynamicSharedMemoryAllocations);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AlignDynamicSharedMemoryAllocations",
AlignDynamicSharedMemoryAllocations);
});
} // namespace tl
} // namespace tvm
......@@ -22,8 +22,9 @@
* \brief Split device function from host.
*/
#include "tir/transforms/ir_utils.h"
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
......@@ -87,8 +88,11 @@ tvm::transform::Pass AnnotateDeviceRegions() {
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateDeviceRegions", {});
}
TVM_REGISTER_GLOBAL("tl.transform.AnnotateDeviceRegions")
.set_body_typed(AnnotateDeviceRegions);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AnnotateDeviceRegions",
AnnotateDeviceRegions);
});
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment