Unverified Commit 7248a810 authored by Gabriel Wu's avatar Gabriel Wu Committed by GitHub
Browse files

feat(cutedsl): add CuTeDSL backend (#1421)



* feat: CuTeDSL backend

* fix: clang-tidy

* fix: clang-format

* fix: ci

* fix: revert example gemm fp8

* fix: remove duplicate code

* fix: switch-case

* fix: fp16 silence

* fix: TVM IR print

* fix: useless tir

* fix: clang-format

* fix: remove tilelang/contrib/cutedsl/.gitignore

* fix: use hexfloat

* fix: gsym guard

* fix: unknown storage sync type

* fix: string literal

* fix: add args guard

* fix: name hint dedup

* fix: better find_kernel_by_pattern

* fix: set libpath for from_database path

* fix: guard buffer.strides

* fix: from guard

* fix: eviction guard

* fix: use thread local tma descs

* fix: ruff

* fix: drop tma_init_cpp

* fix: exc_info

* fix: negative unmatch early return

* fix: rename postproc func and add test

* fix: handle fast math according to pass config

* fix: dyn_sym parse

* fix: wrap_forward

* fix: use tvm_ffi.libinfo instead of cli

* fix: keep signature

* fix: C++ string safety

* fix: mark tma_store_add as unsupported

* fix: tvm version

* resolve ldsm and cpasync issues.

* fix: minor fixes

* fix: parse signature using ast

* fix: guard global_addr

* fix: create tempfile only when necessary

* fix: use logger.execption for exceptions

* fix: guard lib_path and host_func

* fix: remove tma_cpp_init and add timeout for cpp compile

* add timeout for mbarrier_wait.

* fix: _load_kernel_from_disk signature

* resolve codegen issues.

* fix: logger.exception

* add comment for div_by=1

* merge

* fix: reserve cutlass,cute,tl

* fix: guard tma_store

* fix: allow int64 offset in make_tensor_at_offset

* fix: guard barrier

* fix: add comments for div_by=16

* fix: div_by=1 issue

* delete div_by when offset is 0

* use tl.make_tensor when offset is 0

* fix: explicitly check cutedsl target

* fix: use param.torch_dtype()

---------
Co-authored-by: default avataryuxic <yuxic@nvidia.com>
Co-authored-by: default avatarYong <yong@local>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent a6f59f31
...@@ -370,8 +370,27 @@ jobs: ...@@ -370,8 +370,27 @@ jobs:
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
) )
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ "${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
--ignore=./python/jit/test_tilelang_jit_cutedsl.py \
./python ./python
# CuTeDSL JIT tests require GEMM v1 (must be set before importing tilelang).
# Run them in a dedicated step to avoid changing the default GEMM selection
# (and to keep the rest of the CUDA tests on GEMM v2).
- name: Run CuTeDSL JIT tests (GEMM v1) with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
id: cutedsl-tests
if: contains(matrix.runner.toolkit, 'CUDA')
env:
TILELANG_USE_GEMM_V1: "1"
run: |
cd testing
PYTEST=(
uv run --no-project -m --
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
)
# Avoid xdist contention on a single GPU by running this file in one worker.
"${PYTEST[@]}" --maxfail=3 --numprocesses=1 \
./python/jit/test_tilelang_jit_cutedsl.py
# AMD ROCm tests # AMD ROCm tests
- name: Run ROCm tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) - name: Run ROCm tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
id: rocm-tests id: rocm-tests
......
...@@ -215,7 +215,11 @@ elseif(USE_CUDA) ...@@ -215,7 +215,11 @@ elseif(USE_CUDA)
src/runtime/runtime.cc src/runtime/runtime.cc
src/target/ptx.cc src/target/ptx.cc
src/target/codegen_cuda.cc src/target/codegen_cuda.cc
src/target/codegen_py.cc
src/target/codegen_utils.cc
src/target/codegen_cutedsl.cc
src/target/rt_mod_cuda.cc src/target/rt_mod_cuda.cc
src/target/rt_mod_cutedsl.cc
) )
list(APPEND TILE_LANG_SRCS ${TILE_LANG_CUDA_SRCS}) list(APPEND TILE_LANG_SRCS ${TILE_LANG_CUDA_SRCS})
......
...@@ -14,7 +14,13 @@ cd examples ...@@ -14,7 +14,13 @@ cd examples
python -m pytest -n 4 . --verbose --color=yes --durations=0 --showlocals --cache-clear python -m pytest -n 4 . --verbose --color=yes --durations=0 --showlocals --cache-clear
cd .. cd ..
# Run pytest in parallel (4 workers) for all tests in the testing/python directory # Run pytest in parallel (4 workers) for all tests in the testing/python directory.
# IMPORTANT: CuTeDSL backend currently requires GEMM v1 (TILELANG_USE_GEMM_V1=1).
# Do NOT export it globally here, or you'll silently change the default GEMM selection
# for unrelated tests. Run the CuTeDSL JIT tests in a separate pytest invocation.
cd testing/python cd testing/python
python -m pytest -n 4 . --verbose --color=yes --durations=0 --showlocals --cache-clear python -m pytest -n 4 . --ignore=jit/test_tilelang_jit_cutedsl.py --verbose --color=yes --durations=0 --showlocals --cache-clear
# CuTeDSL JIT tests (isolate env + avoid xdist contention on a single GPU)
TILELANG_USE_GEMM_V1=1 python -m pytest -n 1 jit/test_tilelang_jit_cutedsl.py --verbose --color=yes --durations=0 --showlocals --cache-clear
cd .. cd ..
...@@ -7,3 +7,5 @@ ...@@ -7,3 +7,5 @@
# CUDA specific requirements # CUDA specific requirements
flash-attn==2.5.8 flash-attn==2.5.8
cuda-python==12.9.4 cuda-python==12.9.4
# CuTeDSL (CUTLASS Python DSL with CuTe support)
nvidia-cutlass-dsl>=4.3.1
/*!
* \file target/codegen_cutedsl.cc
*/
#include "codegen_cutedsl.h"
#include "codegen_utils.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ir/transform.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
#include <cmath>
#include <string>
#include <utility>
#include <vector>
#include "../op/builtin.h"
#include "arith/pattern_match.h"
namespace tvm {
namespace codegen {
namespace {
// The threshold of the loop extent to use cutlass.range_constexpr
// Higher values would lead to DSLOptimizationWarning:
// This static loop has 128 iterations, which may be very slow to compile,
// consider using `cutlass.range(..., unroll_full=True)` instead.
const int64_t LOOP_UNROLL_THRESHOLD = 64;
void ReplaceAll(std::string &str, const std::string &from,
const std::string &to) {
ICHECK(!from.empty()) << "ReplaceAll(): `from` must be non-empty";
auto pos = str.find(from);
while (pos != std::string::npos) {
str.replace(pos, from.size(), to);
pos = str.find(from, pos + to.size());
}
}
} // namespace
CodeGenTileLangCuTeDSL::CodeGenTileLangCuTeDSL() {
// Read fastmath configuration from current PassContext
auto pass_ctx = tvm::transform::PassContext::Current();
// Read tl.enable_fast_math config, default to false
enable_fastmath_ =
pass_ctx->GetConfig<Bool>(tl::kEnableFastMath, Bool(false)).value();
}
std::string CodeGenTileLangCuTeDSL::CanonicalizeFastmathFunctionName_(
const std::string &func_name) const {
static const std::unordered_map<std::string, std::string> kFastMathMap = {
{"divf", "tl.divf"}, {"exp", "tl.exp"}, {"expf", "tl.exp"},
{"exp2", "tl.exp2"}, {"exp2f", "tl.exp2"}, {"log", "tl.log"},
{"logf", "tl.log"}, {"log2", "tl.log2"}, {"log2f", "tl.log2"},
{"log10", "tl.log10"}, {"tan", "tl.tan"}, {"cos", "tl.cos"},
{"sin", "tl.sin"}, {"sqrt", "tl.sqrt"}, {"sqrtf", "tl.sqrt"},
};
auto it = kFastMathMap.find(func_name);
if (it != kFastMathMap.end()) {
return it->second;
}
return "";
}
void CodeGenTileLangCuTeDSL::PrintFuncDecorator_(
std::ostream &os) { // NOLINT(*)
os << "@cute.kernel\n";
}
void CodeGenTileLangCuTeDSL::PreFunctionBody_(const PrimFunc &f) {
PrintIndent();
stream << "threadIdx = tl.ThreadIdx()" << "\n";
PrintIndent();
stream << "blockIdx = tl.BlockIdx()" << "\n";
}
namespace {
std::string DTypeToString(DataType t) {
ICHECK(t.is_scalar()) << "unsupported type " << t;
if (t.is_void()) {
return "void";
}
if (t == tl::cuTensorMapType()) {
return "CUtensorMap";
}
int bits = t.bits();
std::string elem_type;
if (t.is_float()) {
if (bits == 16 || bits == 32 || bits == 64) {
elem_type = "Float" + std::to_string(bits);
}
} else if (t.is_bfloat16()) {
elem_type = "BFloat16";
} else if (t.is_float8()) {
if (t.is_float8_e3m4()) {
// unsupported
} else if (t.is_float8_e4m3()) {
elem_type =
"Float8E4M3FN"; // Only Float8E4M3FN is supported at the moment
} else if (t.is_float8_e4m3b11fnuz()) {
// unsupported
} else if (t.is_float8_e4m3fn()) {
elem_type = "Float8E4M3FN";
} else if (t.is_float8_e4m3fnuz()) {
// unsupported
} else if (t.is_float8_e5m2()) {
elem_type = "Float8E5M2";
} else if (t.is_float8_e5m2fnuz()) {
// unsupported
} else if (t.is_float8_e8m0fnu()) {
elem_type = "Float8E8M0FNU";
}
} else if (t.is_float6()) {
if (t.is_float6_e3m2fn()) {
elem_type = "Float6E3M2FN";
} else if (t.is_float6_e2m3fn()) {
elem_type = "Float6E2M3FN";
}
} else if (t.is_float4()) {
if (t.is_float4_e2m1fn()) {
elem_type = "Float4E2M1FN";
}
} else if (t.is_bool()) {
elem_type = "Boolean";
} else if (t.is_uint()) {
if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 128) {
elem_type = "Uint" + std::to_string(bits);
}
} else if (t.is_int()) {
if (bits == 4 || bits == 8 || bits == 16 || bits == 32 || bits == 64 ||
bits == 128) {
elem_type = "Int" + std::to_string(bits);
}
}
if (elem_type.empty()) {
LOG(FATAL) << "Cannot convert type " << t << " to CuTeDSL type!";
}
return "cutlass." + elem_type;
}
} // namespace
void CodeGenTileLangCuTeDSL::PrintType(DataType t,
std::ostream &os) { // NOLINT(*)
CHECK(t.is_scalar()) << "Should not print a non-scalar type in CuTeDSL: "
<< t;
os << DTypeToString(t);
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
os << "tl.make_filled_tensor((" << PrintExpr_(op->lanes) << ",), "
<< PrintExpr_(op->value) << ").load()";
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const FloatImmNode *op,
std::ostream &os) { // NOLINT(*)
switch (op->dtype.bits()) {
case 64:
case 32:
case 16:
case 8:
case 4: {
std::ostringstream temp;
if (std::isinf(op->value)) {
// For CuTeDSL, use Python's float('inf') instead of CUDA macros
PrintType(op->dtype, temp);
temp << "(";
if (op->value < 0) {
temp << "float('-inf')";
} else {
temp << "float('inf')";
}
temp << ")";
} else if (std::isnan(op->value)) {
// For CuTeDSL, use Python's float('nan')
PrintType(op->dtype, temp);
temp << "(float('nan'))";
} else {
// For CuTeDSL, use Python's float.fromhex() with hexfloat for full
// precision
PrintType(op->dtype, temp);
temp << "(float.fromhex('" << std::hexfloat << op->value << "'))";
}
MarkConst(temp.str());
os << temp.str();
break;
}
default:
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
}
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const CastNode *op,
std::ostream &os) { // NOLINT(*)
DataType from_ty = op->value.dtype();
DataType target_ty = op->dtype;
ICHECK_EQ(target_ty.lanes(), from_ty.lanes());
if (from_ty.is_scalar())
return CodeGenTileLangPY::VisitExpr_(op, os);
// Emit this as vectorized unary ops.
std::string sret = name_supply_->FreshName("_");
PrintIndent();
stream << sret << " = tl.make_rmem_tensor((" << target_ty.lanes() << ",), ";
PrintType(target_ty.element_of(), stream);
stream << ")\n";
std::string src = SSAGetID(PrintExpr_(op->value), from_ty);
PrintIndent();
stream << sret << ".store(" << src << ".to(";
PrintType(target_ty.element_of(), stream);
stream << "))\n";
os << sret << ".load()";
return;
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const DivNode *op,
std::ostream &os) { // NOLINT(*)
if (op->dtype.is_int() || op->dtype.is_uint()) {
PrintBinaryExpr_("//", op->dtype, op->a, op->b, os);
} else {
if (enable_fastmath_) {
os << "tl.divf(" << PrintExpr_(op->a) << ", " << PrintExpr_(op->b)
<< ", fastmath=True)";
} else {
PrintBinaryExpr_("tl.divf", op->dtype, op->a, op->b, os);
}
}
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const MinNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("tl.min", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const MaxNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("tl.max", op->dtype, op->a, op->b, os);
}
/**
* @brief Emit CuTeDSL-specific code for a call expression.
*
* This visitor handles CallNode intrinsics and builtins that require emitting
* CuTeDSL-specific code (inline PTX/ASM sequences, TensorLanguage runtime
* calls, WMMA/TMA helpers, barriers, cp.async primitives, index-map based
* stores, reinterpret/packing helpers, and various mma/ldmatrix patterns). The
* function writes the generated code to the provided output stream and falls
* back to the Python codegen for unrecognized calls.
*
* The method recognizes and emits code for (non-exhaustive): cp.async and its
* commit/wait variants, tma_load/store and im2col variants, ptX
* ldmatrix/stmatrix helpers, mbarrier APIs, cooperative grid sync, WMMA/legacy
* MMA intrinsics (fill/load/store/mma/bmma/ptx_mma/ptx_mma_sp), low-level PTX
* asm helpers (ldg32, cp_async bulk/init/arrive/wait barriers), reinterpret
* paths for special small-float encodings (e.g., float4 e2m1fn), tl::tl_gemm
* and related external calls, and other TL runtime calls.
*
* Side effects:
* - Emits to `os` and the internal codegen output stream.
* - May set internal feature flags (e.g., need_cooperative_groups_).
* - May open/close SSA scopes and mutate internal variable mappings.
* - May call LOG(FATAL) / CHECK / ICHECK on invalid or unsupported argument
* patterns.
*
* @param op The call node to generate code for; the function inspects op->op
* and op->args to determine the appropriate emission.
* @param os Output stream to receive expression-level output when the caller
* expects an expression result (some paths write directly to the
* member stream instead).
*/
void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op,
std::ostream &os) { // NOLINT(*)
auto print_extern_call_stmt = [&](std::string name, size_t start = 0,
size_t end = 0) {
// Cache context into a private ss, otherwise the let node may generate
// within the function call arguments.
std::ostringstream ss;
for (size_t i = start; i < op->args.size() - end; i++) {
if (i > start)
ss << ", ";
ss << PrintExpr_(op->args[i]);
}
PrintIndent();
stream << name << "(";
stream << ss.str();
stream << ")\n";
};
auto print_mbarrier_obj = [&](PrimExpr barrier_id) {
std::ostringstream ss;
if (barrier_id.as<IntImmNode>()) {
// incase the barrier_id is an integer, we need to print the barrier_id as
// an integer
ss << "(" << mbarrier_name_ << "+" << barrier_id << ")";
} else {
// otherwise may be a T.get_mbarrier() call or BufferLoad Node
// we need to print the barrier_id as a string
ss << PrintExpr_(barrier_id);
}
return ss.str();
};
if (op->op.same_as(builtin::ptx_cp_async())) {
std::string dst = PrintExpr_(op->args[0]);
std::string dst_offset = PrintExpr_(op->args[1]);
std::string src = PrintExpr_(op->args[2]);
std::string src_offset = PrintExpr_(op->args[3]);
std::string size = PrintExpr_(op->args[4]);
// use size of argument list to indicate whether or not to use predicated
// cp.async
if (op->args.size() == 5) {
PrintIndent();
stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << dst_offset
<< ", " << src << ", " << src_offset << ")\n";
} else {
std::string condition = PrintExpr_(op->args[5]);
PrintIndent();
stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", "
<< dst_offset << ", " << src << ", " << src_offset << ", "
<< condition << ")\n";
}
} else if (op->op.same_as(builtin::ptx_commit_group())) {
print_extern_call_stmt("tl.cp_async_commit");
} else if (op->op.same_as(builtin::ptx_wait_group())) {
print_extern_call_stmt("tl.cp_async_wait");
} else if (op->op.same_as(builtin::create_barriers())) {
PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value;
stream << mbarrier_name_
<< " = tl.alloc_smem(cutlass.Uint64, size_in_elems=" << barrier_count
<< ")\n";
} else if (op->op.same_as(tl::get_mbarrier())) {
ICHECK_EQ(op->args.size(), 1);
std::string barrier_id = PrintExpr_(op->args[0]);
os << "(" << mbarrier_name_ << "+" << barrier_id << ")";
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
if (op->args.size() == 1) {
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
stream << "tl.mbarrier_arrive(" << mbarrier_obj << ")\n";
} else if (op->args.size() == 3) {
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto cta_id = PrintExpr_(op->args[1]);
auto pred = PrintExpr_(op->args[2]);
stream << "tl.mbarrier_arrive(" << mbarrier_obj << ", " << cta_id << ", "
<< pred << ")\n";
} else {
LOG(FATAL) << "Invalid parameter for tl::arrive_barrier "
<< op->args.size();
}
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
ICHECK_EQ(op->args.size(), 2);
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto arrive_count = PrintExpr_(op->args[1]);
stream << "tl.mbarrier_init(" << mbarrier_obj << ", " << arrive_count
<< ")\n";
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
if (op->args.size() == 2) {
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto transaction_bytes = PrintExpr_(op->args[1]);
stream << "tl.arrive_and_expect_tx(" << mbarrier_obj << ", "
<< transaction_bytes << ")\n";
} else if (op->args.size() == 4) {
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto transaction_bytes = PrintExpr_(op->args[1]);
auto cta_id = PrintExpr_(op->args[2]);
auto pred = PrintExpr_(op->args[3]);
stream << "tl.arrive_and_expect_tx(" << mbarrier_obj << ", "
<< transaction_bytes << ", " << cta_id << ", " << pred << ")\n";
} else {
LOG(FATAL) << "Invalid parameter for tl::arrive_barrier_expect_tx "
<< op->args.size();
}
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
print_extern_call_stmt("tl.mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::ptx_fence_barrier_init())) {
print_extern_call_stmt("tl.fence_barrier_init");
} else if (op->op.same_as(tl::ptx_cp_async_barrier_noinc())) {
print_extern_call_stmt("tl.mbarrier_cp_async_arrive_noinc");
} else if (op->op.same_as(tl::mbarrier_expect_tx())) {
ICHECK_EQ(op->args.size(), 2);
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto transaction_bytes = PrintExpr_(op->args[1]);
stream << "tl.mbarrier_expect_tx(" << mbarrier_obj << ", "
<< transaction_bytes << ")\n";
} else if (op->op.same_as(tl::mbarrier_wait_parity())) {
ICHECK_EQ(op->args.size(), 2);
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto phase = PrintExpr_(op->args[1]);
stream << "tl.mbarrier_wait(" << mbarrier_obj << ", " << phase << ")\n";
} else if (op->op.same_as(tl::ptx_init_tensor_memory())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ptx_deallocate_tensor_memory())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::no_set_max_nreg())) {
// do nothing
} else if (op->op.same_as(tl::tma_load())) {
std::ostringstream ss;
ICHECK_GE(op->args.size(), 2);
auto pol = op->args[op->args.size() - 1].as<IntImmNode>();
ICHECK(pol) << "Eviction policy must be IntImm";
ICHECK_GE(pol->value, 0);
ICHECK_LT(static_cast<size_t>(pol->value), eviction_policy_names_.size());
auto eviction_policy = eviction_policy_names_[pol->value];
// Simplify the code by using the default eviction policy
if (eviction_policy != "EVICT_NORMAL") {
LOG(FATAL) << "Eviction policy " << eviction_policy
<< " is not supported currently";
} else {
ss << "tl.tma_load(";
}
auto desc = op->args[0];
ss << PrintExpr_(desc) << ", ";
ss << print_mbarrier_obj(op->args[1]) << ", ";
ss << PrintExpr_(op->args[2]) << ", (";
for (size_t i = 3; i < op->args.size() - 1; i++) {
if (i > 3)
ss << ", ";
ss << PrintExpr_(op->args[i]);
}
ss << "))\n";
PrintIndent();
stream << ss.str();
} else if (op->op.same_as(tl::tma_load_im2col())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::tma_store())) {
std::stringstream ss;
// Check minimum argument count (desc, data, at least one coord,
// need_reduce, eviction)
ICHECK_GE(op->args.size(), 4) << "tma_store requires at least 4 arguments "
"(desc, data, coords..., need_reduce, "
"eviction_policy), got "
<< op->args.size();
// Safely extract need_reduce flag
auto need_reduce_ptr = op->args[op->args.size() - 2].as<IntImmNode>();
ICHECK(need_reduce_ptr)
<< "tma_store need_reduce flag (args[-2]) must be IntImm, got "
<< op->args[op->args.size() - 2]->GetTypeKey();
auto need_reduce = need_reduce_ptr->value;
if (need_reduce) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
}
// Safely extract and validate eviction policy index
auto eviction_idx_ptr = op->args[op->args.size() - 1].as<IntImmNode>();
ICHECK(eviction_idx_ptr)
<< "tma_store eviction policy (args[-1]) must be IntImm, got "
<< op->args[op->args.size() - 1]->GetTypeKey();
ICHECK_GE(eviction_idx_ptr->value, 0)
<< "tma_store eviction policy index must be >= 0, got "
<< eviction_idx_ptr->value;
ICHECK_LT(static_cast<size_t>(eviction_idx_ptr->value),
eviction_policy_names_.size())
<< "tma_store eviction policy index " << eviction_idx_ptr->value
<< " out of bounds (max " << eviction_policy_names_.size() - 1 << ")";
auto eviction_policy = eviction_policy_names_[eviction_idx_ptr->value];
ss << "tl.tma_store(";
auto desc = op->args[0];
ss << PrintExpr_(desc) << ", ";
ss << PrintExpr_(op->args[1]) << ", (";
for (size_t i = 2; i < op->args.size() - 2; i++) {
if (i > 2)
ss << ", ";
ss << PrintExpr_(op->args[i]);
}
ss << ")";
if (eviction_policy != "EVICT_NORMAL") {
ss << ", eviction_kind = nvvm.EvictKind." << eviction_policy.substr(6);
}
ss << ")\n";
PrintIndent();
stream << ss.str();
} else if (op->op.same_as(tl::ptx_ldmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl.ptx_ldmatrix_x" + std::to_string(num);
if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::ptx_stmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl.ptx_stmatrix_x" + std::to_string(num);
if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::fence_proxy_async())) {
print_extern_call_stmt("tl.fence_proxy_async");
} else if (op->op.same_as(tl::tma_store_arrive())) {
print_extern_call_stmt("tl.tma_store_arrive");
} else if (op->op.same_as(tl::tma_store_wait())) {
PrintIndent();
stream << "tl.tma_store_wait(0)\n";
} else if (op->op.same_as(tl::warpgroup_arrive())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warpgroup_commit_batch())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warpgroup_wait())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warpgroup_fence_operand())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::set_max_nreg())) {
PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value;
int is_inc = Downcast<IntImm>(op->args[1])->value;
std::string func_name =
is_inc ? "tl.warpgroup_reg_alloc" : "tl.warpgroup_reg_dealloc";
stream << func_name << "(" << nreg << ")\n";
} else if (op->op.same_as(tl::wait_wgmma())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::pack_b16())) {
os << "tl.pack_half2(" << PrintExpr_(op->args[0]) << ", "
<< PrintExpr_(op->args[1]) << ")";
} else if (op->op.same_as(tl::sync_grid())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::loop_break())) {
PrintIndent();
stream << "break\n";
} else if (op->op.same_as(builtin::ptx_mma())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ptx_mma_sm70())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::ptx_mma_sp())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ptx_wgmma_ss())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ptx_wgmma_rs())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ptx_tcgen05_mma_ss())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ptx_tcgen05_mma_ts())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::tcgen05_mma_arrive())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::ptx_ldmatrix())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::mma_store())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::mma_fill())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::ptx_cp_async_bulk())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::ptx_wait_barrier())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::ptx_ldg32())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::reinterpret())) {
DataType tgt_dtype = op->dtype;
DataType src_dtype = op->args[0]->dtype;
ICHECK_EQ(tgt_dtype.lanes() * tgt_dtype.bits(),
src_dtype.lanes() * src_dtype.bits())
<< "reinterpret expects source and target to have the same number of "
"bits";
const BufferLoadNode *load = op->args[0].as<BufferLoadNode>();
ICHECK(op->args.size() == 1 && load);
ICHECK_EQ(load->indices.size(), 1)
<< "CodeGenTileLangCuTeDSL only supports flat memory";
PrimExpr index = load->indices[0];
if (const RampNode *node = index.as<RampNode>(); node) {
auto *p_stride = as_const_int(node->stride);
CHECK(p_stride);
ICHECK_EQ(*p_stride, 1) << "reinterpret expects contiguous elements";
index = node->base;
}
auto ptr_str = GetBufferPtr_(load->buffer.get(), index);
os << "tl.make_tensor(tl.recast_ptr(" << ptr_str << ", dtype=";
PrintType(tgt_dtype.element_of(), os);
os << "), (" << tgt_dtype.lanes() << ",)).load()";
} else if (op->op.same_as(builtin::thread_return())) {
os << "return";
} else if (op->op.same_as(tl::tl_gemm())) {
ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments <op_instance, "
"A_ptr, B_ptr, C_ptr>, but got "
<< op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]);
PrintCallExtern_(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_sp())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::get_lane_idx())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::get_warp_idx_sync())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::get_warp_idx())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::get_warp_group_idx())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::tl_shuffle_elect())) {
os << "tl.shuffle_elect(" << PrintExpr_(op->args[0]) << ")";
} else if (op->op.same_as(tl::initialize_wgmma_descriptor())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::initialize_tcgen05_descriptor())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::increase_descriptor_offset())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::__exp())) {
os << "tl.exp2(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::__exp10())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::__log())) {
os << "tl.log(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::__log2())) {
os << "tl.log2(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::__log10())) {
os << "tl.log10(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::__tan())) {
os << "tl.tan(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::__cos())) {
os << "tl.cos(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::__sin())) {
os << "tl.sin(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::ieee_add())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_sub())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_mul())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_fmaf())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_frcp())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_fsqrt())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_frsqrt())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_fdiv())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warp_reduce_sum())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warp_reduce_max())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warp_reduce_min())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warp_reduce_bitand())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warp_reduce_bitor())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::address_of())) {
const BufferLoadNode *load = op->args[0].as<BufferLoadNode>();
ICHECK(op->args.size() == 1 && load);
ICHECK_EQ(load->indices.size(), 1)
<< "CodeGenTileLangCuTeDSL only supports flat memory";
os << GetBufferPtr_(load->buffer.get(), load->indices[0]);
} else {
CodeGenTileLangPY::VisitExpr_(op, os);
}
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const BufferLoadNode *op,
std::ostream &os) { // NOLINT(*)
ICHECK_EQ(op->indices.size(), 1)
<< "Load from non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer load is not supported.";
DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
DataType element_dtype = op->buffer->dtype;
const int value_lanes = value_dtype.lanes();
if (value_lanes == element_dtype.lanes()) {
std::string ref = GetBufferRef_(value_dtype, op->buffer.get(), index);
if (ref.back() == ')') {
ref += ".load()";
}
os << ref;
} else {
ICHECK_GE(value_lanes, element_dtype.lanes())
<< "Unsupported load/store: value lanes < buffer element lanes";
bool is_contiguous = false;
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, value_lanes / element_dtype.lanes())
.Match(index)) {
is_contiguous = true;
}
if (is_contiguous) {
std::string ref =
GetBufferRef_(value_dtype, op->buffer.get(), base.Eval());
if (ref.back() == ')') {
ref += ".load()";
}
os << ref;
} else {
ICHECK(element_dtype.is_scalar())
<< "buffer element type for non-contiguous load must be scalar "
"currently";
std::string sret = name_supply_->FreshName("_");
PrintIndent();
stream << sret << " = tl.make_rmem_tensor((" << value_lanes << ",), ";
PrintType(element_dtype, stream);
stream << ")\n";
std::string vid = GetVarID(buffer_var.get());
const RampNode *ramp = index.as<RampNode>();
ICHECK(ramp)
<< "Expected Ramp index for vectorized non-contiguous access";
for (int i = 0; i < value_lanes; ++i) {
auto idx_expr =
arith::Analyzer().Simplify(ramp->base + ramp->stride * i);
PrintIndent();
stream << sret << "[" << i << "] = "
<< GetBufferRef_(element_dtype, op->buffer.get(), idx_expr)
<< "\n";
}
os << sret << ".load()";
}
}
}
void CodeGenTileLangCuTeDSL::VisitStmt_(const BufferStoreNode *op) {
ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer store is not supported.";
DataType value_dtype = op->value.dtype();
DataType element_dtype = op->buffer->dtype;
PrimExpr index_expr = op->indices[0];
Var buffer_var = op->buffer->data;
std::string value_str = PrintExpr_(op->value);
int value_lanes = value_dtype.lanes();
if (value_lanes == element_dtype.lanes()) {
std::string ref = GetBufferRef_(value_dtype, op->buffer.get(), index_expr);
PrintIndent();
if (ref.back() != ')') {
stream << ref << " = " << RemoveOutermostParentheses(value_str) << "\n";
} else {
stream << ref << ".store(" << RemoveOutermostParentheses(value_str)
<< ")\n";
}
} else {
bool is_contiguous = false;
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, value_lanes / element_dtype.lanes())
.Match(index_expr)) {
is_contiguous = true;
}
if (is_contiguous) {
PrintVecStore_(op->buffer.get(), value_dtype, base.Eval(), value_str);
} else {
ICHECK(element_dtype.is_scalar())
<< "buffer element type for non-contiguous store must be scalar "
"currently";
// store elements separately
value_str = SSAGetID(value_str, element_dtype);
for (int i = 0; i < value_lanes; ++i) {
const RampNode *ramp = index_expr.as<RampNode>();
ICHECK(ramp);
auto idx_expr =
arith::Analyzer().Simplify(ramp->base + ramp->stride * i);
PrintIndent();
stream << GetBufferRef_(element_dtype, op->buffer.get(), idx_expr)
<< " = ";
PrintVecElemLoad_(value_str, value_dtype, i, stream);
stream << "\n";
}
}
}
}
void CodeGenTileLangCuTeDSL::VisitStmt_(const AllocateNode *op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
PrintIndent();
std::string scope = GetPtrStorageScope(op->buffer_var);
alloc_storage_scope_[op->buffer_var.get()] = scope;
if (scope == "local.descriptor.wgmma") {
stream << vid << " = tl.GmmaDescriptor()\n";
} else if (scope == "local.descriptor.tcgen05_smem") {
LOG(FATAL) << "Currently unsupported scope: " << scope;
} else if (scope == "local.descriptor.tcgen05_instr") {
LOG(FATAL) << "Currently unsupported scope: " << scope;
} else if (scope == "shared.dyn") {
stream << vid << " = tl.make_tensor(tl.get_dyn_smem(";
PrintType(op->dtype, stream);
// there is no bound check for Tensor access, so just set shape to 1
stream << ", alignment=1024), (1,))\n";
} else {
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now, but get "
<< constant_size << " for " << op->buffer_var->name_hint;
if (scope == "shared") {
stream << vid << " = tl.make_tensor(tl.alloc_smem(";
PrintType(op->dtype, stream);
stream << ", " << constant_size << "), (" << constant_size << ",))\n";
} else if (scope == "shared.barrier") {
ICHECK(false) << "Unsupported scope: " << scope;
} else if (scope == "local") {
stream << vid << " = tl.make_rmem_tensor((" << constant_size << "),";
PrintType(op->dtype, stream);
stream << ")\n";
} else if (scope == "local.var") {
PrimExpr init = tir::make_const(op->dtype, 0);
auto init_it = op->annotations.find(tl::attr::kLocalVarInit);
if (init_it != op->annotations.end()) {
PrimExpr user_init = Downcast<PrimExpr>((*init_it).second);
if (!user_init.dtype().is_void() && user_init.dtype() != op->dtype) {
user_init = tir::Cast(op->dtype, user_init);
}
init = user_init;
}
stream << vid << " = " << PrintExpr_(init) << "\n";
} else {
ICHECK(false) << "Unsupported scope: " << scope;
}
}
RegisterHandleType_(op->buffer_var.get(), op->dtype);
PrintStmt_(op->body);
}
void CodeGenTileLangCuTeDSL::VisitStmt_(const AttrStmtNode *op) {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (!iv->thread_tag.empty()) {
if (!var_idmap_.count(iv->var.get())) {
BindThreadIndex_(iv);
}
}
VisitStmt(op->body);
} else if (op->attr_key == tir::attr::async_commit_queue_scope) {
const IntImmNode *queue_id = op->value.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
VisitStmt(op->body);
auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
VisitExpr(commit_group, stream);
} else if (op->attr_key == tir::attr::async_wait_queue_scope) {
auto wait_attrs = GetAsyncWaitAttributes(op);
auto queue_id = wait_attrs.first.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
auto wait_cnt = wait_attrs.second;
auto wait_group =
Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
VisitExpr(wait_group, stream);
auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner);
VisitStmt(inner->body);
} else if (op->attr_key == "threadblock_swizzle_pattern") {
this->PrintIndent();
const StringImmNode *pattern = op->value.as<StringImmNode>();
ICHECK(pattern);
std::string call_str = pattern->value;
// replace :: with . and replace < with ( and replace > with )
ReplaceAll(call_str, "::", ".");
ReplaceAll(call_str, "<", "(");
ReplaceAll(call_str, ">", ")");
this->stream << "blockIdx = " << call_str << "\n";
this->VisitStmt(op->body);
} else if (op->attr_key == "pragma_unroll_factor") {
const IntImmNode *factor = op->value.as<IntImmNode>();
ICHECK(factor);
unroll_factor_[op->node.as<VarNode>()] = Downcast<IntImm>(factor);
CodeGenTileLangPY::VisitStmt_(op);
} else {
CodeGenTileLangPY::VisitStmt_(op);
}
}
void CodeGenTileLangCuTeDSL::VisitStmt_(const ForNode *op) {
if (op->kind != tir::ForKind::kUnrolled) {
CodeGenTileLangPY::VisitStmt_(op);
return;
}
auto start_expr = arith::Analyzer().Simplify(op->min);
auto stop_expr = arith::Analyzer().Simplify(op->extent + op->min);
std::string unroll_factor;
if (auto it = unroll_factor_.find(op->loop_var.get());
it != unroll_factor_.end()) {
unroll_factor = PrintExpr_(it->second);
}
bool use_range_constexpr = unroll_factor.empty() &&
as_const_int(op->extent) != nullptr &&
*as_const_int(op->extent) <= LOOP_UNROLL_THRESHOLD;
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
stream << "for " << vid << " in cutlass.range";
if (use_range_constexpr) {
stream << "_constexpr";
}
stream << "(";
if (!is_zero(start_expr)) {
PrintExpr_(start_expr, stream);
stream << ", ";
}
PrintExpr_(stop_expr, stream);
if (!unroll_factor.empty()) {
stream << ", unroll=" << unroll_factor;
} else if (!use_range_constexpr) {
stream << ", unroll_full=True";
}
stream << "):\n";
int for_scope = BeginScope();
PrintStmt_(op->body);
EndScope(for_scope);
}
void CodeGenTileLangCuTeDSL::VisitStmt_(const IfThenElseNode *op) {
std::string cond = PrintExpr_(op->condition);
PrintIndent();
stream << "if " << RemoveOutermostParentheses(cond) << ":\n";
int then_scope = BeginScope();
if (const CallNode *call = op->condition.as<CallNode>();
call && call->op.same_as(tl::tl_shuffle_elect())) {
PrintIndent();
stream << "with cute.arch.elect_one():\n";
int with_scope = BeginScope();
PrintStmt_(op->then_case);
EndScope(with_scope);
} else {
PrintStmt_(op->then_case);
}
EndScope(then_scope);
if (op->else_case) {
PrintIndent();
stream << "else:\n";
int else_scope = BeginScope();
PrintStmt_(op->else_case.value());
EndScope(else_scope);
}
}
void CodeGenTileLangCuTeDSL::VisitStmt_(const EvaluateNode *op) {
if (is_const_int(op->value))
return;
const CallNode *call = op->value.as<CallNode>();
if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
LOG(FATAL) << "Currently unsupported op: " << call->op;
}
if (call && (call->op.same_as(tvm::tl::device_assert()))) {
std::string cond = RemoveOutermostParentheses(PrintExpr_(call->args[0]));
PrintIndent();
stream << "assert " << cond << "\n";
} else if (call && call->op.same_as(tvm::tl::device_assert_with_msg())) {
std::string cond = RemoveOutermostParentheses(PrintExpr_(call->args[0]));
std::string msg_expr = PrintExpr_(call->args[1]);
PrintIndent();
stream << "assert " << cond << ", " << msg_expr << "\n";
} else if (call && call->op.same_as(builtin::tvm_storage_sync())) {
PrintStorageSync_(call);
} else {
CodeGenTileLangPY::VisitStmt_(op);
}
}
void CodeGenTileLangCuTeDSL::PrintVecElemLoad_(const std::string &vec,
DataType t, int i,
std::ostream &os) { // NOLINT(*)
if (t.is_scalar()) {
os << vec;
return;
}
os << vec << "[" << i << "]";
}
void CodeGenTileLangCuTeDSL::PrintVecElemStore_(const std::string &vec,
DataType t, int i,
const std::string &value) {
PrintIndent();
stream << vec << "[" << i << "] = " << value << "\n";
}
void CodeGenTileLangCuTeDSL::PrintVecStore_(const BufferNode *buffer,
DataType t, PrimExpr base,
const std::string &value) {
ICHECK(!t.is_scalar()) << "PrintVecStore_() should not be used for scalar";
std::string ref = GetBufferRef_(t, buffer, base);
PrintIndent();
stream << ref << ".store(" << value << ")\n";
}
void CodeGenTileLangCuTeDSL::PrintVecBinaryOp_(const std::string &opstr,
DataType dtype, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os) { // NOLINT(*)
// Declare the result.
std::string sret = name_supply_->FreshName("_");
PrintIndent();
stream << sret << " = tl.make_rmem_tensor((" << dtype.lanes() << ",), ";
PrintType(dtype.element_of(), stream);
stream << ")\n";
std::string vlhs = SSAGetID(PrintExpr_(lhs), lhs.dtype());
std::string vrhs = SSAGetID(PrintExpr_(rhs), rhs.dtype());
const std::string one_char_op{"+-*%<>^|&"};
const std::string two_char_op{"// == != <= >="};
if ((opstr.size() == 1 && one_char_op.find(opstr) != std::string::npos) ||
(opstr.size() == 2 && two_char_op.find(opstr) != std::string::npos)) {
PrintIndent();
stream << sret << ".store(" << vlhs << " " << opstr << " " << vrhs << ")\n";
} else {
// Unpack into individual ops.
for (int i = 0, lanes = dtype.lanes(); i < lanes; ++i) {
std::ostringstream value_temp;
if (isalpha(opstr[0])) {
value_temp << opstr << "(";
PrintVecElemLoad_(vlhs, lhs.dtype(), i, value_temp);
value_temp << ", ";
PrintVecElemLoad_(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
} else {
value_temp << "(";
PrintVecElemLoad_(vlhs, lhs.dtype(), i, value_temp);
value_temp << opstr;
PrintVecElemLoad_(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
}
PrintVecElemStore_(sret, dtype, i, value_temp.str());
}
}
os << sret << ".load()";
}
void CodeGenTileLangCuTeDSL::PrintBinaryExpr_(const std::string &opstr,
DataType dtype, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os) { // NOLINT(*)
if (dtype.is_scalar()) {
CodeGenTileLangPY::PrintBinaryExpr_(opstr, dtype, lhs, rhs, os);
} else {
PrintVecBinaryOp_(opstr, dtype, lhs, rhs, os);
}
}
void CodeGenTileLangCuTeDSL::PrintBinaryIntrinsic_(
const CallNode *op, const char *opstr,
std::ostream &os) { // NOLINT(*)
if (op->dtype.is_scalar()) {
CodeGenTileLangPY::PrintBinaryIntrinsic_(op, opstr, os);
} else {
PrintVecBinaryOp_(opstr, op->dtype, op->args[0], op->args[1], os);
}
}
void CodeGenTileLangCuTeDSL::PrintCallExtern_(Type ret_type,
ffi::String global_symbol,
const ffi::Array<PrimExpr> &args,
bool skip_first_arg,
std::ostream &os) { // NOLINT(*)
DataType ret_dtype = GetRuntimeDataType(ret_type);
std::string global_symbol_str = global_symbol;
ReplaceAll(global_symbol_str, "::", ".");
std::vector<std::string> sargs;
// when the template arguments occurs at the end, merge them with function
// arguments
if (global_symbol_str.back() == '>') {
auto pos = global_symbol_str.rfind('<');
ICHECK(pos != std::string::npos);
std::string template_args =
global_symbol_str.substr(pos + 1, global_symbol_str.size() - pos - 2);
ReplaceAll(template_args, "true", "True");
ReplaceAll(template_args, "false", "False");
sargs.push_back(template_args);
global_symbol_str.resize(pos);
}
const size_t arg_begin = static_cast<size_t>(skip_first_arg);
for (size_t i = arg_begin; i < args.size(); ++i) {
std::string sarg = PrintExpr_(args[i]);
if (ret_dtype.is_fixed_length_vector()) {
std::string val = SSAGetID(sarg, args[i].dtype());
sargs.push_back(std::move(val));
} else {
sargs.push_back(sarg);
}
}
// Replace "<...>" with "(...)". Nested "<" is not supported
{
auto pos_left = global_symbol_str.find('<');
while (pos_left != std::string::npos) {
auto pos_right = global_symbol_str.find('>', pos_left + 1);
if (pos_right != std::string::npos) {
auto args =
global_symbol_str.substr(pos_left + 1, pos_right - pos_left - 1);
ReplaceAll(args, "true", "True");
ReplaceAll(args, "false", "False");
global_symbol_str.replace(pos_left, args.size() + 2, "(" + args + ")");
}
pos_left = global_symbol_str.find('<');
}
}
// Special cases:
// Map C math functions to Python/cutedsl equivalents
const auto canonicalized_global_symbol_str =
CanonicalizeFastmathFunctionName_(global_symbol_str);
const bool canonicalized = !canonicalized_global_symbol_str.empty();
if (canonicalized) {
global_symbol_str = canonicalized_global_symbol_str;
}
// Atomic Functions
if (global_symbol_str.substr(0, 6) == "Atomic") {
global_symbol_str = "tl." + global_symbol_str;
// Convert first argument (Buffer) to pointer for atomic operations
if (const BufferLoadNode *load = args[arg_begin].as<BufferLoadNode>()) {
ICHECK_EQ(load->indices.size(), 1)
<< "CodeGenTileLangCuTeDSL only supports flat memory";
sargs[0] = GetBufferPtr_(load->buffer.get(), load->indices[0]);
}
}
// some optional template arguments might be ommited, so add names explicitly
// for remain arguments
if (global_symbol_str == "tl.gemm_ss" || global_symbol_str == "tl.gemm_rs" ||
global_symbol_str == "tl.gemm_sr" || global_symbol_str == "tl.gemm_rr") {
ICHECK(sargs.size() >= 3);
sargs[sargs.size() - 3] = "A_ptr=" + sargs[sargs.size() - 3];
sargs[sargs.size() - 2] = "B_ptr=" + sargs[sargs.size() - 2];
sargs[sargs.size() - 1] = "C_ptr=" + sargs[sargs.size() - 1];
}
if (ret_dtype.is_fixed_length_vector()) {
// maybe simplify this if TensorSSA suppports this OP
std::string sret = name_supply_->FreshName("_");
PrintIndent();
stream << sret << " = tl.make_rmem_tensor((" << ret_dtype.lanes() << ",), ";
PrintType(ret_dtype.element_of(), stream);
stream << ")\n";
// Emit a scalar call for each lane.
bool has_template_arg = (sargs.size() > args.size() - arg_begin);
for (int i = 0; i < ret_dtype.lanes(); ++i) {
std::ostringstream scall;
scall << global_symbol_str << "(";
for (size_t j = 0; j < sargs.size(); ++j) {
if (j != 0) {
scall << ", ";
}
if (j == 0 && has_template_arg) {
scall << sargs[j];
} else {
PrintVecElemLoad_(
sargs[j],
args[arg_begin + j - static_cast<size_t>(has_template_arg)]
.dtype(),
i, scall);
}
}
if (canonicalized && enable_fastmath_) {
if (!sargs.empty()) {
scall << ", ";
}
scall << "fastmath=True";
}
scall << ")";
PrintVecElemStore_(sret, ret_dtype, i, scall.str());
}
os << sret << ".load()";
} else {
os << global_symbol_str << "(";
for (size_t i = 0; i < sargs.size(); ++i) {
if (i != 0) {
os << ", ";
}
os << sargs[i];
}
if (canonicalized && enable_fastmath_) {
if (!sargs.empty()) {
os << ", ";
}
os << "fastmath=True";
}
os << ")";
}
}
std::string CodeGenTileLangCuTeDSL::GetBufferPtr_(const BufferNode *buffer,
PrimExpr index) {
const VarNode *buffer_var = buffer->data.get();
const std::string vid = GetVarID(buffer_var);
DataType buffer_element_dtype = buffer->dtype;
bool is_handle_type_match =
HandleTypeMatch_(buffer_var, buffer_element_dtype);
std::string ptr_str;
if (is_handle_type_match) {
ptr_str = vid + ".iterator";
} else {
ptr_str = "tl.recast_ptr(" + vid +
".iterator, dtype=" + DTypeToString(buffer_element_dtype) + ")";
}
std::string index_str = PrintExpr_(index);
return "(" + ptr_str + " + " + index_str + ")";
}
// The following forms can be returned:
// (1) vid
// (2) vid[i]
// (3) tl.make_tensor_at_offset(...)[0]
// (4) tl.make_tensor_at_offset(...)
//
// Form (4) is needed when the whole tensor is loaded or stored.
// It's the only form that ends with ")". Using this fact, BufferLoadNode will
// add ".load()" and BufferStoreNode will add ".store()".
std::string CodeGenTileLangCuTeDSL::GetBufferRef_(DataType t,
const BufferNode *buffer,
PrimExpr index) {
const VarNode *buffer_var = buffer->data.get();
std::string vid = GetVarID(buffer_var);
std::string scope;
if (alloc_storage_scope_.count(buffer_var)) {
scope = alloc_storage_scope_.at(buffer_var);
}
if (scope.empty()) {
scope = GetPtrStorageScope(buffer->data);
}
if (scope == "local.var" || scope.find("local.descriptor") == 0) {
return vid;
}
DataType buffer_element_dtype = buffer->dtype;
bool is_handle_type_match =
HandleTypeMatch_(buffer_var, buffer_element_dtype);
std::string ptr_str;
if (is_handle_type_match) {
ptr_str = vid + ".iterator";
} else {
ptr_str = "tl.recast_ptr(" + vid +
".iterator, dtype=" + DTypeToString(buffer_element_dtype) + ")";
}
const std::string index_str = PrintExpr_(index);
if (t == buffer_element_dtype) {
if (is_handle_type_match && buffer_element_dtype.is_scalar() &&
(scope == "local" || scope == "shared" || scope == "shared.dyn" ||
scope == "shared.barrier")) {
// Tensors in these scopes are allocated as one-dimensional, so can be
// assessed via "[]" correctly. Other tensors may be multi-dimensional,
// and must be assessed via ptr, otherwise CuTeDSL will interpret "[]"
// access using its visiting order and layout.
return vid + "[" + index_str + "]";
} else {
std::ostringstream os;
os << "tl.make_tensor_at_offset(" << ptr_str << ", " << index_str
<< ", (1,), div_by=" << buffer_element_dtype.lanes() << ")";
// for vector data types, ".load()" (added by BufferLoadNode) is neeed
// instead of "[0]"
if (buffer_element_dtype.is_scalar()) {
os << "[0]";
}
return os.str();
}
} else {
const int num = t.bits() * t.lanes();
const int den = buffer_element_dtype.bits() * buffer_element_dtype.lanes();
ICHECK_EQ(num % den, 0) << "Cannot form view: bitwidth not divisible";
int buffer_size = num / den;
std::ostringstream os;
os << "tl.make_tensor_at_offset(" << ptr_str << ", " << index_str << ", ("
<< buffer_size << ",), div_by=" << buffer_size << ")";
return os.str();
}
}
void CodeGenTileLangCuTeDSL::BindThreadIndex_(const IterVar &iv) {
ICHECK(!var_idmap_.count(iv->var.get()));
auto &thread_tag = iv->thread_tag;
ICHECK(thread_tag == "threadIdx.x" || thread_tag == "threadIdx.y" ||
thread_tag == "threadIdx.z" || thread_tag == "blockIdx.x" ||
thread_tag == "blockIdx.y" || thread_tag == "blockIdx.z");
// cute.arch.thread_idx() and block_idx() are Int32
DataType from_dtype = DataType::Int(32);
var_idmap_[iv->var.get()] =
CastFromTo_(thread_tag, from_dtype, iv->var.dtype());
}
void CodeGenTileLangCuTeDSL::PrintStorageSync_(const CallNode *op) {
auto args = op->args;
const std::string &sync = args[0].as<StringImmNode>()->value;
if (sync == "warp") {
// do nothing
} else if (sync == "shared" || sync == "shared.dyn") {
PrintIndent();
if (args.size() == 1) {
stream << "tl.sync_threads()\n";
} else if (args.size() == 2) {
auto barrier_id_ptr = args[1].as<IntImmNode>();
ICHECK(barrier_id_ptr)
<< "storage_sync barrier_id (args[1]) must be IntImm, got "
<< args[1]->GetTypeKey();
auto barrier_id = barrier_id_ptr->value;
stream << "tl.sync_thread_partial(" << barrier_id << ")\n";
} else if (args.size() == 3) {
auto barrier_id_ptr = args[1].as<IntImmNode>();
ICHECK(barrier_id_ptr)
<< "storage_sync barrier_id (args[1]) must be IntImm, got "
<< args[1]->GetTypeKey();
auto thread_count_ptr = args[2].as<IntImmNode>();
ICHECK(thread_count_ptr)
<< "storage_sync thread_count (args[2]) must be IntImm, got "
<< args[2]->GetTypeKey();
auto barrier_id = barrier_id_ptr->value;
auto thread_count = thread_count_ptr->value;
stream << "tl.sync_thread_partial(" << barrier_id << ", " << thread_count
<< ")\n";
} else {
LOG(FATAL) << "Invalid number of arguments for storage sync: "
<< args.size();
}
} else if (sync == "global") {
LOG(FATAL) << "PrintStorageSync_ for global is not supported for now";
} else {
LOG(FATAL) << "Unknown storage sync scope: " << sync;
}
}
} // namespace codegen
} // namespace tvm
/*!
* \file target/codegen_cutedsl.h
* \brief Utility to generate CuTeDSL code
*/
#ifndef TVM_TL_TARGET_CODEGEN_CUTEDSL_H_
#define TVM_TL_TARGET_CODEGEN_CUTEDSL_H_
#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <string>
#include <unordered_map>
#include <vector>
#include "codegen_py.h"
namespace tvm {
namespace codegen {
class CodeGenTileLangCuTeDSL final : public CodeGenTileLangPY {
public:
CodeGenTileLangCuTeDSL();
protected:
void PrintFuncDecorator_(std::ostream &os) override; // NOLINT(*)
void PreFunctionBody_(const PrimFunc &f) override;
protected:
void PrintType(DataType t, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const BroadcastNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const CastNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const DivNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MinNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MaxNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const CallNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const BufferLoadNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitStmt_(const BufferStoreNode *op) override;
void VisitStmt_(const AllocateNode *op) override;
void VisitStmt_(const AttrStmtNode *op) override;
void VisitStmt_(const ForNode *op) override;
void VisitStmt_(const IfThenElseNode *op) override;
void VisitStmt_(const EvaluateNode *op) override;
protected:
virtual void PrintVecElemLoad_(const std::string &vec, DataType t, int i,
std::ostream &os); // NOLINT(*)
virtual void PrintVecElemStore_(const std::string &vec, DataType t, int i,
const std::string &value);
virtual void PrintVecStore_(const BufferNode *buffer, DataType t,
PrimExpr base, const std::string &value);
void PrintVecBinaryOp_(const std::string &opstr, DataType dtype, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os); // NOLINT(*)
void PrintBinaryExpr_(const std::string &opstr, DataType dtype, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os) override; // NOLINT(*)
void PrintBinaryIntrinsic_(const CallNode *op, const char *opstr,
std::ostream &os) override; // NOLINT(*)
void PrintCallExtern_(Type ret_type, ffi::String global_symbol,
const ffi::Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) override; // NOLINT(*)
std::string GetBufferPtr_(const BufferNode *buffer, PrimExpr index);
std::string GetBufferRef_(DataType t, const BufferNode *buffer,
PrimExpr index) override;
/*!
* \brief Print expr representing the thread tag
* \param IterVar iv The thread index to be binded;
*/
virtual void BindThreadIndex_(const IterVar &iv); // NOLINT(*)
virtual void PrintStorageSync_(const CallNode *op);
std::string
CanonicalizeFastmathFunctionName_(const std::string &func_name) const;
private:
// The name of the mbarrier array in shared memory
const std::string mbarrier_name_ = "mbarrier";
std::unordered_map<const VarNode *, IntImm> unroll_factor_;
std::vector<std::string> eviction_policy_names_ = {
"EVICT_NORMAL", "EVICT_FIRST", "EVICT_LAST"};
// Fastmath configuration (read from PassContext)
bool enable_fastmath_ = false;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TL_TARGET_CODEGEN_CUTEDSL_H_
/*!
* \file codegen_py.cc
*/
#include "codegen_py.h"
#include "codegen_utils.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ir/name_supply.h>
#include <cctype>
namespace tvm {
namespace codegen {
void CodeGenTileLangPY::AddFunction(const GlobalVar &gvar, const PrimFunc &f) {
RegisterFunction_(gvar, f);
auto function_name = GetFunctionName_(gvar);
// clear previous generated state.
InitFuncState_(f);
PrintFuncDecorator_(stream);
PrintFunctionSignature_(function_name, f, stream);
stream << ":\n";
int func_scope = BeginScope();
PreFunctionBody_(f);
PrintStmt_(f->body);
EndScope(func_scope);
}
std::string CodeGenTileLangPY::Finish() {
std::ostringstream code;
code << decl_stream.str();
code << stream.str();
return code.str();
}
ffi::String CodeGenTileLangPY::GetFunctionName_(const GlobalVar &gvar) {
auto it = internal_functions_.find(gvar);
ICHECK(it != internal_functions_.end())
<< "Attempted to find name of " << gvar
<< ", but no function with this GlobalVar has been declared";
return it->second;
}
void CodeGenTileLangPY::RegisterFunction_(const GlobalVar &gvar,
const PrimFunc &func) {
if (internal_functions_.count(gvar)) {
return;
}
auto function_name = [&]() -> ffi::String {
if (auto global_symbol =
func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
auto name = global_symbol.value();
ICHECK(!func_name_supply_->ContainsName(name))
<< "Function " << gvar << " must use global symbol " << name
<< ", but this name has already been used.";
func_name_supply_->ReserveName(name);
return name;
} else {
ICHECK(!func_name_supply_->ContainsName(gvar->name_hint))
<< "Function " << gvar << " must use name hint " << gvar->name_hint
<< ", but this name has already been used.";
func_name_supply_->ReserveName(gvar->name_hint);
return gvar->name_hint;
}
}();
internal_functions_.insert({gvar, function_name});
}
void CodeGenTileLangPY::InitFuncState_(const PrimFunc &f) {
alloc_storage_scope_.clear();
handle_data_type_.clear();
CodeGenSourceBase::ClearFuncState();
ReserveKeywordsAsUnique_();
}
void CodeGenTileLangPY::PrintFunctionSignature_(
const ffi::String &function_name, const PrimFunc &func,
std::ostream &os) { // NOLINT(*)
os << "def " << function_name << "(";
for (size_t i = 0; i < func->params.size(); ++i) {
tir::Var v = func->params[i];
if (i > 0) {
os << ", ";
}
os << AllocVarID(v.get());
}
os << ")";
// Register handle data type
for (const auto &param : func->params) {
if (auto *ptr = param->type_annotation.as<PointerTypeNode>()) {
if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType_(param.get(), prim->dtype);
}
}
}
}
void CodeGenTileLangPY::ReserveKeywordsAsUnique_() {
// skip the first underscore, so SSA variable starts from _1
name_supply_->ReserveName("_");
name_supply_->ReserveName("False");
name_supply_->ReserveName("None");
name_supply_->ReserveName("True");
name_supply_->ReserveName("and");
name_supply_->ReserveName("as");
name_supply_->ReserveName("assert");
name_supply_->ReserveName("async");
name_supply_->ReserveName("await");
name_supply_->ReserveName("break");
name_supply_->ReserveName("class");
name_supply_->ReserveName("continue");
name_supply_->ReserveName("def");
name_supply_->ReserveName("del");
name_supply_->ReserveName("elif");
name_supply_->ReserveName("else");
name_supply_->ReserveName("except");
name_supply_->ReserveName("finally");
name_supply_->ReserveName("for");
name_supply_->ReserveName("from");
name_supply_->ReserveName("global");
name_supply_->ReserveName("if");
name_supply_->ReserveName("import");
name_supply_->ReserveName("in");
name_supply_->ReserveName("is");
name_supply_->ReserveName("lambda");
name_supply_->ReserveName("nonlocal");
name_supply_->ReserveName("not");
name_supply_->ReserveName("or");
name_supply_->ReserveName("pass");
name_supply_->ReserveName("raise");
name_supply_->ReserveName("return");
name_supply_->ReserveName("try");
name_supply_->ReserveName("while");
name_supply_->ReserveName("with");
name_supply_->ReserveName("yield");
name_supply_->ReserveName("void");
name_supply_->ReserveName("int");
name_supply_->ReserveName("float");
name_supply_->ReserveName("double");
name_supply_->ReserveName("char");
name_supply_->ReserveName("unsigned");
name_supply_->ReserveName("short");
name_supply_->ReserveName("long");
name_supply_->ReserveName("cutlass");
name_supply_->ReserveName("cute");
name_supply_->ReserveName("tl");
}
void CodeGenTileLangPY::PrintSSAAssign(const std::string &target,
const std::string &src, DataType t) {
stream << target << " = " << RemoveOutermostParentheses(src) << "\n";
}
void CodeGenTileLangPY::PrintType(DataType type,
std::ostream &os) { // NOLINT(*)
if (type.is_float()) {
if (type.bits() == 16 || type.bits() == 32 || type.bits() == 64) {
os << "float";
} else {
LOG(FATAL) << "Cannot convert float" << type.bits() << " to Python type";
}
} else if (type.is_uint()) {
switch (type.bits()) {
case 8:
case 16:
case 32:
case 64: {
os << "int";
break;
}
case 1:
os << "bool";
break;
default:
LOG(FATAL) << "Cannot convert uint" << type.bits() << " to Python type";
}
} else if (type.is_int()) {
switch (type.bits()) {
case 8:
case 16:
case 32:
case 64: {
os << "int";
break;
}
case 1:
os << "bool";
break;
default:
LOG(FATAL) << "Cannot convert int" << type.bits() << " to Python type";
}
} else {
LOG(FATAL) << "Cannot convert type " << type << " to Python type";
}
}
void CodeGenTileLangPY::VisitExpr_(const VarNode *op,
std::ostream &os) { // NOLINT(*)
os << GetVarID(op);
}
void CodeGenTileLangPY::VisitExpr_(const IntImmNode *op,
std::ostream &os) { // NOLINT(*)
if (op->dtype == DataType::Bool()) {
os << (op->value ? "True" : "False");
} else {
std::ostringstream temp;
temp << op->value;
MarkConst(temp.str());
os << temp.str();
}
}
void CodeGenTileLangPY::VisitExpr_(const FloatImmNode *op,
std::ostream &os) { // NOLINT(*)
switch (op->dtype.bits()) {
case 64:
case 32: {
std::ostringstream temp;
temp << "float.fromhex('" << std::hexfloat << op->value << "')";
MarkConst(temp.str());
os << temp.str();
break;
}
case 16: {
PrintType(op->dtype, os);
os << "(float.fromhex('" << std::hexfloat << op->value << "'))";
break;
}
default:
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
}
}
void CodeGenTileLangPY::VisitExpr_(const StringImmNode *op,
std::ostream &os) { // NOLINT(*)
EscapeStringLiteral_(op->value, os);
}
void CodeGenTileLangPY::VisitExpr_(const CastNode *op,
std::ostream &os) { // NOLINT(*)
std::stringstream value;
PrintExpr_(op->value, value);
os << CastFromTo_(value.str(), op->value.dtype(), op->dtype);
}
void CodeGenTileLangPY::VisitExpr_(const AddNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("+", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const SubNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("-", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const MulNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("*", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const DivNode *op,
std::ostream &os) { // NOLINT(*)
if (op->dtype.is_int() || op->dtype.is_uint()) {
PrintBinaryExpr_("//", op->dtype, op->a, op->b, os);
} else {
PrintBinaryExpr_("/", op->dtype, op->a, op->b, os);
}
}
void CodeGenTileLangPY::VisitExpr_(const ModNode *op,
std::ostream &os) { // NOLINT(*)
ICHECK(op->dtype.is_int() || op->dtype.is_uint() || op->dtype.is_float())
<< "Expected floating point or integer dtype in Mod, but got "
<< op->dtype;
PrintBinaryExpr_("%", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const MinNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("min", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const MaxNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("max", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const EQNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("==", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const NENode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("!=", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const LTNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("<", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const LENode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("<=", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const GTNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_(">", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const GENode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_(">=", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const AndNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("and", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const OrNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("or", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const NotNode *op,
std::ostream &os) { // NOLINT(*)
os << "(not ";
PrintExpr_(op->a, os);
os << ")";
}
void CodeGenTileLangPY::VisitExpr_(const SelectNode *op,
std::ostream &os) { // NOLINT(*)
os << "(";
PrintExpr_(op->true_value, os);
os << " if ";
PrintExpr_(op->condition, os);
os << " else ";
PrintExpr_(op->false_value, os);
os << ")";
}
void CodeGenTileLangPY::VisitExpr_(const RampNode *op,
std::ostream &os) { // NOLINT(*)
int lanes = op->dtype.lanes();
os << "(";
for (int i = 0; i < lanes; i++) {
os << "(" << PrintExpr_(op->base) << ")"
<< "+(" << PrintExpr_(op->stride) << "*" << i << ")";
if (i != lanes - 1)
os << ", ";
}
os << ")";
}
void CodeGenTileLangPY::VisitExpr_(const CallNode *op,
std::ostream &os) { // NOLINT(*)
if (auto opt_call_op = op->op.as<Op>()) {
const auto &call_op = opt_call_op.value();
if (op->op.same_as(builtin::ret())) {
os << "return " << RemoveOutermostParentheses(PrintExpr_(op->args[0]));
} else if (op->op.same_as(builtin::continue_loop())) {
os << "continue";
} else if (op->op.same_as(builtin::break_loop())) {
os << "break";
} else if (op->op.same_as(builtin_call_extern_) ||
op->op.same_as(builtin_call_pure_extern_)) {
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
PrintCallExtern_(GetType(ffi::GetRef<PrimExpr>(op)), func->value,
op->args, true, os);
} else if (op_attr_global_symbol_.count(call_op)) {
// call extern if the op itself have a global symbol.
PrintCallExtern_(GetType(ffi::GetRef<PrimExpr>(op)),
op_attr_global_symbol_[call_op], op->args, false, os);
} else if (op->op.same_as(builtin::large_uint_imm())) {
ICHECK_EQ(op->args.size(), 2U);
uint64_t low =
static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
uint64_t high =
static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
uint64_t val = (high << 32U) | low;
if (op->dtype == DataType::UInt(32)) {
std::ostringstream temp;
temp << val;
MarkConst(temp.str());
os << temp.str();
} else {
PrintType(op->dtype, os);
os << "(" << val << ")";
}
} else if (op->op.same_as(builtin::bitwise_and())) {
PrintBinaryIntrinsic_(op, "&", os);
} else if (op->op.same_as(builtin::bitwise_or())) {
PrintBinaryIntrinsic_(op, "|", os);
} else if (op->op.same_as(builtin::bitwise_xor())) {
PrintBinaryIntrinsic_(op, "^", os);
} else if (op->op.same_as(builtin::bitwise_not())) {
ICHECK_EQ(op->args.size(), 1U);
os << "~";
PrintExpr_(op->args[0], os);
} else if (op->op.same_as(builtin::shift_left())) {
PrintBinaryIntrinsic_(op, "<<", os);
} else if (op->op.same_as(builtin::shift_right())) {
PrintBinaryIntrinsic_(op, ">>", os);
} else if (op->op.same_as(builtin::if_then_else())) {
std::string cond = PrintExpr_(op->args[0]);
std::string true_val = PrintExpr_(op->args[1]);
std::string false_val = PrintExpr_(op->args[2]);
os << "(" << true_val << " if " << cond << " else " << false_val << ")";
} else if (op->op.same_as(builtin::isnullptr())) {
ICHECK_EQ(op->args.size(), 1U);
os << "(";
PrintExpr_(op->args[0], os);
os << " is None)";
} else if (op->op.same_as(builtin::isnan())) {
os << "(";
PrintExpr_(op->args[0], os);
os << " != ";
PrintExpr_(op->args[0], os);
os << ")";
} else {
LOG(FATAL) << "Unresolved call " << op->op;
}
} else if (auto opt = op->op.as<GlobalVar>()) {
const auto &gvar = opt.value();
auto callee_name = GetFunctionName_(gvar);
PrintCallExtern_(GetType(ffi::GetRef<PrimExpr>(op)), callee_name, op->args,
false, os);
} else {
LOG(FATAL)
<< "CodeGenTileLangPY: Unknown operation " << op->op
<< " is neither a recognized built-in, "
<< "nor a GlobalVar reference to another function in the IRModule";
}
}
void CodeGenTileLangPY::VisitExpr_(const BufferLoadNode *op,
std::ostream &os) { // NOLINT(*)
ICHECK_EQ(op->indices.size(), 1)
<< "Load from non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer load is not supported.";
DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
DataType element_dtype = op->buffer->dtype;
ICHECK_EQ(value_dtype, element_dtype)
<< "value_dtype and element_dtype must be same for a BufferLoadNode";
std::string ref = GetBufferRef_(op->dtype, op->buffer.get(), index);
os << ref;
}
void CodeGenTileLangPY::VisitStmt_(const BufferStoreNode *op) {
ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer store is not supported.";
DataType value_dtype = op->value.dtype();
DataType element_dtype = op->buffer->dtype;
PrimExpr index_expr = op->indices[0];
Var buffer_var = op->buffer->data;
ICHECK_EQ(value_dtype, element_dtype)
<< "value_dtype and element_dtype must be same for a BufferStoreNode";
std::string value = PrintExpr_(op->value);
std::string ref = GetBufferRef_(value_dtype, op->buffer.get(), index_expr);
PrintIndent();
stream << ref << " = " << RemoveOutermostParentheses(value) << "\n";
}
void CodeGenTileLangPY::VisitStmt_(const DeclBufferNode *op) {
PrintStmt_(op->body);
}
void CodeGenTileLangPY::VisitStmt_(const LetStmtNode *op) {
std::string value = PrintExpr_(op->value);
PrintIndent();
stream << AllocVarID(op->var.get()) << " = " << value << "\n";
PrintStmt_(op->body);
}
void CodeGenTileLangPY::VisitStmt_(const AllocateNode *op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
PrintIndent();
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
auto scope = GetPtrStorageScope(op->buffer_var);
alloc_storage_scope_[op->buffer_var.get()] = scope;
stream << vid << " = [None] * " << constant_size << "\n";
RegisterHandleType_(op->buffer_var.get(), op->dtype);
PrintStmt_(op->body);
}
void CodeGenTileLangPY::VisitStmt_(const AttrStmtNode *op) {
PrintStmt_(op->body);
}
void CodeGenTileLangPY::VisitStmt_(const ForNode *op) {
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
stream << "for " << vid << " in range(";
if (is_zero(op->min)) {
PrintExpr_(op->extent, stream);
} else {
PrintExpr_(op->min, stream);
stream << ", ";
PrimExpr upper_bound = arith::Analyzer().Simplify(op->extent + op->min);
PrintExpr_(upper_bound, stream);
}
stream << "):\n";
int for_scope = BeginScope();
PrintStmt_(op->body);
EndScope(for_scope);
}
void CodeGenTileLangPY::VisitStmt_(const WhileNode *op) {
std::string cond = PrintExpr_(op->condition);
PrintIndent();
stream << "while " << RemoveOutermostParentheses(cond) << ":\n";
int while_scope = BeginScope();
PrintStmt_(op->body);
EndScope(while_scope);
}
void CodeGenTileLangPY::VisitStmt_(const IfThenElseNode *op) {
std::string cond = PrintExpr_(op->condition);
PrintIndent();
stream << "if " << RemoveOutermostParentheses(cond) << ":\n";
int then_scope = BeginScope();
PrintStmt_(op->then_case);
EndScope(then_scope);
if (op->else_case) {
PrintIndent();
stream << "else:\n";
int else_scope = BeginScope();
PrintStmt_(op->else_case.value());
EndScope(else_scope);
}
}
void CodeGenTileLangPY::VisitStmt_(const SeqStmtNode *op) {
for (Stmt stmt : op->seq) {
PrintStmt_(stmt);
}
}
void CodeGenTileLangPY::VisitStmt_(const EvaluateNode *op) {
if (is_const_int(op->value))
return;
std::string vid = PrintExpr_(op->value);
if (!vid.empty()) {
PrintIndent();
stream << vid << "\n";
}
}
void CodeGenTileLangPY::VisitStmt_(const AssertStmtNode *op) {
std::string cond = PrintExpr_(op->condition);
PrintIndent();
if (const auto *str = op->message.as<StringImmNode>()) {
stream << "assert " << cond << ", ";
EscapeStringLiteral_(str->value, stream);
stream << "\n";
} else {
stream << "assert " << cond << "\n";
}
PrintStmt_(op->body);
}
std::string CodeGenTileLangPY::CastFromTo_(const std::string &value,
DataType from, DataType target) {
if (from == target)
return value;
std::ostringstream os;
PrintType(target, os);
os << "(" << value << ")";
return os.str();
}
void CodeGenTileLangPY::PrintBinaryExpr_(const std::string &opstr,
DataType dtype, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os) { // NOLINT(*)
ICHECK_EQ(dtype.lanes(), 1);
if (isalpha(opstr[0]) && opstr != "and" && opstr != "or") {
os << opstr << '(';
PrintExpr_(lhs, os);
os << ", ";
PrintExpr_(rhs, os);
os << ')';
} else {
os << '(';
PrintExpr_(lhs, os);
os << ' ' << opstr << ' ';
PrintExpr_(rhs, os);
os << ')';
}
}
void CodeGenTileLangPY::PrintBinaryIntrinsic_(const CallNode *op,
const char *opstr,
std::ostream &os) { // NOLINT(*)
ICHECK_EQ(op->dtype.lanes(), 1);
ICHECK_EQ(op->args.size(), 2U);
os << '(';
PrintExpr_(op->args[0], os);
os << ' ' << opstr << ' ';
PrintExpr_(op->args[1], os);
os << ')';
}
void CodeGenTileLangPY::PrintCallExtern_(Type ret_type,
ffi::String global_symbol,
const ffi::Array<PrimExpr> &args,
bool skip_first_arg,
std::ostream &os) { // NOLINT(*)
os << global_symbol << "(";
for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
PrintExpr_(args[i], os);
if (i < args.size() - 1) {
os << ", ";
}
}
os << ")";
}
// Print a reference expression to a buffer.
std::string CodeGenTileLangPY::GetBufferRef_(DataType t,
const BufferNode *buffer,
PrimExpr index) {
const VarNode *buffer_var = buffer->data.get();
std::string vid = GetVarID(buffer_var);
DataType buffer_element_dtype = buffer->dtype;
ICHECK(HandleTypeMatch_(buffer_var, buffer_element_dtype));
ICHECK_EQ(t, buffer_element_dtype);
std::string index_str = PrintExpr_(index);
return vid + "[" + index_str + "]";
}
void CodeGenTileLangPY::RegisterHandleType_(const VarNode *buf_var,
DataType t) {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) {
handle_data_type_[buf_var] = t;
} else {
ICHECK(it->second == t) << "conflicting buf var type";
}
}
bool CodeGenTileLangPY::HandleTypeMatch_(const VarNode *buf_var,
DataType t) const {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end())
return false;
return it->second == t;
}
void CodeGenTileLangPY::EscapeStringLiteral_(const std::string &s,
std::ostream &os) {
os << '"';
for (unsigned char c : s) {
switch (c) {
case '\\':
os << "\\\\";
break;
case '"':
os << "\\\"";
break;
case '\n':
os << "\\n";
break;
case '\r':
os << "\\r";
break;
case '\t':
os << "\\t";
break;
case '\f':
os << "\\f";
break;
case '\b':
os << "\\b";
break;
default:
// Handle non-printable and non-ASCII characters
if (c < 32 || c == 127) {
// Output as \xHH
os << "\\x";
const char hex[] = "0123456789abcdef";
os << hex[(c >> 4) & 0xF];
os << hex[c & 0xF];
} else {
os << c;
}
break;
}
}
os << '"';
}
} // namespace codegen
} // namespace tvm
/*!
* \file codegen_py.h
* \brief Common utilities to generate simple Python code.
*/
#ifndef TVM_TL_TARGET_CODEGEN_PY_H_
#define TVM_TL_TARGET_CODEGEN_PY_H_
#include <tvm/ir/op.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <string>
#include <unordered_map>
// from tvm/src/
#include "target/source/codegen_source_base.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace codegen {
using namespace tir;
/*!
* \brief A base class to generate simple Python code.
*/
class CodeGenTileLangPY
: public ExprFunctor<void(const PrimExpr &, std::ostream &)>,
public StmtFunctor<void(const Stmt &)>,
public CodeGenSourceBase {
public:
/*!
* \brief Add the function definition to the generated module.
* \param gvar The GlobalVar representing the function.
* \param func The function to be compiled.
*/
virtual void AddFunction(const GlobalVar &gvar, const PrimFunc &func);
/*!
* \brief Finalize the compilation and return the code.
* \return The code.
*/
virtual std::string Finish();
protected:
/*!
* \brief Get the name of a declared function
* \param gvar The GlobalVar of the function
* \returns The string name of the function
*/
ffi::String GetFunctionName_(const GlobalVar &gvar);
/*!
* \brief Reserve the function name in the generated module.
*
* \param gvar The GlobalVar representing the function.
* \param func The function to be compiled.
* \param whether to append return 0 in the end.
*/
virtual void RegisterFunction_(const GlobalVar &gvar, const PrimFunc &func);
/*!
* \brief Initialize codegen state for generating f.
* \param f The function to be compiled.
*/
virtual void InitFuncState_(const PrimFunc &f);
/*! \brief Print the function signature before ":"
* \param function_name The name of the function
* \param func The function whose signature should be printed
* \param os The output stream
*/
virtual void PrintFunctionSignature_(const ffi::String &function_name,
const PrimFunc &func,
std::ostream &os); // NOLINT(*)
/*!
* \brief Print the function decorator
* \param os The output stream
*/
virtual void PrintFuncDecorator_(std::ostream &os) {} // NOLINT(*)
/*!
* \brief Insert statement before function body.
* \param f The function to be compiled.
*/
virtual void PreFunctionBody_(const PrimFunc &f) {}
protected:
/*! \brief reserves common Python keywords */
void ReserveKeywordsAsUnique_();
void PrintSSAAssign(const std::string &target, const std::string &src,
DataType t) override;
protected:
/*!
* \brief Print Type representation of type type.
* \param t The type representation.
* \param os The output stream
*/
void PrintType(DataType type, std::ostream &os) override; // NOLINT(*)
/*!
* \brief Print the Stmt n to CodeGenTileLangPY->stream
* \param n The statement to be printed.
*/
void PrintStmt_(const Stmt &n) { VisitStmt(n); }
/*!
* \brief Print the expression n into os
* \param n The expression to be printed.
* \param os The output stream
*/
void PrintExpr_(const PrimExpr &n, std::ostream &os) { // NOLINT(*)
VisitExpr(n, os);
}
/*!
* \brief Same as PrintExpr_, but simply returns result string
* \param n The expression to be printed.
*/
std::string PrintExpr_(const PrimExpr &n) {
std::ostringstream os;
PrintExpr_(n, os);
return os.str();
}
// expression
void VisitExpr_(const VarNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const IntImmNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const StringImmNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const CastNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const AddNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const SubNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MulNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const DivNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const ModNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MinNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MaxNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const EQNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const NENode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const LTNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const LENode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const GTNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const GENode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const AndNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const OrNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const NotNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const SelectNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const RampNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const CallNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const BufferLoadNode *op,
std::ostream &os) override; // NOLINT(*)
// statment
void VisitStmt_(const BufferStoreNode *op) override;
void VisitStmt_(const DeclBufferNode *op) override;
void VisitStmt_(const LetStmtNode *op) override;
void VisitStmt_(const AllocateNode *op) override;
void VisitStmt_(const AttrStmtNode *op) override;
void VisitStmt_(const ForNode *op) override;
void VisitStmt_(const WhileNode *op) override;
void VisitStmt_(const IfThenElseNode *op) override;
void VisitStmt_(const SeqStmtNode *op) override;
void VisitStmt_(const EvaluateNode *op) override;
void VisitStmt_(const AssertStmtNode *op) override;
protected:
// Get a string of type casting
virtual std::string CastFromTo_(const std::string &value, DataType from,
DataType target);
virtual void PrintBinaryExpr_(const std::string &opstr, DataType dtype,
PrimExpr lhs, PrimExpr rhs,
std::ostream &os); // NOLINT(*)
virtual void PrintBinaryIntrinsic_(const CallNode *op, const char *opstr,
std::ostream &os); // NOLINT(*)
/*!
* \brief Print external function call.
* \param ret_type The return type.
* \param global_symbol The symbolc of the target function.
* \param args The arguments to the function.
* \param skip_first_arg Whether to skip the first arguments.
* \param os The output stream.
*/
virtual void PrintCallExtern_(Type ret_type, ffi::String global_symbol,
const ffi::Array<PrimExpr> &args,
bool skip_first_arg,
std::ostream &os); // NOLINT(*)
// Print reference to a buffer as type t in index.
virtual std::string GetBufferRef_(DataType t, const BufferNode *buffer,
PrimExpr index);
/*!
* \brief Register the data type of buf_var
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
void RegisterHandleType_(const VarNode *buf_var, DataType t);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
bool HandleTypeMatch_(const VarNode *buf_var, DataType t) const;
protected:
/*! \brief the storage scope of allocation */
std::unordered_map<const VarNode *, std::string> alloc_storage_scope_;
/*! \brief Record of ops that have pre-defined global symbol. */
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ =
Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
// cache commonly used ops
const Op &builtin_call_extern_ = builtin::call_extern();
const Op &builtin_call_pure_extern_ = builtin::call_pure_extern();
private:
/*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode *, DataType> handle_data_type_;
/* \brief Map of GlobalVar to their symbol.
*
* For externally-exposed functions, this is given by the
* tvm::attr::kTarget attribute of the PrimFunc. For internal
* functions, this is the name of the function's GlobalVar, possibly
* altered to prevent duplicate names.
*/
std::unordered_map<GlobalVar, ffi::String> internal_functions_;
/* \brief Name supply to generate unique function names */
NameSupply func_name_supply_;
/*!
* \brief Escape a string to be a valid Python double-quoted string literal.
* \param s The input string to escape.
* \param os The output stream to write the escaped string to.
*/
void EscapeStringLiteral_(const std::string &s, std::ostream &os);
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TL_TARGET_CODEGEN_PY_H_
/*!
* \file target/codegen_utils.cc
* \brief Shared utility functions for code generation
*/
#include "codegen_utils.h"
namespace tvm {
namespace codegen {
bool CheckOutermostParenthesesMatch(const std::string &s) {
if (!s.empty() && s.front() == '(' && s.back() == ')') {
size_t len = s.size();
int n_unmatched = 0;
for (size_t i = 0; i < len; ++i) {
if (s[i] == '(') {
n_unmatched++;
} else if (s[i] == ')') {
n_unmatched--;
}
if (n_unmatched < 0) {
return false;
}
if (n_unmatched == 0) {
return i == len - 1;
}
}
}
return false;
}
std::string RemoveOutermostParentheses(const std::string &s) {
if (CheckOutermostParenthesesMatch(s)) {
return s.substr(1, s.size() - 2);
} else {
return s;
}
}
} // namespace codegen
} // namespace tvm
/*!
* \file target/codegen_utils.h
* \brief Shared utility functions for code generation
*/
#ifndef TVM_TARGET_CODEGEN_UTILS_H_
#define TVM_TARGET_CODEGEN_UTILS_H_
#include <string>
namespace tvm {
namespace codegen {
/*!
* \brief Check if the outermost parentheses match
* \param s The input string
* \return true if the first character is '(' and the last character is ')'
* and they form a matching pair
*/
bool CheckOutermostParenthesesMatch(const std::string &s);
/*!
* \brief Remove outermost parentheses if they match
* \param s The input string
* \return The string with outermost parentheses removed if they match,
* otherwise return the original string
*/
std::string RemoveOutermostParentheses(const std::string &s);
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_CODEGEN_UTILS_H_
#include "codegen_cutedsl.h"
#include "runtime/cuda/cuda_module.h"
#include "runtime/pack_args.h"
#include <tvm/ffi/reflection/registry.h>
namespace tvm {
namespace codegen {
static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info;
for (size_t i = 0; i < f->params.size(); ++i) {
if (f->params[i]->dtype.is_handle()) {
auto ptr = f->params[i]->type_annotation.as<PointerTypeNode>();
if (ptr && ptr->storage_scope == "grid_constant") {
info.arg_types.push_back(DataType(runtime::kDLGridConstant, 64, 1));
continue;
}
}
info.arg_types.push_back(f->params[i].dtype());
}
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag);
}
}
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info;
}
return fmap;
}
ffi::Module BuildTileLangCuTeDSLWithoutCompile(IRModule mod, Target target) {
CodeGenTileLangCuTeDSL cg;
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCuTeDSL: Can only take PrimFunc";
auto gvar = Downcast<GlobalVar>(kv.first);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
cg.AddFunction(gvar, f);
}
std::string code = cg.Finish();
if (const auto f =
ffi::Function::GetGlobal("tilelang_callback_cutedsl_postproc")) {
code = (*f)(code, target).cast<std::string>();
}
return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.tilelang_cutedsl_without_compile",
BuildTileLangCuTeDSLWithoutCompile);
}
} // namespace codegen
} // namespace tvm
...@@ -173,4 +173,4 @@ template <class T, unsigned I = 0> ...@@ -173,4 +173,4 @@ template <class T, unsigned I = 0>
inline constexpr size_t extent_v = extent<T, I>::value; inline constexpr size_t extent_v = extent<T, I>::value;
} // namespace std } // namespace std
#endif #endif // __CUDACC_RTC__
\ No newline at end of file
from tilelang import tvm as tvm
import tilelang.language as T
import tilelang.testing
import tilelang
import torch
from tilelang.utils.tensor import map_torch_type
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
stramp = "&*(XS)"
@tvm.register_global_func("tilelang_callback_cutedsl_postproc", override=True)
def tilelang_callback_cutedsl_postproc(code, _):
code = f"# {stramp}\n" + code
return code
matmul_kernel = tilelang.compile(program, out_idx=-1, target="cutedsl")
kernel_source = matmul_kernel.get_kernel_source()
assert stramp in kernel_source, f"Expected {stramp} in the kernel source"
def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def matmul_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_jit_kernel(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, out_idx=-1, target="cutedsl")
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
A = torch.randn(M, K, dtype=in_dtype).cuda()
B = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
A = A.T
if trans_B:
B = B.T
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(out_dtype)
return C
ref_C = ref_program(A, B)
C = matmul_kernel(A, B)
tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_jit_kernel():
run_gemm_jit_kernel(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def run_cutedsl_kernel_do_bench(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, target="cutedsl")
profiler = matmul_kernel.get_profiler()
cutedsl_latency = profiler.do_bench(func=matmul_kernel)
print(f"CuTeDSL Latency: {cutedsl_latency} ms")
assert cutedsl_latency is not None
tvm_latency = profiler.do_bench()
print(f"TVM Latency: {tvm_latency} ms")
assert tvm_latency is not None
def test_cutedsl_kernel_do_bench():
run_cutedsl_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def run_cutedsl_kernel_multi_stream(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, target="cutedsl")
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
num_streams = 4
for _ in range(num_streams):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
matmul_kernel(tensor_a, tensor_b, tensor_c)
def test_cutedsl_kernel_multi_stream():
run_cutedsl_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def run_cutedsl_dynamic_shape(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
matmul_kernel = tilelang.compile(program, target="cutedsl")
if isinstance(M, T.Var):
M = 1024
if isinstance(N, T.Var):
N = 1024
if isinstance(K, T.Var):
K = 768
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_cutedsl_dynamic_shape():
run_cutedsl_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cutedsl_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cutedsl_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2
)
def check_hopper():
if not torch.cuda.is_available():
return False
props = torch.cuda.get_device_properties(0)
compute_capability = props.major, props.minor
return compute_capability == (9, 0)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -29,6 +29,11 @@ KERNEL_CUBIN_PATH = "kernel.cubin" ...@@ -29,6 +29,11 @@ KERNEL_CUBIN_PATH = "kernel.cubin"
KERNEL_PY_PATH = "kernel.py" KERNEL_PY_PATH = "kernel.py"
PARAMS_PATH = "params.pkl" PARAMS_PATH = "params.pkl"
# CuTeDSL C++ launcher specific
LAUNCHER_LIB_PATH = "launcher_lib.so"
LAUNCHER_CPP_PATH = "launcher.cpp"
CUTEDSL_CUBIN_PATH = "kernel.cubin"
class KernelCache: class KernelCache:
""" """
...@@ -43,7 +48,7 @@ class KernelCache: ...@@ -43,7 +48,7 @@ class KernelCache:
_instance = None # For implementing singleton pattern _instance = None # For implementing singleton pattern
_lock = threading.Lock() # For thread safety _lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary _memory_cache = {} # In-memory cache dictionary
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi" execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi"
def __new__(cls): def __new__(cls):
""" """
...@@ -72,7 +77,7 @@ class KernelCache: ...@@ -72,7 +77,7 @@ class KernelCache:
self, self,
func: Callable, func: Callable,
out_idx: list[int], out_idx: list[int],
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi", execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi",
args=None, args=None,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
...@@ -85,7 +90,7 @@ class KernelCache: ...@@ -85,7 +90,7 @@ class KernelCache:
Args: Args:
func (Callable): The function to be compiled. func (Callable): The function to be compiled.
out_idx (List[int]): Indices specifying which outputs to return. out_idx (List[int]): Indices specifying which outputs to return.
execution_backend (Literal): Backend type for execution. Defaults to "cython". execution_backend (Literal): Backend type for execution. Defaults to "tvm_ffi".
args: Arguments passed to the function. args: Arguments passed to the function.
target (Union[str, Target]): Compilation target platform. Defaults to "auto". target (Union[str, Target]): Compilation target platform. Defaults to "auto".
target_host (Union[str, Target], optional): Host target platform. target_host (Union[str, Target], optional): Host target platform.
...@@ -118,7 +123,7 @@ class KernelCache: ...@@ -118,7 +123,7 @@ class KernelCache:
*args, *args,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target = None,
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto", execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
verbose: bool = False, verbose: bool = False,
pass_configs: dict = None, pass_configs: dict = None,
compile_flags: list[str] | str | None = None, compile_flags: list[str] | str | None = None,
...@@ -217,7 +222,11 @@ class KernelCache: ...@@ -217,7 +222,11 @@ class KernelCache:
) )
with self._lock: with self._lock:
if env.is_cache_enabled(): if env.is_cache_enabled():
cache_path = self._get_cache_path(key)
self._save_kernel_to_disk(key, kernel, func, verbose) self._save_kernel_to_disk(key, kernel, func, verbose)
# Set cache path on adapter so it can save cubin after first execution
if hasattr(kernel, "adapter") and execution_backend == "cutedsl":
kernel.adapter._cache_path = cache_path
# Store in memory cache after compilation # Store in memory cache after compilation
self._memory_cache[key] = kernel self._memory_cache[key] = kernel
...@@ -287,36 +296,59 @@ class KernelCache: ...@@ -287,36 +296,59 @@ class KernelCache:
# Save kernel source code # Save kernel source code
try: try:
if self.execution_backend != "cutedsl":
device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH) device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH)
if verbose: if verbose:
self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}") self.logger.debug(f"Saving kernel source code to file: {device_kernel_path}")
if kernel.kernel_source is not None: if kernel.kernel_source is not None:
KernelCache._safe_write_file(device_kernel_path, "w", lambda file: file.write(kernel.kernel_source)) KernelCache._safe_write_file(device_kernel_path, "w", lambda file: file.write(kernel.kernel_source))
except Exception as e: except Exception:
self.logger.error(f"Error saving kernel source code to disk: {e}") self.logger.exception("Error saving kernel source code to disk")
# Save wrapped kernel source code # Save wrapped kernel source code
try: try:
host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH) host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH if self.execution_backend != "cutedsl" else KERNEL_PY_PATH)
if verbose: if verbose:
self.logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}") self.logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}")
if self.execution_backend == "tvm_ffi": if self.execution_backend == "tvm_ffi":
KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_host_source())) KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_host_source()))
else: else:
KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_kernel_source())) KernelCache._safe_write_file(host_kernel_path, "w", lambda file: file.write(kernel.adapter.get_kernel_source()))
except Exception as e: except Exception:
self.logger.error(f"Error saving host kernel source code to disk: {e}") self.logger.exception("Error saving host kernel source code to disk")
# Save the kernel library # Save the kernel library
try: try:
# Save CUBIN or SO file # Save CUBIN or SO file
if self.execution_backend == "cutedsl":
# For CuTeDSL, kernel_lib_path is the Python module
kernel_lib_path = os.path.join(cache_path, KERNEL_PY_PATH)
# Save C++ launcher library if it exists
lib_gen = getattr(kernel.adapter, "lib_generator", None)
if lib_gen and hasattr(lib_gen, "launcher_libpath") and lib_gen.launcher_libpath:
launcher_lib_path = os.path.join(cache_path, LAUNCHER_LIB_PATH)
src_launcher_path = lib_gen.launcher_libpath
if verbose:
self.logger.debug(f"Saving C++ launcher library to cache: {src_launcher_path}")
KernelCache._safe_write_file(
launcher_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_launcher_path))
)
# Optionally save launcher C++ source for debugging
if hasattr(kernel.adapter, "launcher_cpp_code") and kernel.adapter.launcher_cpp_code:
launcher_cpp_path = os.path.join(cache_path, LAUNCHER_CPP_PATH)
if verbose:
self.logger.debug(f"Saving C++ launcher source to: {launcher_cpp_path}")
KernelCache._safe_write_file(launcher_cpp_path, "w", lambda file: file.write(kernel.adapter.launcher_cpp_code))
else:
if self.execution_backend == "nvrtc": if self.execution_backend == "nvrtc":
kernel_lib_path = KERNEL_CUBIN_PATH kernel_lib_path = KERNEL_CUBIN_PATH
elif self.execution_backend == "tvm_ffi": elif self.execution_backend == "tvm_ffi":
kernel_lib_path = EXECUTABLE_PATH kernel_lib_path = EXECUTABLE_PATH
else: else:
kernel_lib_path = KERNEL_LIB_PATH kernel_lib_path = KERNEL_LIB_PATH
kernel_lib_path = os.path.join(cache_path, kernel_lib_path) kernel_lib_path = os.path.join(cache_path, kernel_lib_path)
# Save an extra Python file for NVRTC # Save an extra Python file for NVRTC
...@@ -327,7 +359,8 @@ class KernelCache: ...@@ -327,7 +359,8 @@ class KernelCache:
if verbose: if verbose:
self.logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}") self.logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}")
KernelCache._safe_write_file(kernel_py_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path))) KernelCache._safe_write_file(kernel_py_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path)))
elif self.execution_backend == "tvm_ffi":
if self.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable executable = kernel.adapter.executable
if verbose: if verbose:
self.logger.debug(f"Saving kernel executable to file: {executable}") self.logger.debug(f"Saving kernel executable to file: {executable}")
...@@ -338,8 +371,8 @@ class KernelCache: ...@@ -338,8 +371,8 @@ class KernelCache:
self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}") self.logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
KernelCache._safe_write_file(kernel_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path))) KernelCache._safe_write_file(kernel_lib_path, "wb", lambda file: file.write(KernelCache._load_binary(src_lib_path)))
except Exception as e: except Exception:
self.logger.error(f"Error saving kernel library to disk: {e}") self.logger.exception("Error saving kernel library to disk")
# Save kernel parameters # Save kernel parameters
try: try:
...@@ -347,19 +380,19 @@ class KernelCache: ...@@ -347,19 +380,19 @@ class KernelCache:
if verbose: if verbose:
self.logger.debug(f"Saving kernel parameters to disk: {params_path}") self.logger.debug(f"Saving kernel parameters to disk: {params_path}")
KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file)) KernelCache._safe_write_file(params_path, "wb", lambda file: cloudpickle.dump(kernel.params, file))
except Exception as e: except Exception:
self.logger.error(f"Error saving kernel parameters to disk: {e}") self.logger.exception("Error saving kernel parameters to disk")
def _load_kernel_from_disk( def _load_kernel_from_disk(
self, self,
key: str, key: str,
target: str | Target = "auto", target: str | Target = "auto",
target_host: str | Target = None, target_host: str | Target | None = None,
out_idx: list[int] = None, out_idx: list[int] | None = None,
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi", execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi",
pass_configs: dict = None, pass_configs: dict | None = None,
compile_flags: list[str] | str | None = None, compile_flags: list[str] | str | None = None,
func: Callable = None, func: Callable | None = None,
verbose: bool = False, verbose: bool = False,
) -> JITKernel | None: ) -> JITKernel | None:
""" """
...@@ -370,7 +403,7 @@ class KernelCache: ...@@ -370,7 +403,7 @@ class KernelCache:
target (Union[str, Target]): Compilation target platform. Defaults to "auto". target (Union[str, Target]): Compilation target platform. Defaults to "auto".
target_host (Union[str, Target], optional): Host target platform. target_host (Union[str, Target], optional): Host target platform.
out_idx (List[int], optional): Indices specifying which outputs to return. out_idx (List[int], optional): Indices specifying which outputs to return.
execution_backend (Literal): Backend type for execution. Defaults to "cython". execution_backend (Literal): Backend type for execution. Defaults to "tvm_ffi".
pass_configs (dict, optional): Configuration for compiler passes. pass_configs (dict, optional): Configuration for compiler passes.
func (Callable, optional): The original function. func (Callable, optional): The original function.
verbose (bool): Enable verbose log messages. verbose (bool): Enable verbose log messages.
...@@ -385,11 +418,21 @@ class KernelCache: ...@@ -385,11 +418,21 @@ class KernelCache:
kernel_lib_path = KERNEL_CUBIN_PATH kernel_lib_path = KERNEL_CUBIN_PATH
elif self.execution_backend == "tvm_ffi": elif self.execution_backend == "tvm_ffi":
kernel_lib_path = EXECUTABLE_PATH kernel_lib_path = EXECUTABLE_PATH
elif self.execution_backend == "cutedsl":
kernel_lib_path = KERNEL_PY_PATH
else: else:
kernel_lib_path = KERNEL_LIB_PATH kernel_lib_path = KERNEL_LIB_PATH
kernel_lib_path = os.path.join(cache_path, kernel_lib_path) kernel_lib_path = os.path.join(cache_path, kernel_lib_path)
params_path = os.path.join(cache_path, PARAMS_PATH) params_path = os.path.join(cache_path, PARAMS_PATH)
if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]):
# Check required files exist
required_files = [kernel_lib_path, params_path]
# For CuTeDSL, also check launcher library
if self.execution_backend == "cutedsl":
required_files.append(os.path.join(cache_path, LAUNCHER_LIB_PATH))
if not all([os.path.exists(file) for file in required_files]):
return None return None
device_kernel_source: str | None = None device_kernel_source: str | None = None
...@@ -397,20 +440,25 @@ class KernelCache: ...@@ -397,20 +440,25 @@ class KernelCache:
kernel_params: list[KernelParam] | None = None kernel_params: list[KernelParam] | None = None
# Load the kernel source file (optional) # Load the kernel source file (optional)
if self.execution_backend != "cutedsl":
try: try:
if verbose: if verbose:
self.logger.debug(f"Loading kernel source code from file: {device_kernel_path}") self.logger.debug(f"Loading kernel source code from file: {device_kernel_path}")
with open(device_kernel_path) as f: with open(device_kernel_path) as f:
device_kernel_source = f.read() device_kernel_source = f.read()
except Exception as e: except Exception:
self.logger.error(f"Error loading kernel source code from disk: {e}") self.logger.exception("Error loading kernel source code from disk")
try: try:
if verbose: if verbose:
self.logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}") self.logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}")
with open(host_kernel_path) as f: with open(host_kernel_path) as f:
host_kernel_source = f.read() host_kernel_source = f.read()
except Exception as e: except Exception:
self.logger.error(f"Error loading host kernel source code from disk: {e}") self.logger.exception("Error loading host kernel source code from disk")
else:
# For CuTeDSL, set empty strings since sources aren't loaded from cache
device_kernel_source = ""
host_kernel_source = ""
# Load kernel parameters # Load kernel parameters
try: try:
...@@ -418,10 +466,10 @@ class KernelCache: ...@@ -418,10 +466,10 @@ class KernelCache:
self.logger.debug(f"Loading kernel parameters from file: {params_path}") self.logger.debug(f"Loading kernel parameters from file: {params_path}")
with open(params_path, "rb") as f: with open(params_path, "rb") as f:
kernel_params = cloudpickle.load(f) kernel_params = cloudpickle.load(f)
except Exception as e: except Exception:
self.logger.error(f"Error loading kernel parameters from disk: {e}") self.logger.exception("Error loading kernel parameters from disk")
if host_kernel_source and device_kernel_source and kernel_params: if ((host_kernel_source and device_kernel_source) or self.execution_backend == "cutedsl") and kernel_params:
return JITKernel.from_database( return JITKernel.from_database(
func=func, func=func,
host_kernel_source=host_kernel_source, host_kernel_source=host_kernel_source,
...@@ -453,5 +501,5 @@ class KernelCache: ...@@ -453,5 +501,5 @@ class KernelCache:
# Re-create the cache directory # Re-create the cache directory
KernelCache._create_dirs() KernelCache._create_dirs()
except Exception as e: except Exception:
self.logger.error(f"Error clearing disk cache: {e}") self.logger.exception("Error clearing disk cache")
import cutlass
import cutlass.cute as cute
from cutlass._mlir.dialects import nvvm
from cutlass.cutlass_dsl import T
# re-export cutlass.cute.arch functions first
from cutlass.cute.arch import sync_threads # noqa: F401
from cutlass.cute.arch import alloc_smem, get_dyn_smem # noqa: F401
from cutlass.cute.arch import warpgroup_reg_alloc, warpgroup_reg_dealloc # noqa: F401
from cutlass.cute import make_tensor, make_rmem_tensor, recast_ptr # noqa: F401
from cutlass.cute.typing import Numeric
from cutlass.base_dsl.typing import as_numeric, Int32, Uint16, Uint32 # noqa: F401
from cutlass._mlir.dialects import llvm, arith # noqa: F401
from cutlass._mlir import ir as mlir_ir
from cutlass.cutlass_dsl import dsl_user_op
# Import our custom implementations (will override if names conflict)
from .mbar import *
from .cpasync import *
from .gemm_V1 import *
from .reduce import *
from .ldsm import *
from .math import *
from .threadblock_swizzle import *
# Forward nvvm enums
from cutlass._mlir.dialects.nvvm import (
MemOrderKind,
MemScopeKind,
AtomicOpKind,
)
BYTES_PER_TENSORMAP = 128
BYTES_PER_POINTER = 8
def make_filled_tensor(shape, value):
t = cute.make_rmem_tensor(shape, type(value))
t.fill(value)
return t
def make_tensor_at_offset(ptr: cute.Pointer, offset, shape, div_by=1):
if div_by != 1:
offset = cute.assume(cutlass.as_numeric(offset), divby=div_by)
return cute.make_tensor(ptr + offset, shape)
def shuffle_elect(thread_extent):
# thread_extent is the number of threads of a warpgroup
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
if thread_extent == 0:
return warp_idx == 0
else:
return (warp_idx % (thread_extent // 32)) == 0
def sync_thread_partial(barrier_id=None, thread_count=None):
bar_sync_ptx(barrier_id, thread_count)
# Packing functions
def pack_half2(x, y):
"""
Pack two half-precision (fp16) values into a single 32-bit value.
Corresponds to CUDA's __pack_half2 intrinsic.
This packs two fp16 values into a single int32 by treating the fp16 bits
as raw data and concatenating them.
"""
@dsl_user_op
def pack_half2_impl(x_val, y_val, *, loc=None, ip=None):
# Cast fp16 to uint16 (bitcast)
x_ir = x_val.ir_value(loc=loc, ip=ip) if hasattr(x_val, "ir_value") else x_val
y_ir = y_val.ir_value(loc=loc, ip=ip) if hasattr(y_val, "ir_value") else y_val
# Bitcast fp16 to i16
i16_type = mlir_ir.IntegerType.get_signless(16)
x_i16 = llvm.bitcast(i16_type, x_ir, loc=loc, ip=ip)
y_i16 = llvm.bitcast(i16_type, y_ir, loc=loc, ip=ip)
packed_xy = llvm.inline_asm(
Int32.mlir_type,
[x_i16, y_i16],
"mov.b32 $0, {$1, $2};",
"=r,h,h",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
return Int32(packed_xy)
return pack_half2_impl(x, y)
def AtomicAdd(ptr: cute.Pointer, value: Numeric, *, loc=None, ip=None):
if ptr.dtype == cutlass.Float32:
ret = nvvm.atomicrmw(
T.f32(),
AtomicOpKind.FADD,
ptr.llvm_ptr,
ptr.dtype(value).ir_value(loc=loc, ip=ip),
mem_order=MemOrderKind.RELAXED,
syncscope=MemScopeKind.GPU,
loc=loc,
ip=ip,
)
elif ptr.dtype == cutlass.Int32:
ret = nvvm.atomicrmw(
T.i32(),
AtomicOpKind.ADD,
ptr.llvm_ptr,
ptr.dtype(value).ir_value(loc=loc, ip=ip),
mem_order=MemOrderKind.RELAXED,
syncscope=MemScopeKind.GPU,
loc=loc,
ip=ip,
)
else:
raise ValueError(f"Unsupported dtype: {ptr.dtype}")
return ptr.dtype(ret)
from __future__ import annotations
from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op # noqa: F401
from cutlass._mlir.dialects import nvvm, cute_nvgpu # noqa: F401
from cutlass._mlir import ir
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
import cutlass.cute as cute
from cutlass.cute.typing import Int, Boolean, Int32, Int16, Uint64, Union # noqa: F401
from cutlass.impl_utils import check_value_in
from cutlass.cute.arch import cp_async_commit_group as cp_async_commit # noqa: F401
from cutlass.cute.arch import cp_async_wait_group as cp_async_wait # noqa: F401
BYTES_PER_TENSORMAP = 128
BYTES_PER_POINTER = 8
def cp_async_gs(size, dst, dst_offset, src, src_offset):
assert size in [16, 8, 4]
# use CG (cache global) to by pass L1 when loading contiguous 128B.
mode = nvvm.LoadCacheModifierKind.CG if size == 16 else nvvm.LoadCacheModifierKind.CA
if isinstance(src, cute.Tensor):
src_ptr = src.iterator
elif isinstance(src, cute.Pointer):
src_ptr = src
else:
raise ValueError(f"Invalid source type: {type(src)}")
if isinstance(dst, cute.Tensor):
dst_ptr = dst.iterator
elif isinstance(dst, cute.Pointer):
dst_ptr = dst
else:
raise ValueError(f"Invalid destination type: {type(dst)}")
cp_async_shared_global(dst_ptr + dst_offset, src_ptr + src_offset, size, mode)
@cute.jit
def cp_async_gs_conditional(size, dst, dst_offset, src, src_offset, cond):
if cond:
cp_async_gs(size, dst, dst_offset, src, src_offset)
@dsl_user_op
def extract_tensormap_ptr(tma_atom: cute.CopyAtom, *, loc=None, ip=None) -> cute.Pointer:
"""
extract the tensormap pointer from a TMA Copy Atom.
:param tma_atom: The TMA Copy Atom
:type tma_atom: CopyAtom
"""
exec_value = _cute_nvgpu_ir.atom_make_exec_tma(tma_atom._trait.value, loc=loc, ip=ip)
ptr_type = _cute_ir.PtrType.get(Uint64.mlir_type, _cute_ir.AddressSpace.generic, 64)
tensormap_ptr = _cute_nvgpu_ir.get_tma_desc_addr(ptr_type, exec_value, loc=loc, ip=ip)
return tensormap_ptr
@dsl_user_op
def tma_load(tma_desc, mbar: cute.Pointer, smem_ptr: cute.Pointer, crd: Int | tuple[Int, ...], *, loc=None, ip=None) -> None:
"""
Load data from global memory to shared memory using TMA (Tensor Memory Access).
:param tma_desc: TMA descriptor for the tensor
:type tma_desc: CopyAtom or tensormap_ptr or Tensor of tensormap_ptr
:param mbar: Mbarrier pointer in shared memory
:type mbar: Pointer
:param smem_ptr: Destination pointer in shared memory
:type smem_ptr: Pointer
:param crd: Coordinates tuple for the tensor access
:type crd: tuple[Int, ...]
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
if not isinstance(crd, tuple) and isinstance(tma_desc, cute.Pointer):
# Legacy signature: tma_load(smem_ptr, gmem_ptr, mbar, size)
_smem_ptr = tma_desc
_gmem_ptr = mbar
_mbar = smem_ptr
nvvm.cp_async_bulk_shared_cluster_global(
dst_mem=_smem_ptr.llvm_ptr,
src_mem=_gmem_ptr.llvm_ptr,
mbar=_mbar.llvm_ptr,
size=Int32(crd).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
else:
if isinstance(tma_desc, cute.CopyAtom):
tma_desc_ptr = extract_tensormap_ptr(tma_desc)
elif isinstance(tma_desc, cute.Tensor):
tma_desc_ptr = tma_desc.iterator
else:
tma_desc_ptr = tma_desc
nvvm.cp_async_bulk_tensor_shared_cluster_global(
dst_mem=smem_ptr.llvm_ptr,
tma_descriptor=tma_desc_ptr.llvm_ptr,
coordinates=[Int32(i).ir_value(loc=loc, ip=ip) for i in crd],
mbar=mbar.llvm_ptr,
im2col_offsets=[],
load_mode=nvvm.CpAsyncBulkTensorLoadMode.TILE,
group=nvvm.Tcgen05GroupKind.CTA_1,
use_intrinsic=False, # set to True would lead to compile error
loc=loc,
ip=ip,
)
@dsl_user_op
def tma_store(tma_desc, smem_ptr: cute.Pointer, crd: Int | tuple[Int, ...], *, loc=None, ip=None) -> None:
"""
Store data from shared memory to global memory using TMA (Tensor Memory Access).
:param tma_desc: TMA descriptor for the tensor
:type tma_desc: TMA descriptor
:param smem_ptr: Source pointer in shared memory
:type smem_ptr: Pointer
:param crd: Coordinates tuple for the tensor access
:type crd: tuple[Int, ...]
"""
arch = CuTeDSL._get_dsl().envar.arch
check_value_in(arch, ["sm_90", "sm_90a", "sm_100a"], "arch")
if not isinstance(crd, tuple):
if arch not in ("sm_90", "sm_90a"):
raise NotImplementedError("tma_store(size) path is only implemented for sm_90/sm_90a")
gmem_ptr = tma_desc.align(smem_ptr.alignment)
_cute_nvgpu_ir.arch_copy_SM90_bulk_copy_s2g(
dsmem_data_addr=smem_ptr.value,
gmem_data_addr=gmem_ptr.value,
size=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), crd),
loc=loc,
ip=ip,
)
else:
if isinstance(tma_desc, cute.CopyAtom):
tma_desc_ptr = extract_tensormap_ptr(tma_desc)
elif isinstance(tma_desc, cute.Tensor):
tma_desc_ptr = tma_desc.iterator
else:
tma_desc_ptr = tma_desc
nvvm.cp_async_bulk_tensor_global_shared_cta(
tma_descriptor=tma_desc_ptr.llvm_ptr,
src_mem=smem_ptr.llvm_ptr,
coordinates=[Int32(i).ir_value(loc=loc, ip=ip) for i in crd],
predicate=None,
loc=loc,
ip=ip,
)
@dsl_user_op
def tma_store_arrive(*, loc=None, ip=None) -> None:
"""
Indicate arrival of warp issuing TMA_STORE.
Corresponds to PTX instruction: cp.async.bulk.commit_group;
"""
nvvm.cp_async_bulk_commit_group(loc=loc, ip=ip)
@dsl_user_op
def tma_store_wait(count: int, *, read=None, loc=None, ip=None) -> None:
"""
Wait for TMA_STORE operations to complete.
Corresponds to PTX instruction: cp.async.bulk.wait_group.read <count>;
:param count: The number of outstanding bulk async groups to wait for
:type count: Int
"""
nvvm.cp_async_bulk_wait_group(group=count, read=read, loc=loc, ip=ip)
@dsl_user_op
def cp_async_shared_global(
dst: cute.Pointer, src: cute.Pointer, cp_size: Int, modifier: nvvm.LoadCacheModifierKind, *, src_size: Int = None, loc=None, ip=None
) -> None:
"""
Asynchronously copy data from global memory to shared memory.
:param dst: Destination pointer in shared memory
:type dst: Pointer
:param src: Source pointer in global memory
:type src: Pointer
:param size: Size of the copy in bytes
:type size: Int
:param modifier: Cache modifier
:type modifier: Int
:param cp_size: Optional copy size override
:type cp_size: Int
"""
size = src_size if src_size else cp_size
nvvm.cp_async_shared_global(
dst=dst.llvm_ptr,
src=src.llvm_ptr,
size=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), size),
modifier=modifier,
cp_size=Int32(cp_size).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
@dsl_user_op
def prefetch_tma_descriptor(tma_desc, *, loc=None, ip=None) -> None:
"""
Prefetch a TMA descriptor.
Corresponds to PTX instruction: prefetch.tensormap;
"""
if isinstance(tma_desc, cute.CopyAtom):
tma_desc_ptr = extract_tensormap_ptr(tma_desc)
elif isinstance(tma_desc, cute.Tensor):
tma_desc_ptr = tma_desc.iterator
else:
tma_desc_ptr = tma_desc
nvvm.prefetch_tensormap(tma_desc_ptr.llvm_ptr, loc=loc, ip=ip)
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils # noqa: F401
import math
import cutlass.utils.hopper_helpers as hopper_utils
from cutlass.utils import LayoutEnum
from cutlass.cute.nvgpu.warpgroup import OperandMajorMode, OperandSource, make_smem_layout_atom
def make_aligned_tensor(ptr: cute.Pointer, layout: cute.Layout, align_bytes: int, swizzle=False):
ptr = ptr.align(align_bytes)
if swizzle and isinstance(layout, cute.ComposedLayout):
ptr = cute.recast_ptr(ptr=ptr, swizzle_=layout.inner, dtype=ptr.dtype)
return cute.make_tensor(ptr, layout.outer)
return cute.make_tensor(ptr, layout)
def gemm_ss(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
use_wgmma=None,
wg_wait=0,
A_ptr: cute.Pointer = None,
B_ptr: cute.Pointer = None,
C_ptr: cute.Pointer = None,
):
"""GEMM with both A and B from shared memory"""
if use_wgmma:
gemm = Gemm_SM90(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
A_ptr.dtype,
B_ptr.dtype,
C_ptr.dtype,
)
gemm(A_ptr, B_ptr, C_ptr, wg_wait=wg_wait, clear_accum=clear_accum)
else:
gemm = Gemm_SM80(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
A_ptr.dtype,
B_ptr.dtype,
C_ptr.dtype,
)
gemm(A_ptr, B_ptr, C_ptr)
def gemm_rs(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
use_wgmma=None,
wg_wait=0,
A_ptr: cute.Pointer = None,
B_ptr: cute.Pointer = None,
C_ptr: cute.Pointer = None,
):
"""GEMM with A from register/fragment and B from shared memory"""
if use_wgmma:
gemm = Gemm_SM90(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
A_ptr.dtype,
B_ptr.dtype,
C_ptr.dtype,
)
gemm.body_rs(A_ptr, B_ptr, C_ptr, wg_wait=wg_wait, clear_accum=clear_accum)
else:
gemm = Gemm_SM80(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
A_ptr.dtype,
B_ptr.dtype,
C_ptr.dtype,
)
gemm.body_rs(A_ptr, B_ptr, C_ptr)
def gemm_sr(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
use_wgmma=None,
wg_wait=0,
A_ptr: cute.Pointer = None,
B_ptr: cute.Pointer = None,
C_ptr: cute.Pointer = None,
):
"""GEMM with A from shared memory and B from register/fragment"""
# wgmma doesn't support gemm_sr, only use SM80
gemm = Gemm_SM80(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
A_ptr.dtype,
B_ptr.dtype,
C_ptr.dtype,
)
gemm.body_sr(A_ptr, B_ptr, C_ptr)
def gemm_rr(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
use_wgmma=None,
wg_wait=0,
A_ptr: cute.Pointer = None,
B_ptr: cute.Pointer = None,
C_ptr: cute.Pointer = None,
):
"""GEMM with both A and B from register/fragment"""
# Both operands in register, no copy needed
gemm = Gemm_SM80(
M,
N,
K,
warp_m,
warp_n,
trans_A,
trans_B,
clear_accum,
stride_A,
stride_B,
offset_A,
offset_B,
A_ptr.dtype,
B_ptr.dtype,
C_ptr.dtype,
)
# For gemm_rr, directly call _body_impl with copy_A=False, copy_B=False
gemm._body_impl(A_ptr, B_ptr, C_ptr, copy_A=False, copy_B=False)
class Gemm_SM80:
_instances = {} # cache instances for the same arguments
def __new__(cls, *args):
key = args
if key not in cls._instances:
cls._instances[key] = super().__new__(cls)
return cls._instances[key]
# in Tilelang, trans_A == 0 or trans_B == 1 means K major
# in Cute, trans == 0 means K major
def __init__(
self, M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type
):
if not hasattr(self, "initialized"):
self.cta_tiler = (M, N, K)
self.mma_inst_shape = (16, 8, 16)
self.trans_A = trans_A != 0 # same with Tilelang
self.trans_B = trans_B == 0 # inverse with Tilelang
A_major_mode = LayoutEnum.COL_MAJOR if self.trans_A else LayoutEnum.ROW_MAJOR
B_major_mode = LayoutEnum.COL_MAJOR if self.trans_B else LayoutEnum.ROW_MAJOR
self.A_layout = self._make_smem_layout_AB(A_type, A_major_mode, 128, (M, K))
self.B_layout = self._make_smem_layout_AB(B_type, B_major_mode, 128, (N, K))
self.ab_dtype = A_type
self.acc_dtype = C_type
self.tiled_mma = self._make_tiled_mma(warp_m, warp_n)
self.clear_accum = clear_accum
def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler):
is_row_major = major_mode == LayoutEnum.ROW_MAJOR
major_mode_size = smem_tiler[1] if is_row_major else smem_tiler[0]
major_mode_size = 64 if major_mode_size >= 64 else major_mode_size
swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
swizzle_bits = min(swizzle_bits, 3)
layout_atom_outer = (
cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
if is_row_major
else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
)
layout_atom = cute.make_composed_layout(
cute.make_swizzle(swizzle_bits, 3, 3),
0,
layout_atom_outer,
)
layout = cute.tile_to_shape(layout_atom, smem_tiler, (0, 1) if is_row_major else (1, 0))
return layout
def _make_tiled_mma(self, warp_m, warp_n):
atom_layout_mnk = (warp_m, warp_n, 1)
op = cute.nvgpu.warp.MmaF16BF16Op(self.ab_dtype, self.acc_dtype, self.mma_inst_shape)
permutation_mnk = (
atom_layout_mnk[0] * self.mma_inst_shape[0],
atom_layout_mnk[1] * self.mma_inst_shape[1] * 2,
atom_layout_mnk[2] * self.mma_inst_shape[2],
)
tiled_mma = cute.make_tiled_mma(op, atom_layout_mnk, permutation_mnk)
return tiled_mma
@cute.jit
def __call__(
self,
sA_ptr: cute.Pointer,
sB_ptr: cute.Pointer,
rC_ptr: cute.Pointer,
):
"""GEMM body: both A and B from shared memory"""
self._body_impl(sA_ptr, sB_ptr, rC_ptr, copy_A=True, copy_B=True)
@cute.jit
def body_rs(
self,
rA_ptr: cute.Pointer, # A already in register
sB_ptr: cute.Pointer, # B from shared memory
rC_ptr: cute.Pointer,
):
"""GEMM body_rs: A from register, B from shared memory"""
self._body_impl(rA_ptr, sB_ptr, rC_ptr, copy_A=False, copy_B=True)
@cute.jit
def body_sr(
self,
sA_ptr: cute.Pointer, # A from shared memory
rB_ptr: cute.Pointer, # B already in register
rC_ptr: cute.Pointer,
):
"""GEMM body_sr: A from shared memory, B from register"""
self._body_impl(sA_ptr, rB_ptr, rC_ptr, copy_A=True, copy_B=False)
@cute.jit
def _body_impl(
self,
A_ptr: cute.Pointer,
B_ptr: cute.Pointer,
rC_ptr: cute.Pointer,
copy_A: cutlass.Constexpr = True,
copy_B: cutlass.Constexpr = True,
):
"""Internal implementation with configurable copy operations"""
tidx, _, _ = cute.arch.thread_idx()
thr_mma = self.tiled_mma.get_slice(tidx)
tCrA = None
tCrB = None
tCrC = cute.make_tensor(rC_ptr, self.tiled_mma.partition_shape_C((self.cta_tiler[0], self.cta_tiler[1])))
# Create copy operations only for operands that need copying
if cutlass.const_expr(copy_A):
sA = make_aligned_tensor(A_ptr, self.A_layout, 16)
tCsA = thr_mma.partition_A(sA)
tCrA = self.tiled_mma.make_fragment_A(tCsA)
atom_copy_s2r_A = cute.make_copy_atom(
cute.nvgpu.warp.LdMatrix8x8x16bOp(self.trans_A, 4),
sA.element_type,
)
tiled_copy_s2r_A = cute.make_tiled_copy(
atom_copy_s2r_A,
layout_tv=self.tiled_mma.tv_layout_A_tiled,
tiler_mn=(self.tiled_mma.get_tile_size(0), self.tiled_mma.get_tile_size(2)),
)
thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx)
tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA)
tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA)
else:
# A already in register
tCrA = cute.make_tensor(A_ptr, self.tiled_mma.partition_shape_A((self.cta_tiler[0], self.cta_tiler[2])))
if cutlass.const_expr(copy_B):
sB = make_aligned_tensor(B_ptr, self.B_layout, 16)
tCsB = thr_mma.partition_B(sB)
tCrB = self.tiled_mma.make_fragment_B(tCsB)
atom_copy_s2r_B = cute.make_copy_atom(
cute.nvgpu.warp.LdMatrix8x8x16bOp(self.trans_B, 4),
sB.element_type,
)
tiled_copy_s2r_B = cute.make_tiled_copy(
atom_copy_s2r_B,
layout_tv=self.tiled_mma.tv_layout_B_tiled,
tiler_mn=(self.tiled_mma.get_tile_size(1), self.tiled_mma.get_tile_size(2)),
)
thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx)
tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB)
tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB)
else:
# B already in register
tCrB = cute.make_tensor(B_ptr, self.tiled_mma.partition_shape_B((self.cta_tiler[1], self.cta_tiler[2])))
if self.clear_accum:
tCrC.fill(0)
for k in cutlass.range(cute.size(tCrA, mode=[2])):
if cutlass.const_expr(copy_A):
cute.copy(tiled_copy_s2r_A, tCsA_copy_view[None, None, k], tCrA_copy_view[None, None, k])
if cutlass.const_expr(copy_B):
cute.copy(tiled_copy_s2r_B, tCsB_copy_view[None, None, k], tCrB_copy_view[None, None, k])
cute.gemm(self.tiled_mma, tCrC, tCrA[None, None, k], tCrB[None, None, k], tCrC)
class Gemm_SM90:
_instances = {} # cache instances for the same arguments
def __new__(cls, *args):
key = args
if key not in cls._instances:
cls._instances[key] = super().__new__(cls)
return cls._instances[key]
# in Tilelang, trans_A == 0 or trans_B == 1 means K major
# in Cute, trans == 0 means K major
def __init__(
self, M, N, K, warp_m, warp_n, trans_A, trans_B, clear_accum, stride_A, stride_B, offset_A, offset_B, A_type, B_type, C_type
):
if not hasattr(self, "initialized"):
self.cta_tiler = (M, N, K)
self.tiler_mn = (M, N)
self.atom_layout_mnk = (warp_m // 4, warp_n, 1)
self.trans_A = trans_A != 0 # same with Tilelang
self.trans_B = trans_B == 0 # inverse with Tilelang
self.a_leading_mode = OperandMajorMode.MN if self.trans_A else OperandMajorMode.K
self.b_leading_mode = OperandMajorMode.MN if self.trans_B else OperandMajorMode.K
A_major_mode = LayoutEnum.COL_MAJOR if self.trans_A else LayoutEnum.ROW_MAJOR
B_major_mode = LayoutEnum.COL_MAJOR if self.trans_B else LayoutEnum.ROW_MAJOR
self.A_layout = self.make_smem_layout_AB(A_type, A_major_mode, (M, K))
self.B_layout = self.make_smem_layout_AB(B_type, B_major_mode, (N, K))
self.a_dtype = A_type
self.b_dtype = B_type
self.acc_dtype = C_type
self.tiled_mma = None
self.A_source = None
self.clear_accum = clear_accum
@staticmethod
def make_tma_atom(
tensor,
smem_layout_staged,
smem_tile,
mcast_dim,
):
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() if mcast_dim == 1 else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp()
smem_layout = cute.slice_(smem_layout_staged, (None, None, 0))
tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
op,
tensor,
smem_layout,
smem_tile,
num_multicast=mcast_dim,
)
return tma_atom
@staticmethod
def get_tma_atom(tensor, tiler_mk, stages=1):
smem_layout_staged = Gemm_SM90.make_smem_layout_AB(tensor.element_type, LayoutEnum.from_tensor(tensor), tiler_mk, stages)
tma_atom = Gemm_SM90.make_tma_atom(tensor, smem_layout_staged, tiler_mk, 1)
return tma_atom
@staticmethod
def make_smem_layout_AB(dtype, major_mode: LayoutEnum, tiler_mk, stages=1):
smem_shape = tiler_mk
# Determine if K is the major mode and get the major mode size
is_k_major = major_mode.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K
major_mode_size = tiler_mk[1] if is_k_major else tiler_mk[0]
# Create SMEM layout atom for A tensor based on major mode and data type
smem_layout_atom = make_smem_layout_atom(
hopper_utils.get_smem_layout_atom(major_mode, dtype, major_mode_size),
dtype,
)
# Tile the SMEM layout atom to the A tensor shape and add staging dimension
smem_layout = cute.tile_to_shape(smem_layout_atom, cute.append(smem_shape, stages), order=(0, 1, 2) if is_k_major else (1, 0, 2))
return smem_layout
def _make_tiled_mma(self, is_rsMode=False):
tiled_mma = hopper_utils.make_trivial_tiled_mma(
self.a_dtype,
self.b_dtype,
self.a_leading_mode,
self.b_leading_mode,
self.acc_dtype,
self.atom_layout_mnk,
(64, self.tiler_mn[1] // self.atom_layout_mnk[1]),
OperandSource.SMEM if not is_rsMode else OperandSource.RMEM,
)
return tiled_mma
@cute.jit
def __call__(
self,
sA_ptr: cute.Pointer,
sB_ptr: cute.Pointer,
rC_ptr: cute.Pointer,
wg_wait: cutlass.Constexpr = 0,
clear_accum: cutlass.Constexpr = False,
):
tidx, _, _ = cute.arch.thread_idx()
self.tiled_mma = self._make_tiled_mma()
thr_mma = self.tiled_mma.get_slice(tidx)
sA_ptr = cute.recast_ptr(sA_ptr, self.A_layout.inner, dtype=sA_ptr.dtype)
sB_ptr = cute.recast_ptr(sB_ptr, self.B_layout.inner, dtype=sB_ptr.dtype)
sA = cute.make_tensor(sA_ptr, self.A_layout.outer)
sB = cute.make_tensor(sB_ptr, self.B_layout.outer)
tCsA = thr_mma.partition_A(sA)
tCsB = thr_mma.partition_B(sB)
tCrA = self.tiled_mma.make_fragment_A(tCsA)
tCrB = self.tiled_mma.make_fragment_B(tCsB)
tCrC = cute.make_tensor(rC_ptr, self.tiled_mma.partition_shape_C((self.cta_tiler[0], self.cta_tiler[1])))
cute.nvgpu.warpgroup.fence()
if cutlass.const_expr(clear_accum):
self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
else:
self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
num_k_blocks = cute.size(tCrA, mode=[2])
for k in cutlass.range(num_k_blocks):
tCrA_1phase = tCrA[None, None, k, 0]
tCrB_1phase = tCrB[None, None, k, 0]
cute.gemm(self.tiled_mma, tCrC, tCrA_1phase, tCrB_1phase, tCrC)
self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
cute.nvgpu.warpgroup.commit_group()
if cutlass.const_expr(wg_wait >= 0):
cute.nvgpu.warpgroup.wait_group(wg_wait)
@cute.jit
def body_rs(
self,
rA_ptr: cute.Pointer, # A already in register (Fragment)
sB_ptr: cute.Pointer, # B from shared memory
rC_ptr: cute.Pointer,
wg_wait: cutlass.Constexpr = 0,
clear_accum: cutlass.Constexpr = False,
):
"""
GEMM body_rs for SM90/Hopper: A from register, B from shared memory.
Based on cute::tl_wgmma::GemmTensorOp::body_rs from gemm_sm90.h
"""
tidx, _, _ = cute.arch.thread_idx()
self.tiled_mma = self._make_tiled_mma(is_rsMode=True)
# if self.A_source != OperandSource.RMEM or self.tiled_mma is None:
# self.tiled_mma = self._make_tiled_mma(is_rsMode = True)
# self.A_source = OperandSource.RMEM
# B from shared memory (with swizzle)
sB_ptr = cute.recast_ptr(sB_ptr, self.B_layout.inner, dtype=sB_ptr.dtype)
sB = cute.make_tensor(sB_ptr, self.B_layout.outer)
# Use the existing tiled_mma
thr_mma = self.tiled_mma.get_slice(tidx)
# Partition B from shared memory - standard path
tCsB = thr_mma.partition_B(sB)
tCrB = self.tiled_mma.make_fragment_B(tCsB)
# A already in register
# For body_rs, A is NOT partitioned through thr_mma (it's already partitioned)
# We create the tensor directly with the full shape
# This matches C++: make_tensor(make_rmem_ptr(pA), partition_shape_A(...))
tCrA = cute.make_tensor(rA_ptr, self.tiled_mma.partition_shape_A((self.cta_tiler[0], self.cta_tiler[2])))
# C accumulator
tCrC = cute.make_tensor(rC_ptr, self.tiled_mma.partition_shape_C((self.cta_tiler[0], self.cta_tiler[1])))
# Fence operands (prepare for wgmma)
cute.nvgpu.warpgroup.fence()
# Note: warpgroup_arrive() is called internally by wgmma
# Set accumulation mode
if cutlass.const_expr(clear_accum):
self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False)
else:
self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
# GEMM loop
num_k_blocks = cute.size(tCrB, mode=[2])
for k_block in cutlass.range(num_k_blocks):
# Match the indexing pattern from __call__
# If tCrB has 4 dimensions (with pipeline), use [None, None, k, 0]
# Otherwise use [None, None, k]
tCrB_k = tCrB[None, None, k_block, 0] if cute.rank(tCrB) >= 4 else tCrB[None, None, k_block]
tCrA_k = tCrA[None, None, k_block, 0] if cute.rank(tCrA) >= 4 else tCrA[None, None, k_block]
cute.gemm(self.tiled_mma, tCrC, tCrA_k, tCrB_k, tCrC)
self.tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True)
cute.nvgpu.warpgroup.commit_group()
if cutlass.const_expr(wg_wait >= 0):
cute.nvgpu.warpgroup.wait_group(wg_wait)
"""
LDMATRIX and STMATRIX operations for CuTeDSL backend.
Based on tl_templates/cuda/ldsm.h
These functions provide wrappers around PTX ldmatrix/stmatrix instructions
for loading/storing 8x8 matrix fragments between shared memory and registers.
"""
from cutlass.cutlass_dsl import T, dsl_user_op
from cutlass._mlir.dialects import nvvm, llvm
from cutlass._mlir import ir # noqa: F401
from cutlass.cute.typing import Pointer, Int32 # noqa: F401
import cutlass.cute as cute
def _to_ir_value(v, loc=None, ip=None):
"""Convert value to MLIR IR, handling both cutlass types and raw MLIR Values"""
if hasattr(v, "ir_value"):
return v.ir_value(loc=loc, ip=ip)
else:
# Already an MLIR Value
return v
def _ldmatrix(smem_ptr, local_ptr, num, transpose, loc=None, ip=None):
"""Internal helper for ldmatrix operations"""
layout = nvvm.MMALayout.col if transpose else nvvm.MMALayout.row
assert num in [2, 4]
ret_type = llvm.StructType.get_literal([T.i32()] * num)
out_i32 = nvvm.ldmatrix(ret_type, smem_ptr.llvm_ptr, num=num, layout=layout, loc=loc, ip=ip)
out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), num)
for i in range(num):
out[i] = cute.Int32(llvm.extractvalue(T.i32(), out_i32, [i], loc=loc, ip=ip))
def _stmatrix(smem_ptr, values, transpose, loc=None, ip=None):
"""Internal helper for stmatrix operations"""
layout = nvvm.MMALayout.col if transpose else nvvm.MMALayout.row
ir_values = [_to_ir_value(v, loc, ip) for v in values]
nvvm.stmatrix(smem_ptr.llvm_ptr, ir_values, layout=layout, loc=loc, ip=ip)
# ============================================================================
# LDMATRIX operations (load from shared memory to registers)
# ============================================================================
@dsl_user_op
def ptx_ldmatrix_x1(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None:
"""Load 1 matrix (8x8) from shared memory"""
# _ldmatrix(smem_ptr, local_ptr, 1, False, loc, ip)
out_i32 = nvvm.ldmatrix(T.i32(), smem_ptr.llvm_ptr, num=1, layout=nvvm.MMALayout.row, loc=loc, ip=ip)
out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), 1)
out[0] = cute.Int32(out_i32)
@dsl_user_op
def ptx_ldmatrix_x2(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None:
"""Load 2 matrices (8x8 each) from shared memory"""
_ldmatrix(smem_ptr, local_ptr, 2, False, loc, ip)
@dsl_user_op
def ptx_ldmatrix_x4(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None:
"""Load 4 matrices (8x8 each) from shared memory"""
_ldmatrix(smem_ptr, local_ptr, 4, False, loc, ip)
@dsl_user_op
def ptx_ldmatrix_x1_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None:
"""Load 1 matrix (8x8) with transpose from shared memory"""
out_i32 = nvvm.ldmatrix(T.i32(), smem_ptr.llvm_ptr, num=1, layout=nvvm.MMALayout.col, loc=loc, ip=ip)
out = cute.make_tensor(cute.recast_ptr(local_ptr, dtype=cute.Int32), 1)
out[0] = cute.Int32(out_i32)
@dsl_user_op
def ptx_ldmatrix_x2_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None:
"""Load 2 matrices (8x8 each) with transpose from shared memory"""
_ldmatrix(smem_ptr, local_ptr, 2, True, loc, ip)
@dsl_user_op
def ptx_ldmatrix_x4_trans(smem_ptr: Pointer, local_ptr: Pointer, *, loc=None, ip=None) -> None:
"""Load 4 matrices (8x8 each) with transpose from shared memory"""
_ldmatrix(smem_ptr, local_ptr, 4, True, loc, ip)
# ============================================================================
# STMATRIX operations (store from registers to shared memory)
# ============================================================================
@dsl_user_op
def ptx_stmatrix_x1(smem_ptr: Pointer, value0, *, loc=None, ip=None) -> None:
"""Store 1 matrix (8x8) to shared memory"""
_stmatrix(smem_ptr, [value0], False, loc, ip)
@dsl_user_op
def ptx_stmatrix_x2(smem_ptr: Pointer, value0, value1, *, loc=None, ip=None) -> None:
"""Store 2 matrices (8x8 each) to shared memory"""
_stmatrix(smem_ptr, [value0, value1], False, loc, ip)
@dsl_user_op
def ptx_stmatrix_x4(smem_ptr: Pointer, value0, value1, value2, value3, *, loc=None, ip=None) -> None:
"""Store 4 matrices (8x8 each) to shared memory"""
_stmatrix(smem_ptr, [value0, value1, value2, value3], False, loc, ip)
@dsl_user_op
def ptx_stmatrix_x1_trans(smem_ptr: Pointer, value0, *, loc=None, ip=None) -> None:
"""Store 1 matrix (8x8) with transpose to shared memory"""
_stmatrix(smem_ptr, [value0], True, loc, ip)
@dsl_user_op
def ptx_stmatrix_x2_trans(smem_ptr: Pointer, value0, value1, *, loc=None, ip=None) -> None:
"""Store 2 matrices (8x8 each) with transpose to shared memory"""
_stmatrix(smem_ptr, [value0, value1], True, loc, ip)
@dsl_user_op
def ptx_stmatrix_x4_trans(smem_ptr: Pointer, value0, value1, value2, value3, *, loc=None, ip=None) -> None:
"""Store 4 matrices (8x8 each) with transpose to shared memory"""
_stmatrix(smem_ptr, [value0, value1, value2, value3], True, loc, ip)
import cutlass.cute as cute
from cutlass.cute.typing import Union, Numeric
from cutlass.cute.tensor import TensorSSA
from cutlass._mlir.dialects import arith
from cutlass.cute.math import exp, exp2, log, log2, log10, tan, cos, sin, sqrt # noqa: F401
def divf(x: Union[TensorSSA, Numeric], y: Union[TensorSSA, Numeric], fastmath: bool = False) -> Union[TensorSSA, Numeric]:
return cute.math._math_op(arith.divf, fastmath, x, y)
"""
Simple wrappers that delegate to cutlass.cute.arch implementations.
We use the existing implementations from cutlass rather than reinventing the wheel.
"""
from cutlass.cute.typing import Pointer, Int, Int32, Boolean # noqa: F401
from cutlass.cutlass_dsl import CuTeDSL, dsl_user_op # noqa: F401
from cutlass._mlir.dialects import nvvm
from cutlass.cute.arch import mbarrier_init, mbarrier_expect_tx, mbarrier_arrive # noqa: F401
from cutlass.cute.arch import mbarrier_arrive_and_expect_tx as arrive_and_expect_tx # noqa: F401
from cutlass.cute.arch import cp_async_mbarrier_arrive_noinc as mbarrier_cp_async_arrive_noinc # noqa: F401
import cutlass.cute.arch as arch
@dsl_user_op
def mbarrier_wait(mbar_ptr: Pointer, phase: Int, timeout_ns: Int = 10000000, *, loc=None, ip=None) -> None:
"""Waits on a mbarrier with a specified phase."""
nvvm.mbarrier_try_wait_parity_shared(
mbar_ptr.llvm_ptr,
Int32(phase).ir_value(loc=loc, ip=ip),
Int32(timeout_ns).ir_value(loc=loc, ip=ip),
loc=loc,
ip=ip,
)
@dsl_user_op
def mbarrier_cp_async_arrive(mbar_ptr: Pointer, *, loc=None, ip=None) -> None:
mbar_llvm_ptr = mbar_ptr.llvm_ptr
nvvm.cp_async_mbarrier_arrive_shared(
mbar_llvm_ptr,
noinc=False,
loc=loc,
ip=ip,
)
def fence_proxy_async():
arch.fence_proxy(arch.ProxyKind.async_shared, space=arch.SharedSpace.shared_cta)
def fence_barrier_init():
arch.mbarrier_init_fence()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment