"tests/test_class_sh_inheritance.cpp" did not exist on "d7efc9b8b1c10c52097d140eba359dd617b1138b"
Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_c_host.cc
*/
#include "codegen_c_host.h"
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/codegen.h>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
// For escaping strings embedded into generated C sources
#include "support/str_escape.h"
namespace tvm {
namespace tl {
CodeGenCHost::CodeGenCHost() {
module_name_ = name_supply_->FreshName(tvm::ffi::symbol::tvm_ffi_library_ctx);
}
void CodeGenCHost::Init(bool output_ssa, bool emit_asserts,
bool emit_fwd_func_decl, std::string target_str,
const std::unordered_set<std::string> &devices) {
emit_asserts_ = emit_asserts;
emit_fwd_func_decl_ = emit_fwd_func_decl;
declared_globals_.clear();
decl_stream << "// tilelang target: " << target_str << "\n";
decl_stream << "#define TVM_EXPORTS\n";
decl_stream << "#include \"tvm/runtime/base.h\"\n";
decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n";
decl_stream << "#include \"tvm/ffi/c_api.h\"\n";
decl_stream << "#include <math.h>\n";
// snprintf for richer assert messages with actual values
decl_stream << "#include <stdio.h>\n";
decl_stream << "#include <stdbool.h>\n";
CodeGenCHost::InitGlobalContext();
tvm::codegen::CodeGenC::Init(output_ssa);
}
void CodeGenCHost::InitGlobalContext() {
decl_stream << "void* " << tvm::ffi::symbol::tvm_ffi_library_ctx
<< " = NULL;\n";
}
void CodeGenCHost::DefineModuleName() {
decl_stream << "void* " << module_name_ << " = NULL;\n";
}
void CodeGenCHost::AddFunction(const tvm::GlobalVar &gvar,
const tvm::tir::PrimFunc &func) {
return AddFunction(gvar, func, /*emit_fwd_func_decl=*/false);
}
void CodeGenCHost::AddFunction(const tvm::GlobalVar &gvar,
const tvm::tir::PrimFunc &func,
bool emit_fwd_func_decl) {
auto global_symbol =
func->GetAttr<tvm::ffi::String>(tvm::attr::kGlobalSymbol);
if (global_symbol) {
function_names_.push_back(global_symbol.value());
}
emit_fwd_func_decl_ = emit_fwd_func_decl;
tvm::codegen::CodeGenC::AddFunction(gvar, func);
if (func->HasNonzeroAttr(tvm::tir::attr::kIsEntryFunc) && !has_main_func_) {
ICHECK(global_symbol.has_value())
<< "CodeGenCHost: The entry func must have the global_symbol "
"attribute, "
<< "but function " << gvar << " only has attributes " << func->attrs;
function_names_.push_back(tvm::ffi::symbol::tvm_ffi_main);
stream << "// CodegenC: NOTE: Auto-generated entry function\n";
PrintFuncPrefix(stream);
PrintType(func->ret_type, stream);
stream << " " << tvm::ffi::symbol::tvm_ffi_main
<< "(void* self, void* args,int num_args, void* result) {\n";
stream << " return " << static_cast<std::string>(global_symbol.value())
<< "(self, args, num_args, result);\n";
stream << "}\n";
has_main_func_ = true;
}
}
void CodeGenCHost::GenerateForwardFunctionDeclarations(
tvm::ffi::String global_symbol, const tvm::ffi::Array<tvm::Type> &arg_types,
const tvm::Type &ret_type) {
if (!emit_fwd_func_decl_) {
return;
}
for (auto &func_already_defined : GetFunctionNames()) {
if (global_symbol == func_already_defined) {
return;
}
}
this->PrintFuncPrefix(fwd_decl_stream);
this->PrintType(ret_type, fwd_decl_stream);
fwd_decl_stream << " " << global_symbol << "(";
for (size_t i = 0; i < arg_types.size(); ++i) {
if (i > 0) {
fwd_decl_stream << ", ";
}
tvm::codegen::CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream);
}
fwd_decl_stream << ");\n";
}
void CodeGenCHost::PrintFuncPrefix(std::ostream &os) { // NOLINT(*)
os << "#ifdef __cplusplus\n"
<< "extern \"C\"\n"
<< "#endif\n";
}
void CodeGenCHost::PrintType(tvm::DataType t, std::ostream &os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
ICHECK_EQ(lanes, 1) << "does not support vector types";
os << "void*";
return;
}
if (t.is_void()) {
os << "void";
return;
}
if (t == tvm::DataType::Bool()) {
os << "bool";
return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16:
os << "half";
break;
case 32:
os << "float";
break;
case 64:
os << "double";
break;
default:
fail = true;
break;
}
if (!fail && lanes == 1)
return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes;
return;
}
}
if (t.is_bfloat16()) {
os << "__bf16";
return;
}
if (t.is_int() || t.is_uint()) {
if (t.is_uint()) {
os << 'u';
}
switch (t.bits()) {
case 8:
os << "int8_t";
break;
case 16:
os << "int16_t";
break;
case 32:
os << "int32_t";
break;
case 64:
os << "int64_t";
break;
case 1:
os << "int32_t";
break;
default:
fail = true;
break;
}
if (!fail && lanes == 1)
return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes;
return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to C type";
}
void CodeGenCHost::VisitExpr_(const tvm::tir::BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
int lanes = op->dtype.lanes();
os << "((";
PrintType(op->dtype, os);
os << ")(";
for (int i = 0; i < lanes; ++i) {
if (i != 0)
os << ", ";
os << v;
}
os << "))";
}
void CodeGenCHost::PrintGetFuncFromBackend(
const std::string &func_name, const std::string &packed_func_name) {
this->PrintIndent();
this->stream << "if (" << packed_func_name << " == NULL) {\n";
int packed_func_if_scope = this->BeginScope();
this->PrintIndent();
this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \""
<< func_name << "\""
<< ", &" << packed_func_name << ") != 0) {\n";
int get_func_env_scope = this->BeginScope();
this->PrintIndent();
this->stream << "return -1;\n";
this->EndScope(get_func_env_scope);
this->PrintIndent();
this->stream << "}\n";
this->EndScope(packed_func_if_scope);
this->PrintIndent();
this->stream << "}\n";
}
void CodeGenCHost::PrintCallPacked(const tvm::tir::CallNode *op) {
using namespace tvm::tir;
const StringImmNode *func_name = op->args[0].as<StringImmNode>();
ICHECK(func_name != nullptr)
<< "tvm_call_[c]packed_lowered expects first argument as function name";
int64_t begin = op->args[2].as<IntImmNode>()->value;
int64_t end = op->args[3].as<IntImmNode>()->value;
int64_t num_args = end - begin;
ICHECK_GE(num_args, 0);
std::string packed_func_name;
if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
packed_func_name = GetPackedName(op);
this->PrintGetFuncFromBackend(func_name->value, packed_func_name);
} else {
// directly use the original symbol
ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered()));
packed_func_name =
tvm::ffi::symbol::tvm_ffi_symbol_prefix + func_name->value;
}
std::string args_stack = PrintExpr(op->args[1]);
this->PrintIndent();
std::string result = name_supply_->FreshName("result");
this->stream << "TVMFFIAny " << result << ";\n";
this->PrintIndent();
// must make sure type_index is set to none
this->stream << result << ".type_index = kTVMFFINone;\n";
this->PrintIndent();
this->stream << result << ".zero_padding = 0;\n";
this->PrintIndent();
this->stream << result << ".v_int64 = 0;\n";
this->PrintIndent();
if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
this->stream << "if (TVMFFIFunctionCall(" << packed_func_name << ", ";
} else {
this->stream << "if (" << packed_func_name << "(NULL, ";
}
this->stream << "(TVMFFIAny*) " << args_stack << ", " << num_args << ", "
<< "&" << result << ") != 0) {\n";
int func_call_scope = this->BeginScope();
this->PrintIndent();
this->stream << "return -1;\n";
this->EndScope(func_call_scope);
this->PrintIndent();
this->stream << "}\n";
}
std::string CodeGenCHost::GetPackedName(const tvm::tir::CallNode *op) {
using namespace tvm::tir;
const StringImmNode *s = op->args[0].as<StringImmNode>();
ICHECK(s != nullptr)
<< "tvm_call_packed_lowered expects first argument as function name";
std::string func_name = s->value;
std::string packed_func_name = func_name + "_packed";
std::string unique_name;
auto it = declared_globals_.find(packed_func_name);
if (it != declared_globals_.end()) {
unique_name = it->second;
} else {
unique_name = name_supply_->FreshName(packed_func_name);
declared_globals_[packed_func_name] = unique_name;
decl_stream << "static void* " << unique_name << " = NULL;\n";
}
return unique_name;
}
void CodeGenCHost::VisitExpr_(const tvm::tir::CallNode *op,
std::ostream &os) { // NOLINT(*)
using namespace tvm::tir;
if (op->op.same_as(builtin::tvm_stack_alloca())) {
std::string stack_name = name_supply_->FreshName("stack");
const std::string &type = op->args[0].as<StringImmNode>()->value;
const IntImmNode *num = op->args[1].as<IntImmNode>();
ICHECK(num != nullptr);
static_assert(alignof(TVMFFIAny) % alignof(DLTensor) == 0, "invariant");
size_t unit = sizeof(TVMFFIAny);
size_t size = 0;
if (type == "shape") {
size = (num->value * sizeof(ffi::Shape::index_type) + unit - 1) / unit;
} else if (type == "tvm_ffi_any") {
size = (num->value * sizeof(TVMFFIAny) + unit - 1) / unit;
} else if (type == "array") {
size = (num->value * sizeof(DLTensor) + unit - 1) / unit;
} else {
LOG(FATAL) << "Unknown stack alloca type " << type;
}
this->PrintIndent();
this->stream << "TVMFFIAny " << stack_name << "[" << size << "];\n";
os << stack_name;
} else if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
this->PrintCallPacked(op);
} else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) {
this->PrintCallPacked(op);
} else if (op->op.same_as(builtin::tvm_throw_last_error())) {
this->PrintIndent();
this->stream << "return -1;\n";
} else {
tvm::codegen::CodeGenC::VisitExpr_(op, os);
}
}
void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*)
if (emit_asserts_) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (!(" << cond << ")) {\n";
int assert_if_scope = this->BeginScope();
{
// Prepare the base error message: allow StringImm or general PrimExpr
const auto *msg_node = op->message.as<tvm::tir::StringImmNode>();
bool msg_is_literal = (msg_node != nullptr);
std::string esc_msg;
std::string msg_expr;
if (msg_is_literal) {
const std::string &raw_msg = msg_node->value;
esc_msg = tvm::support::StrEscape(
raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true,
/*escape_whitespace_special_chars=*/true);
} else {
msg_expr = PrintExpr(op->message);
}
// Only print expected/got values for equality when message is StringImm
if (msg_is_literal) {
if (const auto *eq = op->condition.as<tvm::tir::EQNode>()) {
std::string lhs = PrintExpr(eq->a);
std::string rhs = PrintExpr(eq->b);
PrintIndent();
stream << "char __tvm_assert_msg_buf[512];\n";
PrintIndent();
stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; expected: %lld, "
"got: %lld\", \""
<< esc_msg << "\", (long long)(" << lhs << "), (long long)("
<< rhs << "));\n";
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", "
"__tvm_assert_msg_buf);\n";
} else {
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \""
<< esc_msg << "\");\n";
}
} else {
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", " << msg_expr
<< ");\n";
}
}
PrintIndent();
stream << "return -1;\n";
this->EndScope(assert_if_scope);
PrintIndent();
stream << "}\n";
}
this->PrintStmt(op->body);
}
void CodeGenCHost::VisitExpr_(const tvm::tir::MinNode *op,
std::ostream &os) { // NOLINT(*)
PrintTernaryCondExpr(op, "<", os);
}
void CodeGenCHost::VisitExpr_(const tvm::tir::MaxNode *op,
std::ostream &os) { // NOLINT(*)
PrintTernaryCondExpr(op, ">", os);
}
template <typename T>
inline void CodeGenCHost::PrintTernaryCondExpr(const T *op, const char *compare,
std::ostream &os) { // NOLINT(*)
std::ostringstream temp_a;
VisitExpr(op->a, temp_a);
std::string a_id = SSAGetID(temp_a.str(), op->a.dtype());
std::ostringstream temp_b;
VisitExpr(op->b, temp_b);
std::string b_id = SSAGetID(temp_b.str(), op->b.dtype());
os << "((" << a_id << ") " << compare << " (" << b_id << ") "
<< "? (" << a_id << ") : (" << b_id << "))";
}
} // namespace tl
} // namespace tvm
namespace tvm {
namespace tl {
using tvm::codegen::CodeGenSourceBase;
using tvm::codegen::CSourceModuleCreate;
using tvm::ffi::Array;
using tvm::ffi::Map;
using tvm::ffi::Module;
using tvm::ffi::String;
// Build function that mirrors TVM's host C codegen, registered under a
// TileLang-specific name.
::tvm::ffi::Module BuildTileLangCHost(::tvm::IRModule mod,
::tvm::Target target) {
bool output_ssa = false;
bool emit_asserts = true;
bool emit_fwd_func_decl = true;
std::unordered_set<std::string> devices;
if (mod->GetAttr<::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String>>(
"device_contexts") != nullptr) {
::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String> device_contexts =
mod->GetAttr<::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String>>(
"device_contexts")
.value();
for (auto const &context : device_contexts) {
devices.insert(context.second.data());
}
}
CodeGenCHost cg;
cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
cg.SetConstantsByteAlignment(
target->GetAttr<::tvm::Integer>("constants-byte-alignment").value_or(16));
auto is_aot_executor_fn = [](::tvm::tir::PrimFunc const &func) -> bool {
return func->GetAttr<::tvm::Bool>("runner_function", ::tvm::Bool(false))
.value();
};
std::vector<std::pair<::tvm::GlobalVar, ::tvm::tir::PrimFunc>> funcs;
for (auto [gvar, base_func] : mod->functions) {
ICHECK(base_func->IsInstance<::tvm::tir::PrimFuncNode>())
<< "CodegenCHost: Can only take PrimFunc";
auto prim_func = ::tvm::Downcast<::tvm::tir::PrimFunc>(base_func);
funcs.push_back({gvar, prim_func});
}
auto sort_key = [&is_aot_executor_fn](const auto &kv) {
return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint};
};
std::sort(funcs.begin(), funcs.end(),
[&sort_key](const auto &kv_a, const auto &kv_b) {
return sort_key(kv_a) < sort_key(kv_b);
});
for (const auto &[gvar, prim_func] : funcs) {
cg.DeclareFunction(gvar, prim_func);
}
for (const auto &[gvar, prim_func] : funcs) {
cg.AddFunction(gvar, prim_func, emit_fwd_func_decl);
}
std::string code = cg.Finish();
return ::tvm::codegen::CSourceModuleCreate(code, "c", cg.GetFunctionNames());
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.tilelang_c", BuildTileLangCHost);
}
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_c_host.h
* \brief Generate C host code (TileLang copy).
*/
#ifndef TL_TARGET_SOURCE_CODEGEN_C_HOST_H_
#define TL_TARGET_SOURCE_CODEGEN_C_HOST_H_
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "target/source/codegen_c.h"
#include "tvm/target/codegen.h"
#include "tvm/tir/expr.h"
namespace tvm {
namespace tl {
// TileLang copy of TVM's CodeGenCHost, under the tl namespace.
// Inherits from tvm::codegen::CodeGenC.
class CodeGenCHost : public tvm::codegen::CodeGenC {
public:
CodeGenCHost();
void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl,
std::string target_str,
const std::unordered_set<std::string> &devices);
void InitGlobalContext();
void AddFunction(const tvm::GlobalVar &gvar,
const tvm::tir::PrimFunc &f) override;
void AddFunction(const tvm::GlobalVar &gvar, const tvm::tir::PrimFunc &f,
bool emit_fwd_func_decl);
/*!
* \brief Add functions from the (unordered) range to the current module in a
* deterministic order. This helps with debugging.
*
* \param functions A vector of unordered range of current module.
*/
void AddFunctionsOrdered(
std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> functions);
void DefineModuleName();
using tvm::codegen::CodeGenC::PrintType;
void PrintType(tvm::DataType t, std::ostream &os) final; // NOLINT(*)
void PrintFuncPrefix(std::ostream &os) final; // NOLINT(*)
// overload visitor functions
void VisitExpr_(const tvm::tir::BroadcastNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const tvm::tir::CallNode *op,
std::ostream &os) override; // NOLINT(*)
// overload min and max to use the ternary operator, so we don't rely on the
// standard library implementations
void VisitExpr_(const tvm::tir::MinNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const tvm::tir::MaxNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitStmt_(const tvm::tir::AssertStmtNode *op) final; // NOLINT(*)
void GenerateForwardFunctionDeclarations(
tvm::ffi::String global_symbol,
const tvm::ffi::Array<tvm::Type> &arg_types,
const tvm::Type &ret_type) override;
tvm::ffi::Array<tvm::ffi::String> GetFunctionNames() {
return function_names_;
}
private:
std::string module_name_;
/* \brief mapping global packed func to the unique name */
std::unordered_map<std::string, std::string> declared_globals_;
/* \brief names of the functions declared in this module */
tvm::ffi::Array<tvm::ffi::String> function_names_;
/*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_;
/*! \brief whether to emit forwared function declarations in the resulting C
* code */
bool emit_fwd_func_decl_;
/*! \brief whether to generate the entry function if encountered */
bool has_main_func_ = false;
std::string GetPackedName(const tvm::tir::CallNode *op);
void PrintGetFuncFromBackend(const std::string &func_name,
const std::string &packed_func_name);
void PrintCallPacked(const tvm::tir::CallNode *op);
/*!
* \brief Print ternary conditional operator implementing binary `op`
* Forces the operands to be in SSA form.
* \param op binary operator being expressed
* \param compare string representation of comparison operator
* \param os stream reference to print into
*/
template <typename T>
inline void PrintTernaryCondExpr(const T *op, const char *compare,
std::ostream &os); // NOLINT(*)
};
} // namespace tl
} // namespace tvm
#endif // TL_TARGET_SOURCE_CODEGEN_C_HOST_H_
......@@ -29,6 +29,7 @@
#include <unordered_set>
#include <utility>
#include "../op/builtin.h"
#include "../support/ffi_aliases.h"
#include "support/str_escape.h"
#include "target/build_common.h"
......@@ -203,12 +204,12 @@ void CodeGenTileLangCPP::PrintFuncCall(const std::string &packed_func_name,
this->PrintIndent();
std::string ret_val = name_supply_->FreshName("ret_val");
std::string ret_type_code = name_supply_->FreshName("ret_type_code");
this->stream << "TVMValue " << ret_val << ";\n";
this->stream << "TVMFFIAny " << ret_val << ";\n";
this->PrintIndent();
this->stream << "int " << ret_type_code << ";\n";
this->PrintIndent();
this->stream << "if (TVMFuncCall(" << packed_func_name << ", "
<< "(TVMValue*) stack_value"
<< "(TVMFFIAny*) stack_value"
<< ", "
<< "(int*) stack_tcode"
<< ", " << num_args << ", "
......@@ -228,13 +229,13 @@ void CodeGenTileLangCPP::PrintFuncCallC(
this->PrintIndent();
std::string ret_val = name_supply_->FreshName("ret_val");
std::string ret_type_code = name_supply_->FreshName("ret_type_code");
this->stream << "TVMValue " << ret_val << ";\n";
this->stream << "TVMFFIAny " << ret_val << ";\n";
this->PrintIndent();
this->stream << "int " << ret_type_code << ";\n";
this->PrintIndent();
this->stream << "if (" << packed_func_name << "( "
<< "(TVMValue*) stack_value "
<< "(TVMFFIAny*) stack_value "
<< ", "
<< "(int*) stack_tcode"
<< ", " << num_args << ", "
......@@ -260,6 +261,12 @@ void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) {
ICHECK(global_symbol)
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
std::unordered_set<const VarNode *> non_restrict;
if (auto opt =
f->GetAttr<ffi::Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
for (const tir::Var &v : opt.value())
non_restrict.insert(v.get());
}
this->PrintFuncPrefix(stream);
CodeGenC::PrintType(f->ret_type, stream);
......@@ -294,7 +301,7 @@ void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) {
}
}
if (no_alias) {
if (no_alias && !non_restrict.count(v.get())) {
PrintRestrict(v, stream);
}
} else {
......
......@@ -107,7 +107,7 @@ struct CUDAIEEEMath {
}
};
static std::string GetFP8Type(DataType type) {
static std::string GetTileLangFP8Type(DataType type) {
std::stringstream stream;
int32_t lanes = type.lanes();
std::string vec;
......@@ -131,16 +131,17 @@ static std::string GetFP8Type(DataType type) {
if (type.is_float8_e4m3fn() || type.is_float8_e4m3fnuz() ||
type.is_float8_e4m3()) {
stream << "fp8_e4" << vec << "_t";
} else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() ||
type.is_float8_e5m2()) {
} else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz()) {
stream << "fp8_e5" << vec << "_t";
} else if (type.is_float8_e8m0fnu()) {
stream << "fp8_e8" << vec << "_t";
} else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen but got " << type;
}
return stream.str();
}
std::string GetFP6Type(DataType type) {
std::string GetTileLangFP6Type(DataType type) {
std::stringstream stream;
int32_t lanes = type.lanes();
std::string vec;
......@@ -171,32 +172,37 @@ std::string GetFP6Type(DataType type) {
return stream.str();
}
std::string GetFP4Type(DataType type) {
std::string GetTileLangFP4Type(DataType type) {
std::stringstream stream;
int32_t lanes = type.lanes();
std::string vec;
if (type.is_scalar()) {
vec = "";
} else if (lanes == 2) {
vec = "x2";
vec = "_2";
} else if (lanes == 4) {
vec = "x4";
vec = "_4";
} else if (lanes == 8) {
vec = "x8";
vec = "_8";
} else if (lanes == 16) {
vec = "x16";
vec = "_16";
} else if (lanes == 32) {
vec = "_32";
} else if (lanes == 64) {
vec = "_64";
} else {
LOG(FATAL)
<< "Only support scalar and vector types of width (2, 4) for FP4";
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16, "
"32, 64) for FP4";
}
stream << "__nv_fp4";
std::string suffix;
if (type.code() == DataType::kFloat4_e2m1fn) {
suffix = "_e2m1";
suffix = "_e2";
} else {
LOG(FATAL) << "Unsupported FP4 type in CUDA codegen";
}
stream << vec << suffix;
stream << "fp4" << suffix << vec << "_t";
return stream.str();
}
......@@ -278,6 +284,9 @@ std::string CodeGenTileLangCUDA::Finish() {
if (enable_fp8_) {
decl_stream << "#include <tl_templates/cuda/cuda_fp8.h>\n";
}
if (enable_fp4_) {
decl_stream << "#include <tl_templates/cuda/cuda_fp4.h>\n";
}
if (need_math_constants_h_) {
decl_stream << "#include <math_constants.h>\n";
......@@ -287,6 +296,10 @@ std::string CodeGenTileLangCUDA::Finish() {
decl_stream << "#include <cooperative_groups.h>\n";
}
if (need_curand_kernel_h_) {
decl_stream << "#include <curand_kernel.h>\n";
}
decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
if (enable_sparse_gemm_) {
decl_stream << "#include <tl_templates/cuda/gemm_sp.h>\n";
......@@ -312,8 +325,13 @@ std::string CodeGenTileLangCUDA::Finish() {
void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) {
if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent();
if (unroll_factor.count(op->loop_var.get())) {
stream << "#pragma unroll "
<< PrintExpr(unroll_factor[op->loop_var.get()]) << "\n";
} else {
stream << "#pragma unroll\n";
}
}
std::string extent =
PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
PrintIndent();
......@@ -432,18 +450,20 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
return;
} else if (t.is_float8()) {
enable_fp8_ = true;
os << GetFP8Type(t);
os << GetTileLangFP8Type(t);
return;
} else if (t.is_float6()) {
enable_fp6_ = true;
if (t.lanes() <= 4) {
os << GetFP6Type(t);
os << GetTileLangFP6Type(t);
}
return;
} else if (t.is_float4()) {
enable_fp4_ = true;
if (t.lanes() <= 4) {
os << GetFP4Type(t);
if (t.lanes() <= 64) {
os << GetTileLangFP4Type(t);
} else {
fail = true;
}
return;
} else if (t == DataType::Bool()) {
......@@ -660,7 +680,9 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t,
}
static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < 256 / t.bits());
ICHECK(i >= 0 && i < 256 / t.bits())
<< "i: " << i << " t: " << t << " t.bits(): " << t.bits()
<< " t.lanes(): " << t.lanes();
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
std::string type_name = t.is_int() ? "char" : "unsigned char";
if (t.lanes() == 2 || t.lanes() == 3) {
......@@ -702,6 +724,22 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t,
os << "." << access[(i % 8) / 4];
// fp8_e5_4_t or fp8_e5_2_t
os << "." << access[i % 4];
} else if (t.is_float4_e2m1fn()) {
os << vec;
// fp4_e2_64_t
if (t.lanes() >= 64)
os << "." << access[i / 32];
// fp4_e2_32_t
if (t.lanes() >= 32)
os << "." << access[(i % 32) / 16];
// fp4_e2_16_t
if (t.lanes() >= 16)
os << "." << access[(i % 16) / 8];
// fp4_e2_8_t
if (t.lanes() >= 8)
os << "." << access[(i % 8) / 4];
// fp4_e2_4_t or fp4_e2_2_t
os << "." << access[i % 4];
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
......@@ -805,6 +843,22 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t,
ICHECK(!type_name.empty());
stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
<< ")))->" << access[i % 2] << " = " << value << ";\n";
} else if (t.is_float4_e2m1fn()) {
stream << vec;
// fp4_e2_64_t
if (t.lanes() >= 64)
stream << "." << access[i / 32];
// fp4_e2_32_t
if (t.lanes() >= 32)
stream << "." << access[(i % 32) / 16];
// fp4_e2_16_t
if (t.lanes() >= 16)
stream << "." << access[(i % 16) / 8];
// fp4_e2_8_t
if (t.lanes() >= 8)
stream << "." << access[(i % 8) / 4];
// fp4_e2_4_t or fp4_e2_2_t
stream << "." << access[i % 4] << " = " << value << ";\n";
} else {
stream << vec << "." << access[i] << " = " << value << ";\n";
}
......@@ -1073,8 +1127,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}
// Handle conversion from float32 to float8 (E4M3/E5M2)
if (from_ty.is_float() &&
(target_ty.is_float8_e4m3() || target_ty.is_float8_e5m2())) {
if (from_ty.is_float() && (target_ty.is_float8())) {
bool target_type_is_e4m3 = target_ty.is_float8_e4m3() ||
target_ty.is_float8_e4m3fn() ||
target_ty.is_float8_e4m3fnuz();
// FP32 -> FP8: Use __nv_cvt_float2_to_fp8x2 for vectorized conversion
// (float2 -> fp8x2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
......@@ -1083,8 +1139,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
stream << "*reinterpret_cast<__nv_fp8x2_storage_t*>(&(" << sret
<< ")) = __nv_cvt_float2_to_fp8x2(*reinterpret_cast<float2*>(&("
<< src << ")), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
......@@ -1093,14 +1148,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
<< "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
<< ")), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+1), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
......@@ -1109,25 +1162,79 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
<< "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
<< ")), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+1), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[2] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+2), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[3] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+3), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
os << sret;
return;
}
}
if (from_ty.is_float8() && target_ty.is_float()) {
bool from_type_is_e4m3 = from_ty.is_float8_e4m3() ||
from_ty.is_float8_e4m3fn() ||
from_ty.is_float8_e4m3fnuz();
// FP8 -> FP32: Use __tl_cvt_fp8x2_to_float2 for vectorized conversion
// (fp8x2 -> float2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
// fp8x2 -> float2
PrintIndent();
stream << "*reinterpret_cast<float2*>(&(" << sret
<< ")) = "
"__tl_cvt_fp8x2_to_float2(*reinterpret_cast<__nv_fp8x2_storage_"
"t*>(&("
<< src << ")), " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
// fp8x4 -> float4
PrintIndent();
stream << "*(float2*)(&" << sret << ") = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[0], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+1) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[1], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// fp8x8 -> float8
PrintIndent();
stream << "*(float2*)(&" << sret << ") = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[0], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+1) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[1], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+2) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[2], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+3) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[3], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
......@@ -1297,7 +1404,7 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
return os.str();
}
std::string index_str = PrintExpr(index);
if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
if ((t.bits() == 4 && !t.is_float4()) || (t.bits() == 1 && t.is_int())) {
// This is a special case, because CodegenCUDA::PrintType()
// returns "int" for bool and for 4-bit integers. In most cases,
// we divide by the number of lanes to determine the index.
......@@ -1645,10 +1752,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(tl::sync_grid())) {
this->need_cooperative_groups_ = true;
this->PrintIndent();
this->stream << "cooperative_groups::grid_group grid = "
"cooperative_groups::this_grid();\n";
this->PrintIndent();
this->stream << "grid.sync();\n";
this->stream << "cooperative_groups::this_grid().sync();\n";
} else if (op->op.same_as(tl::loop_break())) {
this->PrintIndent();
this->stream << "break;\n";
......@@ -2352,6 +2456,23 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr
<< ")), \"r\"((int)" << guard << ")\n";
stream << ");\n";
} else if (op->op.same_as(tl::__ldg())) {
// Explicit read-only cached load. Preferred form: __ldg(BufferLoad(...)).
// Fallback form: __ldg(buffer, index)
const BufferLoadNode *bl = nullptr;
if (!op->args.empty()) {
bl = op->args[0].as<BufferLoadNode>();
}
if (bl == nullptr) {
LOG(FATAL) << "T.__ldg expects a BufferLoad as the first argument.";
}
const BufferNode *buffer = bl->buffer.get();
ICHECK_EQ(bl->indices.size(), 1)
<< "T.__ldg currently supports flattened 1D buffer accesses.";
PrimExpr base = bl->indices[0];
// Emit __ldg(&buffer_ref)
auto buffer_ref = this->GetBufferRef(op->dtype, buffer, base);
os << "__ldg(&(" << buffer_ref << "))";
} else if (op->op.same_as(builtin::reinterpret())) {
DataType tgt_dtype = op->dtype;
DataType src_dtype = op->args[0]->dtype;
......@@ -2612,6 +2733,30 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string func_name = math_func(op->dtype, "fdiv", rounding_mode);
os << func_name << "(" << PrintExpr(op->args[0]) << ", "
<< PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(tl::rng_init())) {
this->need_curand_kernel_h_ = true;
this->curand_philox_state = name_supply_->FreshName("__philox_state");
this->PrintIndent();
this->stream << "curandStatePhilox4_32_10_t " << this->curand_philox_state
<< ";\n";
this->PrintIndent();
this->stream << "curand_init(" << PrintExpr(op->args[0]) << ", "
<< PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2])
<< ", &" << this->curand_philox_state << ");\n";
// Store state_var for later use by rng_rand
} else if (op->op.same_as(tl::rng_rand())) {
this->need_curand_kernel_h_ = true;
os << "curand(&" << this->curand_philox_state << ")";
} else if (op->op.same_as(tl::warp_reduce_sum())) {
os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_max())) {
os << "tl::warp_reduce_max(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_min())) {
os << "tl::warp_reduce_min(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_bitand())) {
os << "tl::warp_reduce_bitand(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_bitor())) {
os << "tl::warp_reduce_bitor(" << PrintExpr(op->args[0]) << ")";
} else {
CodeGenC::VisitExpr_(op, os);
}
......@@ -2654,7 +2799,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) {
this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
this->VisitStmt(op->body);
return;
} else if (op->attr_key == "pragma_unroll_factor") {
const IntImmNode *factor = op->value.as<IntImmNode>();
ICHECK(factor);
unroll_factor[op->node.as<VarNode>()] = Downcast<IntImm>(factor);
}
CodeGenC::VisitStmt_(op);
}
......@@ -2798,7 +2948,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op,
} else {
bool can_vector_load = false;
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
// For sub-byte types with lanes > 1 in element_dtype, adjust the ramp
// pattern
int ramp_lanes = (element_dtype.lanes() > 1 && element_dtype.bits() < 8)
? value_dtype.lanes() / element_dtype.lanes()
: value_dtype.lanes();
if (arith::ramp(base, 1, ramp_lanes).Match(index)) {
const RampNode *ramp = index.as<RampNode>();
ICHECK(ramp);
can_vector_load = true;
......@@ -2810,11 +2965,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op,
// }
}
if (value_dtype.is_float4_e2m1fn() && lanes != 1) {
// A float4_e2m1fn element has 4 bits, which is an incomplete byte.
// So we cannot vector load it.
can_vector_load = false;
}
if (can_vector_load) {
std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
HandleVolatileLoads(ref, op, os);
......@@ -2848,6 +2998,69 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op,
}
}
void CodeGenTileLangCUDA::VisitStmt_(const BufferStoreNode *op) {
ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer store is not supported.";
DataType value_dtype = op->value.dtype();
DataType element_dtype = op->buffer->dtype;
PrimExpr index_expr = op->indices[0];
Var buffer_var = op->buffer->data;
if (value_dtype.lanes() == element_dtype.lanes()) {
std::string value = this->PrintExpr(op->value);
std::string ref =
this->GetBufferRef(value_dtype, op->buffer.get(), index_expr);
this->PrintIndent();
stream << ref << " = " << value << ";\n";
} else {
arith::PVar<PrimExpr> base;
// For sub-byte types with lanes > 1 in element_dtype, adjust the ramp
// pattern
int ramp_lanes = (element_dtype.lanes() > 1 && element_dtype.bits() < 8)
? value_dtype.lanes() / element_dtype.lanes()
: value_dtype.lanes();
if (arith::ramp(base, 1, ramp_lanes).Match(index_expr)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value);
} else {
// The assignment below introduces side-effect, and the resulting value
// cannot be reused across multiple expression, thus a new scope is needed
int vec_scope = BeginScope();
// store elements separately
std::string index = SSAGetID(PrintExpr(index_expr), index_expr.dtype());
std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype());
std::string vid = GetVarID(buffer_var.get());
for (int i = 0; i < value_dtype.lanes(); ++i) {
this->PrintIndent();
DataType elem_type = value_dtype.element_of();
if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
stream << "((";
if (buffer_var.get()->dtype.is_handle()) {
auto it = alloc_storage_scope_.find(buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
}
PrintType(elem_type, stream);
stream << "*)" << vid << ')';
} else {
stream << vid;
}
stream << '[';
PrintVecElemLoad(index, index_expr.dtype(), i, stream);
stream << "] = ";
PrintVecElemLoad(value, op->value.dtype(), i, stream);
stream << ";\n";
}
EndScope(vec_scope);
}
}
}
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
......@@ -3212,6 +3425,20 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
CodeGenC::PrintType(func->ret_type, os);
CodeGenC::PrintExtraAttrs(func, os);
bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
std::unordered_set<const VarNode *> non_restrict;
if (auto opt =
func->GetAttr<ffi::Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
for (const tir::Var &v : opt.value())
non_restrict.insert(v.get());
}
// Read-only param indices attribute, if present.
std::unordered_set<int> ro_param_indices;
if (auto opt =
func->GetAttr<ffi::Array<Integer>>("tl.readonly_param_indices")) {
for (const auto &idx : opt.value()) {
ro_param_indices.insert(static_cast<int>(Downcast<Integer>(idx)->value));
}
}
os << " " << function_name << "(";
for (size_t i = 0; i < func->params.size(); ++i) {
tir::Var v = func->params[i];
......@@ -3236,7 +3463,10 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, os);
}
// If marked read-only, emit const qualifier before type.
if (ro_param_indices.count(static_cast<int>(i))) {
os << "const ";
}
CodeGenC::PrintType(GetType(v), os);
if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
......@@ -3244,7 +3474,7 @@ void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
}
}
if (no_alias) {
if (no_alias && !non_restrict.count(v.get())) {
PrintRestrict(v, os);
}
} else {
......@@ -3280,6 +3510,19 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
ICHECK(global_symbol)
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
std::unordered_set<const VarNode *> non_restrict;
if (auto opt =
f->GetAttr<ffi::Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
for (const tir::Var &v : opt.value())
non_restrict.insert(v.get());
}
// Read-only param indices attribute, if present.
std::unordered_set<int> ro_param_indices;
if (auto opt = f->GetAttr<ffi::Array<Integer>>("tl.readonly_param_indices")) {
for (const auto &idx : opt.value()) {
ro_param_indices.insert(static_cast<int>(Downcast<Integer>(idx)->value));
}
}
this->PrintFuncPrefix(stream);
CodeGenC::PrintType(f->ret_type, stream);
......@@ -3307,7 +3550,10 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
// If marked read-only, emit const qualifier before type.
if (ro_param_indices.count(static_cast<int>(i))) {
stream << "const ";
}
CodeGenC::PrintType(GetType(v), stream);
if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
......@@ -3315,7 +3561,7 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
}
}
if (no_alias) {
if (no_alias && !non_restrict.count(v.get())) {
PrintRestrict(v, stream);
}
} else {
......
......@@ -57,6 +57,7 @@ public:
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final;
void VisitExpr_(const BufferLoadNode *op, std::ostream &os) final;
void VisitStmt_(const BufferStoreNode *op) final;
// Override this as a work around for __grid_constant__ parameter
void AddFunction(const GlobalVar &gvar, const PrimFunc &f);
......@@ -87,6 +88,8 @@ private:
std::string vid_global_barrier_state_;
// Global barrier expected node.
std::string vid_global_barrier_expect_;
// Global curand state
std::string curand_philox_state;
// whether enable fp16
bool enable_fp16_{false};
......@@ -122,6 +125,8 @@ private:
bool need_cast_smem_ptr_to_int_{false};
// whether need cooperative_groups.h
bool need_cooperative_groups_{false};
// whether need curand_kernel.h
bool need_curand_kernel_h_{false};
// Op attribute map
OpAttrMap<bool> op_need_warp_shuffle_ =
Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
......@@ -140,6 +145,7 @@ private:
std::unordered_map<const VarNode *, std::string> fragment_shapes;
std::unordered_map<const VarNode *, std::string> fragment_layouts;
std::unordered_map<const VarNode *, IntImm> unroll_factor;
friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);
void PrintWmmaScope(const std::string &scope, DataType t,
......
/*!
* \file target/codegen_cutedsl.cc
*/
#include "codegen_cutedsl.h"
#include "codegen_utils.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ir/transform.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
#include <cmath>
#include <string>
#include <utility>
#include <vector>
#include "../op/builtin.h"
#include "arith/pattern_match.h"
namespace tvm {
namespace codegen {
namespace {
// The threshold of the loop extent to use cutlass.range_constexpr
// Higher values would lead to DSLOptimizationWarning:
// This static loop has 128 iterations, which may be very slow to compile,
// consider using `cutlass.range(..., unroll_full=True)` instead.
const int64_t LOOP_UNROLL_THRESHOLD = 64;
void ReplaceAll(std::string &str, const std::string &from,
const std::string &to) {
ICHECK(!from.empty()) << "ReplaceAll(): `from` must be non-empty";
auto pos = str.find(from);
while (pos != std::string::npos) {
str.replace(pos, from.size(), to);
pos = str.find(from, pos + to.size());
}
}
} // namespace
CodeGenTileLangCuTeDSL::CodeGenTileLangCuTeDSL() {
// Read fastmath configuration from current PassContext
auto pass_ctx = tvm::transform::PassContext::Current();
// Read tl.enable_fast_math config, default to false
enable_fastmath_ =
pass_ctx->GetConfig<Bool>(tl::kEnableFastMath, Bool(false)).value();
}
std::string CodeGenTileLangCuTeDSL::CanonicalizeFastmathFunctionName_(
const std::string &func_name) const {
static const std::unordered_map<std::string, std::string> kFastMathMap = {
{"divf", "tl.divf"}, {"exp", "tl.exp"}, {"expf", "tl.exp"},
{"exp2", "tl.exp2"}, {"exp2f", "tl.exp2"}, {"log", "tl.log"},
{"logf", "tl.log"}, {"log2", "tl.log2"}, {"log2f", "tl.log2"},
{"log10", "tl.log10"}, {"tan", "tl.tan"}, {"cos", "tl.cos"},
{"sin", "tl.sin"}, {"sqrt", "tl.sqrt"}, {"sqrtf", "tl.sqrt"},
};
auto it = kFastMathMap.find(func_name);
if (it != kFastMathMap.end()) {
return it->second;
}
return "";
}
void CodeGenTileLangCuTeDSL::PrintFuncDecorator_(
std::ostream &os) { // NOLINT(*)
os << "@cute.kernel\n";
}
void CodeGenTileLangCuTeDSL::PreFunctionBody_(const PrimFunc &f) {
PrintIndent();
stream << "threadIdx = tl.ThreadIdx()" << "\n";
PrintIndent();
stream << "blockIdx = tl.BlockIdx()" << "\n";
}
namespace {
std::string DTypeToString(DataType t) {
ICHECK(t.is_scalar()) << "unsupported type " << t;
if (t.is_void()) {
return "void";
}
if (t == tl::cuTensorMapType()) {
return "CUtensorMap";
}
int bits = t.bits();
std::string elem_type;
if (t.is_float()) {
if (bits == 16 || bits == 32 || bits == 64) {
elem_type = "Float" + std::to_string(bits);
}
} else if (t.is_bfloat16()) {
elem_type = "BFloat16";
} else if (t.is_float8()) {
if (t.is_float8_e3m4()) {
// unsupported
} else if (t.is_float8_e4m3()) {
elem_type =
"Float8E4M3FN"; // Only Float8E4M3FN is supported at the moment
} else if (t.is_float8_e4m3b11fnuz()) {
// unsupported
} else if (t.is_float8_e4m3fn()) {
elem_type = "Float8E4M3FN";
} else if (t.is_float8_e4m3fnuz()) {
// unsupported
} else if (t.is_float8_e5m2()) {
elem_type = "Float8E5M2";
} else if (t.is_float8_e5m2fnuz()) {
// unsupported
} else if (t.is_float8_e8m0fnu()) {
elem_type = "Float8E8M0FNU";
}
} else if (t.is_float6()) {
if (t.is_float6_e3m2fn()) {
elem_type = "Float6E3M2FN";
} else if (t.is_float6_e2m3fn()) {
elem_type = "Float6E2M3FN";
}
} else if (t.is_float4()) {
if (t.is_float4_e2m1fn()) {
elem_type = "Float4E2M1FN";
}
} else if (t.is_bool()) {
elem_type = "Boolean";
} else if (t.is_uint()) {
if (bits == 8 || bits == 16 || bits == 32 || bits == 64 || bits == 128) {
elem_type = "Uint" + std::to_string(bits);
}
} else if (t.is_int()) {
if (bits == 4 || bits == 8 || bits == 16 || bits == 32 || bits == 64 ||
bits == 128) {
elem_type = "Int" + std::to_string(bits);
}
}
if (elem_type.empty()) {
LOG(FATAL) << "Cannot convert type " << t << " to CuTeDSL type!";
}
return "cutlass." + elem_type;
}
} // namespace
void CodeGenTileLangCuTeDSL::PrintType(DataType t,
std::ostream &os) { // NOLINT(*)
CHECK(t.is_scalar()) << "Should not print a non-scalar type in CuTeDSL: "
<< t;
os << DTypeToString(t);
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
os << "tl.make_filled_tensor((" << PrintExpr_(op->lanes) << ",), "
<< PrintExpr_(op->value) << ").load()";
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const FloatImmNode *op,
std::ostream &os) { // NOLINT(*)
switch (op->dtype.bits()) {
case 64:
case 32:
case 16:
case 8:
case 4: {
std::ostringstream temp;
if (std::isinf(op->value)) {
// For CuTeDSL, use Python's float('inf') instead of CUDA macros
PrintType(op->dtype, temp);
temp << "(";
if (op->value < 0) {
temp << "float('-inf')";
} else {
temp << "float('inf')";
}
temp << ")";
} else if (std::isnan(op->value)) {
// For CuTeDSL, use Python's float('nan')
PrintType(op->dtype, temp);
temp << "(float('nan'))";
} else {
// For CuTeDSL, use Python's float.fromhex() with hexfloat for full
// precision
PrintType(op->dtype, temp);
temp << "(float.fromhex('" << std::hexfloat << op->value << "'))";
}
MarkConst(temp.str());
os << temp.str();
break;
}
default:
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
}
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const CastNode *op,
std::ostream &os) { // NOLINT(*)
DataType from_ty = op->value.dtype();
DataType target_ty = op->dtype;
ICHECK_EQ(target_ty.lanes(), from_ty.lanes());
if (from_ty.is_scalar())
return CodeGenTileLangPY::VisitExpr_(op, os);
// Emit this as vectorized unary ops.
std::string sret = name_supply_->FreshName("_");
PrintIndent();
stream << sret << " = tl.make_rmem_tensor((" << target_ty.lanes() << ",), ";
PrintType(target_ty.element_of(), stream);
stream << ")\n";
std::string src = SSAGetID(PrintExpr_(op->value), from_ty);
PrintIndent();
stream << sret << ".store(" << src << ".to(";
PrintType(target_ty.element_of(), stream);
stream << "))\n";
os << sret << ".load()";
return;
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const DivNode *op,
std::ostream &os) { // NOLINT(*)
if (op->dtype.is_int() || op->dtype.is_uint()) {
PrintBinaryExpr_("//", op->dtype, op->a, op->b, os);
} else {
if (enable_fastmath_) {
os << "tl.divf(" << PrintExpr_(op->a) << ", " << PrintExpr_(op->b)
<< ", fastmath=True)";
} else {
PrintBinaryExpr_("tl.divf", op->dtype, op->a, op->b, os);
}
}
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const MinNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("tl.min", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const MaxNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("tl.max", op->dtype, op->a, op->b, os);
}
/**
* @brief Emit CuTeDSL-specific code for a call expression.
*
* This visitor handles CallNode intrinsics and builtins that require emitting
* CuTeDSL-specific code (inline PTX/ASM sequences, TensorLanguage runtime
* calls, WMMA/TMA helpers, barriers, cp.async primitives, index-map based
* stores, reinterpret/packing helpers, and various mma/ldmatrix patterns). The
* function writes the generated code to the provided output stream and falls
* back to the Python codegen for unrecognized calls.
*
* The method recognizes and emits code for (non-exhaustive): cp.async and its
* commit/wait variants, tma_load/store and im2col variants, ptX
* ldmatrix/stmatrix helpers, mbarrier APIs, cooperative grid sync, WMMA/legacy
* MMA intrinsics (fill/load/store/mma/bmma/ptx_mma/ptx_mma_sp), low-level PTX
* asm helpers (ldg32, cp_async bulk/init/arrive/wait barriers), reinterpret
* paths for special small-float encodings (e.g., float4 e2m1fn), tl::tl_gemm
* and related external calls, and other TL runtime calls.
*
* Side effects:
* - Emits to `os` and the internal codegen output stream.
* - May set internal feature flags (e.g., need_cooperative_groups_).
* - May open/close SSA scopes and mutate internal variable mappings.
* - May call LOG(FATAL) / CHECK / ICHECK on invalid or unsupported argument
* patterns.
*
* @param op The call node to generate code for; the function inspects op->op
* and op->args to determine the appropriate emission.
* @param os Output stream to receive expression-level output when the caller
* expects an expression result (some paths write directly to the
* member stream instead).
*/
void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op,
std::ostream &os) { // NOLINT(*)
auto print_extern_call_stmt = [&](std::string name, size_t start = 0,
size_t end = 0) {
// Cache context into a private ss, otherwise the let node may generate
// within the function call arguments.
std::ostringstream ss;
for (size_t i = start; i < op->args.size() - end; i++) {
if (i > start)
ss << ", ";
ss << PrintExpr_(op->args[i]);
}
PrintIndent();
stream << name << "(";
stream << ss.str();
stream << ")\n";
};
auto print_mbarrier_obj = [&](PrimExpr barrier_id) {
std::ostringstream ss;
if (barrier_id.as<IntImmNode>()) {
// incase the barrier_id is an integer, we need to print the barrier_id as
// an integer
ss << "(" << mbarrier_name_ << "+" << barrier_id << ")";
} else {
// otherwise may be a T.get_mbarrier() call or BufferLoad Node
// we need to print the barrier_id as a string
ss << PrintExpr_(barrier_id);
}
return ss.str();
};
if (op->op.same_as(builtin::ptx_cp_async())) {
std::string dst = PrintExpr_(op->args[0]);
std::string dst_offset = PrintExpr_(op->args[1]);
std::string src = PrintExpr_(op->args[2]);
std::string src_offset = PrintExpr_(op->args[3]);
std::string size = PrintExpr_(op->args[4]);
// use size of argument list to indicate whether or not to use predicated
// cp.async
if (op->args.size() == 5) {
PrintIndent();
stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << dst_offset
<< ", " << src << ", " << src_offset << ")\n";
} else {
std::string condition = PrintExpr_(op->args[5]);
PrintIndent();
stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", "
<< dst_offset << ", " << src << ", " << src_offset << ", "
<< condition << ")\n";
}
} else if (op->op.same_as(builtin::ptx_commit_group())) {
print_extern_call_stmt("tl.cp_async_commit");
} else if (op->op.same_as(builtin::ptx_wait_group())) {
print_extern_call_stmt("tl.cp_async_wait");
} else if (op->op.same_as(builtin::create_barriers())) {
PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value;
stream << mbarrier_name_
<< " = tl.alloc_smem(cutlass.Uint64, size_in_elems=" << barrier_count
<< ")\n";
} else if (op->op.same_as(tl::get_mbarrier())) {
ICHECK_EQ(op->args.size(), 1);
std::string barrier_id = PrintExpr_(op->args[0]);
os << "(" << mbarrier_name_ << "+" << barrier_id << ")";
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
if (op->args.size() == 1) {
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
stream << "tl.mbarrier_arrive(" << mbarrier_obj << ")\n";
} else if (op->args.size() == 3) {
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto cta_id = PrintExpr_(op->args[1]);
auto pred = PrintExpr_(op->args[2]);
stream << "tl.mbarrier_arrive(" << mbarrier_obj << ", " << cta_id << ", "
<< pred << ")\n";
} else {
LOG(FATAL) << "Invalid parameter for tl::arrive_barrier "
<< op->args.size();
}
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
ICHECK_EQ(op->args.size(), 2);
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto arrive_count = PrintExpr_(op->args[1]);
stream << "tl.mbarrier_init(" << mbarrier_obj << ", " << arrive_count
<< ")\n";
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
if (op->args.size() == 2) {
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto transaction_bytes = PrintExpr_(op->args[1]);
stream << "tl.arrive_and_expect_tx(" << mbarrier_obj << ", "
<< transaction_bytes << ")\n";
} else if (op->args.size() == 4) {
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto transaction_bytes = PrintExpr_(op->args[1]);
auto cta_id = PrintExpr_(op->args[2]);
auto pred = PrintExpr_(op->args[3]);
stream << "tl.arrive_and_expect_tx(" << mbarrier_obj << ", "
<< transaction_bytes << ", " << cta_id << ", " << pred << ")\n";
} else {
LOG(FATAL) << "Invalid parameter for tl::arrive_barrier_expect_tx "
<< op->args.size();
}
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
print_extern_call_stmt("tl.mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::ptx_fence_barrier_init())) {
print_extern_call_stmt("tl.fence_barrier_init");
} else if (op->op.same_as(tl::ptx_cp_async_barrier_noinc())) {
print_extern_call_stmt("tl.mbarrier_cp_async_arrive_noinc");
} else if (op->op.same_as(tl::mbarrier_expect_tx())) {
ICHECK_EQ(op->args.size(), 2);
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto transaction_bytes = PrintExpr_(op->args[1]);
stream << "tl.mbarrier_expect_tx(" << mbarrier_obj << ", "
<< transaction_bytes << ")\n";
} else if (op->op.same_as(tl::mbarrier_wait_parity())) {
ICHECK_EQ(op->args.size(), 2);
PrintIndent();
auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
auto phase = PrintExpr_(op->args[1]);
stream << "tl.mbarrier_wait(" << mbarrier_obj << ", " << phase << ")\n";
} else if (op->op.same_as(tl::ptx_init_tensor_memory())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ptx_deallocate_tensor_memory())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::no_set_max_nreg())) {
// do nothing
} else if (op->op.same_as(tl::tma_load())) {
std::ostringstream ss;
ICHECK_GE(op->args.size(), 2);
auto pol = op->args[op->args.size() - 1].as<IntImmNode>();
ICHECK(pol) << "Eviction policy must be IntImm";
ICHECK_GE(pol->value, 0);
ICHECK_LT(static_cast<size_t>(pol->value), eviction_policy_names_.size());
auto eviction_policy = eviction_policy_names_[pol->value];
// Simplify the code by using the default eviction policy
if (eviction_policy != "EVICT_NORMAL") {
LOG(FATAL) << "Eviction policy " << eviction_policy
<< " is not supported currently";
} else {
ss << "tl.tma_load(";
}
auto desc = op->args[0];
ss << PrintExpr_(desc) << ", ";
ss << print_mbarrier_obj(op->args[1]) << ", ";
ss << PrintExpr_(op->args[2]) << ", (";
for (size_t i = 3; i < op->args.size() - 1; i++) {
if (i > 3)
ss << ", ";
ss << PrintExpr_(op->args[i]);
}
ss << "))\n";
PrintIndent();
stream << ss.str();
} else if (op->op.same_as(tl::tma_load_im2col())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::tma_store())) {
std::stringstream ss;
// Check minimum argument count (desc, data, at least one coord,
// need_reduce, eviction)
ICHECK_GE(op->args.size(), 4) << "tma_store requires at least 4 arguments "
"(desc, data, coords..., need_reduce, "
"eviction_policy), got "
<< op->args.size();
// Safely extract need_reduce flag
auto need_reduce_ptr = op->args[op->args.size() - 2].as<IntImmNode>();
ICHECK(need_reduce_ptr)
<< "tma_store need_reduce flag (args[-2]) must be IntImm, got "
<< op->args[op->args.size() - 2]->GetTypeKey();
auto need_reduce = need_reduce_ptr->value;
if (need_reduce) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
}
// Safely extract and validate eviction policy index
auto eviction_idx_ptr = op->args[op->args.size() - 1].as<IntImmNode>();
ICHECK(eviction_idx_ptr)
<< "tma_store eviction policy (args[-1]) must be IntImm, got "
<< op->args[op->args.size() - 1]->GetTypeKey();
ICHECK_GE(eviction_idx_ptr->value, 0)
<< "tma_store eviction policy index must be >= 0, got "
<< eviction_idx_ptr->value;
ICHECK_LT(static_cast<size_t>(eviction_idx_ptr->value),
eviction_policy_names_.size())
<< "tma_store eviction policy index " << eviction_idx_ptr->value
<< " out of bounds (max " << eviction_policy_names_.size() - 1 << ")";
auto eviction_policy = eviction_policy_names_[eviction_idx_ptr->value];
ss << "tl.tma_store(";
auto desc = op->args[0];
ss << PrintExpr_(desc) << ", ";
ss << PrintExpr_(op->args[1]) << ", (";
for (size_t i = 2; i < op->args.size() - 2; i++) {
if (i > 2)
ss << ", ";
ss << PrintExpr_(op->args[i]);
}
ss << ")";
if (eviction_policy != "EVICT_NORMAL") {
ss << ", eviction_kind = nvvm.EvictKind." << eviction_policy.substr(6);
}
ss << ")\n";
PrintIndent();
stream << ss.str();
} else if (op->op.same_as(tl::ptx_ldmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl.ptx_ldmatrix_x" + std::to_string(num);
if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::ptx_stmatrix())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl.ptx_stmatrix_x" + std::to_string(num);
if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::fence_proxy_async())) {
print_extern_call_stmt("tl.fence_proxy_async");
} else if (op->op.same_as(tl::tma_store_arrive())) {
print_extern_call_stmt("tl.tma_store_arrive");
} else if (op->op.same_as(tl::tma_store_wait())) {
PrintIndent();
stream << "tl.tma_store_wait(0)\n";
} else if (op->op.same_as(tl::warpgroup_arrive())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warpgroup_commit_batch())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warpgroup_wait())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warpgroup_fence_operand())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::set_max_nreg())) {
PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value;
int is_inc = Downcast<IntImm>(op->args[1])->value;
std::string func_name =
is_inc ? "tl.warpgroup_reg_alloc" : "tl.warpgroup_reg_dealloc";
stream << func_name << "(" << nreg << ")\n";
} else if (op->op.same_as(tl::wait_wgmma())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::pack_b16())) {
os << "tl.pack_half2(" << PrintExpr_(op->args[0]) << ", "
<< PrintExpr_(op->args[1]) << ")";
} else if (op->op.same_as(tl::sync_grid())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::loop_break())) {
PrintIndent();
stream << "break\n";
} else if (op->op.same_as(builtin::ptx_mma())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ptx_mma_sm70())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::ptx_mma_sp())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ptx_wgmma_ss())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ptx_wgmma_rs())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ptx_tcgen05_mma_ss())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ptx_tcgen05_mma_ts())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::tcgen05_mma_arrive())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::ptx_ldmatrix())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::mma_store())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::mma_fill())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::ptx_cp_async_bulk())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::ptx_wait_barrier())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::ptx_ldg32())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::reinterpret())) {
DataType tgt_dtype = op->dtype;
DataType src_dtype = op->args[0]->dtype;
ICHECK_EQ(tgt_dtype.lanes() * tgt_dtype.bits(),
src_dtype.lanes() * src_dtype.bits())
<< "reinterpret expects source and target to have the same number of "
"bits";
const BufferLoadNode *load = op->args[0].as<BufferLoadNode>();
ICHECK(op->args.size() == 1 && load);
ICHECK_EQ(load->indices.size(), 1)
<< "CodeGenTileLangCuTeDSL only supports flat memory";
PrimExpr index = load->indices[0];
if (const RampNode *node = index.as<RampNode>(); node) {
auto *p_stride = as_const_int(node->stride);
CHECK(p_stride);
ICHECK_EQ(*p_stride, 1) << "reinterpret expects contiguous elements";
index = node->base;
}
auto ptr_str = GetBufferPtr_(load->buffer.get(), index);
os << "tl.make_tensor(tl.recast_ptr(" << ptr_str << ", dtype=";
PrintType(tgt_dtype.element_of(), os);
os << "), (" << tgt_dtype.lanes() << ",)).load()";
} else if (op->op.same_as(builtin::thread_return())) {
os << "return";
} else if (op->op.same_as(tl::tl_gemm())) {
ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments <op_instance, "
"A_ptr, B_ptr, C_ptr>, but got "
<< op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]);
PrintCallExtern_(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
op_instance->value, op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_sp())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::get_lane_idx())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::get_warp_idx_sync())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::get_warp_idx())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::get_warp_group_idx())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::tl_shuffle_elect())) {
os << "tl.shuffle_elect(" << PrintExpr_(op->args[0]) << ")";
} else if (op->op.same_as(tl::initialize_wgmma_descriptor())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::initialize_tcgen05_descriptor())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::increase_descriptor_offset())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::__exp())) {
os << "tl.exp2(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::__exp10())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::__log())) {
os << "tl.log(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::__log2())) {
os << "tl.log2(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::__log10())) {
os << "tl.log10(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::__tan())) {
os << "tl.tan(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::__cos())) {
os << "tl.cos(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::__sin())) {
os << "tl.sin(" << PrintExpr_(op->args[0]) << ", fastmath=True)";
} else if (op->op.same_as(tl::ieee_add())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_sub())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_mul())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_fmaf())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_frcp())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_fsqrt())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_frsqrt())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::ieee_fdiv())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warp_reduce_sum())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warp_reduce_max())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warp_reduce_min())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warp_reduce_bitand())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(tl::warp_reduce_bitor())) {
LOG(FATAL) << "Currently unsupported op: " << op->op;
} else if (op->op.same_as(builtin::address_of())) {
const BufferLoadNode *load = op->args[0].as<BufferLoadNode>();
ICHECK(op->args.size() == 1 && load);
ICHECK_EQ(load->indices.size(), 1)
<< "CodeGenTileLangCuTeDSL only supports flat memory";
os << GetBufferPtr_(load->buffer.get(), load->indices[0]);
} else {
CodeGenTileLangPY::VisitExpr_(op, os);
}
}
void CodeGenTileLangCuTeDSL::VisitExpr_(const BufferLoadNode *op,
std::ostream &os) { // NOLINT(*)
ICHECK_EQ(op->indices.size(), 1)
<< "Load from non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer load is not supported.";
DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
DataType element_dtype = op->buffer->dtype;
const int value_lanes = value_dtype.lanes();
if (value_lanes == element_dtype.lanes()) {
std::string ref = GetBufferRef_(value_dtype, op->buffer.get(), index);
if (ref.back() == ')') {
ref += ".load()";
}
os << ref;
} else {
ICHECK_GE(value_lanes, element_dtype.lanes())
<< "Unsupported load/store: value lanes < buffer element lanes";
bool is_contiguous = false;
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, value_lanes / element_dtype.lanes())
.Match(index)) {
is_contiguous = true;
}
if (is_contiguous) {
std::string ref =
GetBufferRef_(value_dtype, op->buffer.get(), base.Eval());
if (ref.back() == ')') {
ref += ".load()";
}
os << ref;
} else {
ICHECK(element_dtype.is_scalar())
<< "buffer element type for non-contiguous load must be scalar "
"currently";
std::string sret = name_supply_->FreshName("_");
PrintIndent();
stream << sret << " = tl.make_rmem_tensor((" << value_lanes << ",), ";
PrintType(element_dtype, stream);
stream << ")\n";
std::string vid = GetVarID(buffer_var.get());
const RampNode *ramp = index.as<RampNode>();
ICHECK(ramp)
<< "Expected Ramp index for vectorized non-contiguous access";
for (int i = 0; i < value_lanes; ++i) {
auto idx_expr =
arith::Analyzer().Simplify(ramp->base + ramp->stride * i);
PrintIndent();
stream << sret << "[" << i << "] = "
<< GetBufferRef_(element_dtype, op->buffer.get(), idx_expr)
<< "\n";
}
os << sret << ".load()";
}
}
}
void CodeGenTileLangCuTeDSL::VisitStmt_(const BufferStoreNode *op) {
ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer store is not supported.";
DataType value_dtype = op->value.dtype();
DataType element_dtype = op->buffer->dtype;
PrimExpr index_expr = op->indices[0];
Var buffer_var = op->buffer->data;
std::string value_str = PrintExpr_(op->value);
int value_lanes = value_dtype.lanes();
if (value_lanes == element_dtype.lanes()) {
std::string ref = GetBufferRef_(value_dtype, op->buffer.get(), index_expr);
PrintIndent();
if (ref.back() != ')') {
stream << ref << " = " << RemoveOutermostParentheses(value_str) << "\n";
} else {
stream << ref << ".store(" << RemoveOutermostParentheses(value_str)
<< ")\n";
}
} else {
bool is_contiguous = false;
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, value_lanes / element_dtype.lanes())
.Match(index_expr)) {
is_contiguous = true;
}
if (is_contiguous) {
PrintVecStore_(op->buffer.get(), value_dtype, base.Eval(), value_str);
} else {
ICHECK(element_dtype.is_scalar())
<< "buffer element type for non-contiguous store must be scalar "
"currently";
// store elements separately
value_str = SSAGetID(value_str, element_dtype);
for (int i = 0; i < value_lanes; ++i) {
const RampNode *ramp = index_expr.as<RampNode>();
ICHECK(ramp);
auto idx_expr =
arith::Analyzer().Simplify(ramp->base + ramp->stride * i);
PrintIndent();
stream << GetBufferRef_(element_dtype, op->buffer.get(), idx_expr)
<< " = ";
PrintVecElemLoad_(value_str, value_dtype, i, stream);
stream << "\n";
}
}
}
}
void CodeGenTileLangCuTeDSL::VisitStmt_(const AllocateNode *op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
PrintIndent();
std::string scope = GetPtrStorageScope(op->buffer_var);
alloc_storage_scope_[op->buffer_var.get()] = scope;
if (scope == "local.descriptor.wgmma") {
stream << vid << " = tl.GmmaDescriptor()\n";
} else if (scope == "local.descriptor.tcgen05_smem") {
LOG(FATAL) << "Currently unsupported scope: " << scope;
} else if (scope == "local.descriptor.tcgen05_instr") {
LOG(FATAL) << "Currently unsupported scope: " << scope;
} else if (scope == "shared.dyn") {
stream << vid << " = tl.make_tensor(tl.get_dyn_smem(";
PrintType(op->dtype, stream);
// there is no bound check for Tensor access, so just set shape to 1
stream << ", alignment=1024), (1,))\n";
} else {
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now, but get "
<< constant_size << " for " << op->buffer_var->name_hint;
if (scope == "shared") {
stream << vid << " = tl.make_tensor(tl.alloc_smem(";
PrintType(op->dtype, stream);
stream << ", " << constant_size << "), (" << constant_size << ",))\n";
} else if (scope == "shared.barrier") {
ICHECK(false) << "Unsupported scope: " << scope;
} else if (scope == "local") {
stream << vid << " = tl.make_rmem_tensor((" << constant_size << "),";
PrintType(op->dtype, stream);
stream << ")\n";
} else if (scope == "local.var") {
PrimExpr init = tir::make_const(op->dtype, 0);
auto init_it = op->annotations.find(tl::attr::kLocalVarInit);
if (init_it != op->annotations.end()) {
PrimExpr user_init = Downcast<PrimExpr>((*init_it).second);
if (!user_init.dtype().is_void() && user_init.dtype() != op->dtype) {
user_init = tir::Cast(op->dtype, user_init);
}
init = user_init;
}
stream << vid << " = " << PrintExpr_(init) << "\n";
} else {
ICHECK(false) << "Unsupported scope: " << scope;
}
}
RegisterHandleType_(op->buffer_var.get(), op->dtype);
PrintStmt_(op->body);
}
void CodeGenTileLangCuTeDSL::VisitStmt_(const AttrStmtNode *op) {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (!iv->thread_tag.empty()) {
if (!var_idmap_.count(iv->var.get())) {
BindThreadIndex_(iv);
}
}
VisitStmt(op->body);
} else if (op->attr_key == tir::attr::async_commit_queue_scope) {
const IntImmNode *queue_id = op->value.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
VisitStmt(op->body);
auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
VisitExpr(commit_group, stream);
} else if (op->attr_key == tir::attr::async_wait_queue_scope) {
auto wait_attrs = GetAsyncWaitAttributes(op);
auto queue_id = wait_attrs.first.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
auto wait_cnt = wait_attrs.second;
auto wait_group =
Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
VisitExpr(wait_group, stream);
auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner);
VisitStmt(inner->body);
} else if (op->attr_key == "threadblock_swizzle_pattern") {
this->PrintIndent();
const StringImmNode *pattern = op->value.as<StringImmNode>();
ICHECK(pattern);
std::string call_str = pattern->value;
// replace :: with . and replace < with ( and replace > with )
ReplaceAll(call_str, "::", ".");
ReplaceAll(call_str, "<", "(");
ReplaceAll(call_str, ">", ")");
this->stream << "blockIdx = " << call_str << "\n";
this->VisitStmt(op->body);
} else if (op->attr_key == "pragma_unroll_factor") {
const IntImmNode *factor = op->value.as<IntImmNode>();
ICHECK(factor);
unroll_factor_[op->node.as<VarNode>()] = Downcast<IntImm>(factor);
CodeGenTileLangPY::VisitStmt_(op);
} else {
CodeGenTileLangPY::VisitStmt_(op);
}
}
void CodeGenTileLangCuTeDSL::VisitStmt_(const ForNode *op) {
if (op->kind != tir::ForKind::kUnrolled) {
CodeGenTileLangPY::VisitStmt_(op);
return;
}
auto start_expr = arith::Analyzer().Simplify(op->min);
auto stop_expr = arith::Analyzer().Simplify(op->extent + op->min);
std::string unroll_factor;
if (auto it = unroll_factor_.find(op->loop_var.get());
it != unroll_factor_.end()) {
unroll_factor = PrintExpr_(it->second);
}
bool use_range_constexpr = unroll_factor.empty() &&
as_const_int(op->extent) != nullptr &&
*as_const_int(op->extent) <= LOOP_UNROLL_THRESHOLD;
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
stream << "for " << vid << " in cutlass.range";
if (use_range_constexpr) {
stream << "_constexpr";
}
stream << "(";
if (!is_zero(start_expr)) {
PrintExpr_(start_expr, stream);
stream << ", ";
}
PrintExpr_(stop_expr, stream);
if (!unroll_factor.empty()) {
stream << ", unroll=" << unroll_factor;
} else if (!use_range_constexpr) {
stream << ", unroll_full=True";
}
stream << "):\n";
int for_scope = BeginScope();
PrintStmt_(op->body);
EndScope(for_scope);
}
void CodeGenTileLangCuTeDSL::VisitStmt_(const IfThenElseNode *op) {
std::string cond = PrintExpr_(op->condition);
PrintIndent();
stream << "if " << RemoveOutermostParentheses(cond) << ":\n";
int then_scope = BeginScope();
if (const CallNode *call = op->condition.as<CallNode>();
call && call->op.same_as(tl::tl_shuffle_elect())) {
PrintIndent();
stream << "with cute.arch.elect_one():\n";
int with_scope = BeginScope();
PrintStmt_(op->then_case);
EndScope(with_scope);
} else {
PrintStmt_(op->then_case);
}
EndScope(then_scope);
if (op->else_case) {
PrintIndent();
stream << "else:\n";
int else_scope = BeginScope();
PrintStmt_(op->else_case.value());
EndScope(else_scope);
}
}
void CodeGenTileLangCuTeDSL::VisitStmt_(const EvaluateNode *op) {
if (is_const_int(op->value))
return;
const CallNode *call = op->value.as<CallNode>();
if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
LOG(FATAL) << "Currently unsupported op: " << call->op;
}
if (call && (call->op.same_as(tvm::tl::device_assert()))) {
std::string cond = RemoveOutermostParentheses(PrintExpr_(call->args[0]));
PrintIndent();
stream << "assert " << cond << "\n";
} else if (call && call->op.same_as(tvm::tl::device_assert_with_msg())) {
std::string cond = RemoveOutermostParentheses(PrintExpr_(call->args[0]));
std::string msg_expr = PrintExpr_(call->args[1]);
PrintIndent();
stream << "assert " << cond << ", " << msg_expr << "\n";
} else if (call && call->op.same_as(builtin::tvm_storage_sync())) {
PrintStorageSync_(call);
} else {
CodeGenTileLangPY::VisitStmt_(op);
}
}
void CodeGenTileLangCuTeDSL::PrintVecElemLoad_(const std::string &vec,
DataType t, int i,
std::ostream &os) { // NOLINT(*)
if (t.is_scalar()) {
os << vec;
return;
}
os << vec << "[" << i << "]";
}
void CodeGenTileLangCuTeDSL::PrintVecElemStore_(const std::string &vec,
DataType t, int i,
const std::string &value) {
PrintIndent();
stream << vec << "[" << i << "] = " << value << "\n";
}
void CodeGenTileLangCuTeDSL::PrintVecStore_(const BufferNode *buffer,
DataType t, PrimExpr base,
const std::string &value) {
ICHECK(!t.is_scalar()) << "PrintVecStore_() should not be used for scalar";
std::string ref = GetBufferRef_(t, buffer, base);
PrintIndent();
stream << ref << ".store(" << value << ")\n";
}
void CodeGenTileLangCuTeDSL::PrintVecBinaryOp_(const std::string &opstr,
DataType dtype, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os) { // NOLINT(*)
// Declare the result.
std::string sret = name_supply_->FreshName("_");
PrintIndent();
stream << sret << " = tl.make_rmem_tensor((" << dtype.lanes() << ",), ";
PrintType(dtype.element_of(), stream);
stream << ")\n";
std::string vlhs = SSAGetID(PrintExpr_(lhs), lhs.dtype());
std::string vrhs = SSAGetID(PrintExpr_(rhs), rhs.dtype());
const std::string one_char_op{"+-*%<>^|&"};
const std::string two_char_op{"// == != <= >="};
if ((opstr.size() == 1 && one_char_op.find(opstr) != std::string::npos) ||
(opstr.size() == 2 && two_char_op.find(opstr) != std::string::npos)) {
PrintIndent();
stream << sret << ".store(" << vlhs << " " << opstr << " " << vrhs << ")\n";
} else {
// Unpack into individual ops.
for (int i = 0, lanes = dtype.lanes(); i < lanes; ++i) {
std::ostringstream value_temp;
if (isalpha(opstr[0])) {
value_temp << opstr << "(";
PrintVecElemLoad_(vlhs, lhs.dtype(), i, value_temp);
value_temp << ", ";
PrintVecElemLoad_(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
} else {
value_temp << "(";
PrintVecElemLoad_(vlhs, lhs.dtype(), i, value_temp);
value_temp << opstr;
PrintVecElemLoad_(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
}
PrintVecElemStore_(sret, dtype, i, value_temp.str());
}
}
os << sret << ".load()";
}
void CodeGenTileLangCuTeDSL::PrintBinaryExpr_(const std::string &opstr,
DataType dtype, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os) { // NOLINT(*)
if (dtype.is_scalar()) {
CodeGenTileLangPY::PrintBinaryExpr_(opstr, dtype, lhs, rhs, os);
} else {
PrintVecBinaryOp_(opstr, dtype, lhs, rhs, os);
}
}
void CodeGenTileLangCuTeDSL::PrintBinaryIntrinsic_(
const CallNode *op, const char *opstr,
std::ostream &os) { // NOLINT(*)
if (op->dtype.is_scalar()) {
CodeGenTileLangPY::PrintBinaryIntrinsic_(op, opstr, os);
} else {
PrintVecBinaryOp_(opstr, op->dtype, op->args[0], op->args[1], os);
}
}
void CodeGenTileLangCuTeDSL::PrintCallExtern_(Type ret_type,
ffi::String global_symbol,
const ffi::Array<PrimExpr> &args,
bool skip_first_arg,
std::ostream &os) { // NOLINT(*)
DataType ret_dtype = GetRuntimeDataType(ret_type);
std::string global_symbol_str = global_symbol;
ReplaceAll(global_symbol_str, "::", ".");
std::vector<std::string> sargs;
// when the template arguments occurs at the end, merge them with function
// arguments
if (global_symbol_str.back() == '>') {
auto pos = global_symbol_str.rfind('<');
ICHECK(pos != std::string::npos);
std::string template_args =
global_symbol_str.substr(pos + 1, global_symbol_str.size() - pos - 2);
ReplaceAll(template_args, "true", "True");
ReplaceAll(template_args, "false", "False");
sargs.push_back(template_args);
global_symbol_str.resize(pos);
}
const size_t arg_begin = static_cast<size_t>(skip_first_arg);
for (size_t i = arg_begin; i < args.size(); ++i) {
std::string sarg = PrintExpr_(args[i]);
if (ret_dtype.is_fixed_length_vector()) {
std::string val = SSAGetID(sarg, args[i].dtype());
sargs.push_back(std::move(val));
} else {
sargs.push_back(sarg);
}
}
// Replace "<...>" with "(...)". Nested "<" is not supported
{
auto pos_left = global_symbol_str.find('<');
while (pos_left != std::string::npos) {
auto pos_right = global_symbol_str.find('>', pos_left + 1);
if (pos_right != std::string::npos) {
auto args =
global_symbol_str.substr(pos_left + 1, pos_right - pos_left - 1);
ReplaceAll(args, "true", "True");
ReplaceAll(args, "false", "False");
global_symbol_str.replace(pos_left, args.size() + 2, "(" + args + ")");
}
pos_left = global_symbol_str.find('<');
}
}
// Special cases:
// Map C math functions to Python/cutedsl equivalents
const auto canonicalized_global_symbol_str =
CanonicalizeFastmathFunctionName_(global_symbol_str);
const bool canonicalized = !canonicalized_global_symbol_str.empty();
if (canonicalized) {
global_symbol_str = canonicalized_global_symbol_str;
}
// Atomic Functions
if (global_symbol_str.substr(0, 6) == "Atomic") {
global_symbol_str = "tl." + global_symbol_str;
// Convert first argument (Buffer) to pointer for atomic operations
if (const BufferLoadNode *load = args[arg_begin].as<BufferLoadNode>()) {
ICHECK_EQ(load->indices.size(), 1)
<< "CodeGenTileLangCuTeDSL only supports flat memory";
sargs[0] = GetBufferPtr_(load->buffer.get(), load->indices[0]);
}
}
// some optional template arguments might be ommited, so add names explicitly
// for remain arguments
if (global_symbol_str == "tl.gemm_ss" || global_symbol_str == "tl.gemm_rs" ||
global_symbol_str == "tl.gemm_sr" || global_symbol_str == "tl.gemm_rr") {
ICHECK(sargs.size() >= 3);
sargs[sargs.size() - 3] = "A_ptr=" + sargs[sargs.size() - 3];
sargs[sargs.size() - 2] = "B_ptr=" + sargs[sargs.size() - 2];
sargs[sargs.size() - 1] = "C_ptr=" + sargs[sargs.size() - 1];
}
if (ret_dtype.is_fixed_length_vector()) {
// maybe simplify this if TensorSSA suppports this OP
std::string sret = name_supply_->FreshName("_");
PrintIndent();
stream << sret << " = tl.make_rmem_tensor((" << ret_dtype.lanes() << ",), ";
PrintType(ret_dtype.element_of(), stream);
stream << ")\n";
// Emit a scalar call for each lane.
bool has_template_arg = (sargs.size() > args.size() - arg_begin);
for (int i = 0; i < ret_dtype.lanes(); ++i) {
std::ostringstream scall;
scall << global_symbol_str << "(";
for (size_t j = 0; j < sargs.size(); ++j) {
if (j != 0) {
scall << ", ";
}
if (j == 0 && has_template_arg) {
scall << sargs[j];
} else {
PrintVecElemLoad_(
sargs[j],
args[arg_begin + j - static_cast<size_t>(has_template_arg)]
.dtype(),
i, scall);
}
}
if (canonicalized && enable_fastmath_) {
if (!sargs.empty()) {
scall << ", ";
}
scall << "fastmath=True";
}
scall << ")";
PrintVecElemStore_(sret, ret_dtype, i, scall.str());
}
os << sret << ".load()";
} else {
os << global_symbol_str << "(";
for (size_t i = 0; i < sargs.size(); ++i) {
if (i != 0) {
os << ", ";
}
os << sargs[i];
}
if (canonicalized && enable_fastmath_) {
if (!sargs.empty()) {
os << ", ";
}
os << "fastmath=True";
}
os << ")";
}
}
std::string CodeGenTileLangCuTeDSL::GetBufferPtr_(const BufferNode *buffer,
PrimExpr index) {
const VarNode *buffer_var = buffer->data.get();
const std::string vid = GetVarID(buffer_var);
DataType buffer_element_dtype = buffer->dtype;
bool is_handle_type_match =
HandleTypeMatch_(buffer_var, buffer_element_dtype);
std::string ptr_str;
if (is_handle_type_match) {
ptr_str = vid + ".iterator";
} else {
ptr_str = "tl.recast_ptr(" + vid +
".iterator, dtype=" + DTypeToString(buffer_element_dtype) + ")";
}
std::string index_str = PrintExpr_(index);
return "(" + ptr_str + " + " + index_str + ")";
}
// The following forms can be returned:
// (1) vid
// (2) vid[i]
// (3) tl.make_tensor_at_offset(...)[0]
// (4) tl.make_tensor_at_offset(...)
//
// Form (4) is needed when the whole tensor is loaded or stored.
// It's the only form that ends with ")". Using this fact, BufferLoadNode will
// add ".load()" and BufferStoreNode will add ".store()".
std::string CodeGenTileLangCuTeDSL::GetBufferRef_(DataType t,
const BufferNode *buffer,
PrimExpr index) {
const VarNode *buffer_var = buffer->data.get();
std::string vid = GetVarID(buffer_var);
std::string scope;
if (alloc_storage_scope_.count(buffer_var)) {
scope = alloc_storage_scope_.at(buffer_var);
}
if (scope.empty()) {
scope = GetPtrStorageScope(buffer->data);
}
if (scope == "local.var" || scope.find("local.descriptor") == 0) {
return vid;
}
DataType buffer_element_dtype = buffer->dtype;
bool is_handle_type_match =
HandleTypeMatch_(buffer_var, buffer_element_dtype);
std::string ptr_str;
if (is_handle_type_match) {
ptr_str = vid + ".iterator";
} else {
ptr_str = "tl.recast_ptr(" + vid +
".iterator, dtype=" + DTypeToString(buffer_element_dtype) + ")";
}
const std::string index_str = PrintExpr_(index);
if (t == buffer_element_dtype) {
if (is_handle_type_match && buffer_element_dtype.is_scalar() &&
(scope == "local" || scope == "shared" || scope == "shared.dyn" ||
scope == "shared.barrier")) {
// Tensors in these scopes are allocated as one-dimensional, so can be
// assessed via "[]" correctly. Other tensors may be multi-dimensional,
// and must be assessed via ptr, otherwise CuTeDSL will interpret "[]"
// access using its visiting order and layout.
return vid + "[" + index_str + "]";
} else {
std::ostringstream os;
os << "tl.make_tensor_at_offset(" << ptr_str << ", " << index_str
<< ", (1,), div_by=" << buffer_element_dtype.lanes() << ")";
// for vector data types, ".load()" (added by BufferLoadNode) is neeed
// instead of "[0]"
if (buffer_element_dtype.is_scalar()) {
os << "[0]";
}
return os.str();
}
} else {
const int num = t.bits() * t.lanes();
const int den = buffer_element_dtype.bits() * buffer_element_dtype.lanes();
ICHECK_EQ(num % den, 0) << "Cannot form view: bitwidth not divisible";
int buffer_size = num / den;
std::ostringstream os;
os << "tl.make_tensor_at_offset(" << ptr_str << ", " << index_str << ", ("
<< buffer_size << ",), div_by=" << buffer_size << ")";
return os.str();
}
}
void CodeGenTileLangCuTeDSL::BindThreadIndex_(const IterVar &iv) {
ICHECK(!var_idmap_.count(iv->var.get()));
auto &thread_tag = iv->thread_tag;
ICHECK(thread_tag == "threadIdx.x" || thread_tag == "threadIdx.y" ||
thread_tag == "threadIdx.z" || thread_tag == "blockIdx.x" ||
thread_tag == "blockIdx.y" || thread_tag == "blockIdx.z");
// cute.arch.thread_idx() and block_idx() are Int32
DataType from_dtype = DataType::Int(32);
var_idmap_[iv->var.get()] =
CastFromTo_(thread_tag, from_dtype, iv->var.dtype());
}
void CodeGenTileLangCuTeDSL::PrintStorageSync_(const CallNode *op) {
auto args = op->args;
const std::string &sync = args[0].as<StringImmNode>()->value;
if (sync == "warp") {
// do nothing
} else if (sync == "shared" || sync == "shared.dyn") {
PrintIndent();
if (args.size() == 1) {
stream << "tl.sync_threads()\n";
} else if (args.size() == 2) {
auto barrier_id_ptr = args[1].as<IntImmNode>();
ICHECK(barrier_id_ptr)
<< "storage_sync barrier_id (args[1]) must be IntImm, got "
<< args[1]->GetTypeKey();
auto barrier_id = barrier_id_ptr->value;
stream << "tl.sync_thread_partial(" << barrier_id << ")\n";
} else if (args.size() == 3) {
auto barrier_id_ptr = args[1].as<IntImmNode>();
ICHECK(barrier_id_ptr)
<< "storage_sync barrier_id (args[1]) must be IntImm, got "
<< args[1]->GetTypeKey();
auto thread_count_ptr = args[2].as<IntImmNode>();
ICHECK(thread_count_ptr)
<< "storage_sync thread_count (args[2]) must be IntImm, got "
<< args[2]->GetTypeKey();
auto barrier_id = barrier_id_ptr->value;
auto thread_count = thread_count_ptr->value;
stream << "tl.sync_thread_partial(" << barrier_id << ", " << thread_count
<< ")\n";
} else {
LOG(FATAL) << "Invalid number of arguments for storage sync: "
<< args.size();
}
} else if (sync == "global") {
LOG(FATAL) << "PrintStorageSync_ for global is not supported for now";
} else {
LOG(FATAL) << "Unknown storage sync scope: " << sync;
}
}
} // namespace codegen
} // namespace tvm
/*!
* \file target/codegen_cutedsl.h
* \brief Utility to generate CuTeDSL code
*/
#ifndef TVM_TL_TARGET_CODEGEN_CUTEDSL_H_
#define TVM_TL_TARGET_CODEGEN_CUTEDSL_H_
#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <string>
#include <unordered_map>
#include <vector>
#include "codegen_py.h"
namespace tvm {
namespace codegen {
class CodeGenTileLangCuTeDSL final : public CodeGenTileLangPY {
public:
CodeGenTileLangCuTeDSL();
protected:
void PrintFuncDecorator_(std::ostream &os) override; // NOLINT(*)
void PreFunctionBody_(const PrimFunc &f) override;
protected:
void PrintType(DataType t, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const BroadcastNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const CastNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const DivNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MinNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MaxNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const CallNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const BufferLoadNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitStmt_(const BufferStoreNode *op) override;
void VisitStmt_(const AllocateNode *op) override;
void VisitStmt_(const AttrStmtNode *op) override;
void VisitStmt_(const ForNode *op) override;
void VisitStmt_(const IfThenElseNode *op) override;
void VisitStmt_(const EvaluateNode *op) override;
protected:
virtual void PrintVecElemLoad_(const std::string &vec, DataType t, int i,
std::ostream &os); // NOLINT(*)
virtual void PrintVecElemStore_(const std::string &vec, DataType t, int i,
const std::string &value);
virtual void PrintVecStore_(const BufferNode *buffer, DataType t,
PrimExpr base, const std::string &value);
void PrintVecBinaryOp_(const std::string &opstr, DataType dtype, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os); // NOLINT(*)
void PrintBinaryExpr_(const std::string &opstr, DataType dtype, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os) override; // NOLINT(*)
void PrintBinaryIntrinsic_(const CallNode *op, const char *opstr,
std::ostream &os) override; // NOLINT(*)
void PrintCallExtern_(Type ret_type, ffi::String global_symbol,
const ffi::Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) override; // NOLINT(*)
std::string GetBufferPtr_(const BufferNode *buffer, PrimExpr index);
std::string GetBufferRef_(DataType t, const BufferNode *buffer,
PrimExpr index) override;
/*!
* \brief Print expr representing the thread tag
* \param IterVar iv The thread index to be binded;
*/
virtual void BindThreadIndex_(const IterVar &iv); // NOLINT(*)
virtual void PrintStorageSync_(const CallNode *op);
std::string
CanonicalizeFastmathFunctionName_(const std::string &func_name) const;
private:
// The name of the mbarrier array in shared memory
const std::string mbarrier_name_ = "mbarrier";
std::unordered_map<const VarNode *, IntImm> unroll_factor_;
std::vector<std::string> eviction_policy_names_ = {
"EVICT_NORMAL", "EVICT_FIRST", "EVICT_LAST"};
// Fastmath configuration (read from PassContext)
bool enable_fastmath_ = false;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TL_TARGET_CODEGEN_CUTEDSL_H_
......@@ -829,6 +829,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(tl::pack_b16())) {
os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
<< this->PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(tl::__ldg())) {
// HIP fallback: regular load
const BufferLoadNode *bl = op->args[0].as<BufferLoadNode>();
ICHECK(bl) << "T.__ldg expects a BufferLoad as the first argument.";
ICHECK_EQ(bl->indices.size(), 1)
<< "T.__ldg currently supports flattened 1D buffer accesses.";
const BufferNode *buffer = bl->buffer.get();
PrimExpr base = bl->indices[0];
auto buffer_ref = this->GetBufferRef(op->dtype, buffer, base);
os << buffer_ref;
} else if (op->op.same_as(builtin::tvm_fill_fragment())) {
need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 6U);
......@@ -1256,9 +1266,9 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
if (op->value < 0) {
temp << "-";
}
temp << ((op->dtype.bits() == 32) ? "HIPRT_INF_F" : "HIPRT_INF");
temp << ((op->dtype.bits() == 32) ? "HUGE_VALF" : "HUGE_VAL");
} else if (std::isnan(op->value)) {
temp << ((op->dtype.bits() == 32) ? "HIPRT_NAN_F" : "HIPRT_NAN");
temp << ((op->dtype.bits() == 32) ? "NAN" : "NAN");
} else {
temp << std::scientific << op->value;
if (op->dtype.bits() == 32)
......@@ -1378,6 +1388,12 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
ICHECK(global_symbol.has_value())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
std::unordered_set<const VarNode *> non_restrict;
if (auto opt =
f->GetAttr<ffi::Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
for (const tir::Var &v : opt.value())
non_restrict.insert(v.get());
}
this->PrintFuncPrefix(stream);
CodeGenC::PrintType(f->ret_type, stream);
......@@ -1411,7 +1427,7 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
}
}
if (no_alias) {
if (no_alias && !non_restrict.count(v.get())) {
PrintRestrict(v, stream);
}
} else {
......
/*!
* \file codegen_py.cc
*/
#include "codegen_py.h"
#include "codegen_utils.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ir/name_supply.h>
#include <cctype>
namespace tvm {
namespace codegen {
void CodeGenTileLangPY::AddFunction(const GlobalVar &gvar, const PrimFunc &f) {
RegisterFunction_(gvar, f);
auto function_name = GetFunctionName_(gvar);
// clear previous generated state.
InitFuncState_(f);
PrintFuncDecorator_(stream);
PrintFunctionSignature_(function_name, f, stream);
stream << ":\n";
int func_scope = BeginScope();
PreFunctionBody_(f);
PrintStmt_(f->body);
EndScope(func_scope);
}
std::string CodeGenTileLangPY::Finish() {
std::ostringstream code;
code << decl_stream.str();
code << stream.str();
return code.str();
}
ffi::String CodeGenTileLangPY::GetFunctionName_(const GlobalVar &gvar) {
auto it = internal_functions_.find(gvar);
ICHECK(it != internal_functions_.end())
<< "Attempted to find name of " << gvar
<< ", but no function with this GlobalVar has been declared";
return it->second;
}
void CodeGenTileLangPY::RegisterFunction_(const GlobalVar &gvar,
const PrimFunc &func) {
if (internal_functions_.count(gvar)) {
return;
}
auto function_name = [&]() -> ffi::String {
if (auto global_symbol =
func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
auto name = global_symbol.value();
ICHECK(!func_name_supply_->ContainsName(name))
<< "Function " << gvar << " must use global symbol " << name
<< ", but this name has already been used.";
func_name_supply_->ReserveName(name);
return name;
} else {
ICHECK(!func_name_supply_->ContainsName(gvar->name_hint))
<< "Function " << gvar << " must use name hint " << gvar->name_hint
<< ", but this name has already been used.";
func_name_supply_->ReserveName(gvar->name_hint);
return gvar->name_hint;
}
}();
internal_functions_.insert({gvar, function_name});
}
void CodeGenTileLangPY::InitFuncState_(const PrimFunc &f) {
alloc_storage_scope_.clear();
handle_data_type_.clear();
CodeGenSourceBase::ClearFuncState();
ReserveKeywordsAsUnique_();
}
void CodeGenTileLangPY::PrintFunctionSignature_(
const ffi::String &function_name, const PrimFunc &func,
std::ostream &os) { // NOLINT(*)
os << "def " << function_name << "(";
for (size_t i = 0; i < func->params.size(); ++i) {
tir::Var v = func->params[i];
if (i > 0) {
os << ", ";
}
os << AllocVarID(v.get());
}
os << ")";
// Register handle data type
for (const auto &param : func->params) {
if (auto *ptr = param->type_annotation.as<PointerTypeNode>()) {
if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType_(param.get(), prim->dtype);
}
}
}
}
void CodeGenTileLangPY::ReserveKeywordsAsUnique_() {
// skip the first underscore, so SSA variable starts from _1
name_supply_->ReserveName("_");
name_supply_->ReserveName("False");
name_supply_->ReserveName("None");
name_supply_->ReserveName("True");
name_supply_->ReserveName("and");
name_supply_->ReserveName("as");
name_supply_->ReserveName("assert");
name_supply_->ReserveName("async");
name_supply_->ReserveName("await");
name_supply_->ReserveName("break");
name_supply_->ReserveName("class");
name_supply_->ReserveName("continue");
name_supply_->ReserveName("def");
name_supply_->ReserveName("del");
name_supply_->ReserveName("elif");
name_supply_->ReserveName("else");
name_supply_->ReserveName("except");
name_supply_->ReserveName("finally");
name_supply_->ReserveName("for");
name_supply_->ReserveName("from");
name_supply_->ReserveName("global");
name_supply_->ReserveName("if");
name_supply_->ReserveName("import");
name_supply_->ReserveName("in");
name_supply_->ReserveName("is");
name_supply_->ReserveName("lambda");
name_supply_->ReserveName("nonlocal");
name_supply_->ReserveName("not");
name_supply_->ReserveName("or");
name_supply_->ReserveName("pass");
name_supply_->ReserveName("raise");
name_supply_->ReserveName("return");
name_supply_->ReserveName("try");
name_supply_->ReserveName("while");
name_supply_->ReserveName("with");
name_supply_->ReserveName("yield");
name_supply_->ReserveName("void");
name_supply_->ReserveName("int");
name_supply_->ReserveName("float");
name_supply_->ReserveName("double");
name_supply_->ReserveName("char");
name_supply_->ReserveName("unsigned");
name_supply_->ReserveName("short");
name_supply_->ReserveName("long");
name_supply_->ReserveName("cutlass");
name_supply_->ReserveName("cute");
name_supply_->ReserveName("tl");
}
void CodeGenTileLangPY::PrintSSAAssign(const std::string &target,
const std::string &src, DataType t) {
stream << target << " = " << RemoveOutermostParentheses(src) << "\n";
}
void CodeGenTileLangPY::PrintType(DataType type,
std::ostream &os) { // NOLINT(*)
if (type.is_float()) {
if (type.bits() == 16 || type.bits() == 32 || type.bits() == 64) {
os << "float";
} else {
LOG(FATAL) << "Cannot convert float" << type.bits() << " to Python type";
}
} else if (type.is_uint()) {
switch (type.bits()) {
case 8:
case 16:
case 32:
case 64: {
os << "int";
break;
}
case 1:
os << "bool";
break;
default:
LOG(FATAL) << "Cannot convert uint" << type.bits() << " to Python type";
}
} else if (type.is_int()) {
switch (type.bits()) {
case 8:
case 16:
case 32:
case 64: {
os << "int";
break;
}
case 1:
os << "bool";
break;
default:
LOG(FATAL) << "Cannot convert int" << type.bits() << " to Python type";
}
} else {
LOG(FATAL) << "Cannot convert type " << type << " to Python type";
}
}
void CodeGenTileLangPY::VisitExpr_(const VarNode *op,
std::ostream &os) { // NOLINT(*)
os << GetVarID(op);
}
void CodeGenTileLangPY::VisitExpr_(const IntImmNode *op,
std::ostream &os) { // NOLINT(*)
if (op->dtype == DataType::Bool()) {
os << (op->value ? "True" : "False");
} else {
std::ostringstream temp;
temp << op->value;
MarkConst(temp.str());
os << temp.str();
}
}
void CodeGenTileLangPY::VisitExpr_(const FloatImmNode *op,
std::ostream &os) { // NOLINT(*)
switch (op->dtype.bits()) {
case 64:
case 32: {
std::ostringstream temp;
temp << "float.fromhex('" << std::hexfloat << op->value << "')";
MarkConst(temp.str());
os << temp.str();
break;
}
case 16: {
PrintType(op->dtype, os);
os << "(float.fromhex('" << std::hexfloat << op->value << "'))";
break;
}
default:
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
}
}
void CodeGenTileLangPY::VisitExpr_(const StringImmNode *op,
std::ostream &os) { // NOLINT(*)
EscapeStringLiteral_(op->value, os);
}
void CodeGenTileLangPY::VisitExpr_(const CastNode *op,
std::ostream &os) { // NOLINT(*)
std::stringstream value;
PrintExpr_(op->value, value);
os << CastFromTo_(value.str(), op->value.dtype(), op->dtype);
}
void CodeGenTileLangPY::VisitExpr_(const AddNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("+", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const SubNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("-", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const MulNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("*", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const DivNode *op,
std::ostream &os) { // NOLINT(*)
if (op->dtype.is_int() || op->dtype.is_uint()) {
PrintBinaryExpr_("//", op->dtype, op->a, op->b, os);
} else {
PrintBinaryExpr_("/", op->dtype, op->a, op->b, os);
}
}
void CodeGenTileLangPY::VisitExpr_(const ModNode *op,
std::ostream &os) { // NOLINT(*)
ICHECK(op->dtype.is_int() || op->dtype.is_uint() || op->dtype.is_float())
<< "Expected floating point or integer dtype in Mod, but got "
<< op->dtype;
PrintBinaryExpr_("%", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const MinNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("min", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const MaxNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("max", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const EQNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("==", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const NENode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("!=", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const LTNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("<", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const LENode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("<=", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const GTNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_(">", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const GENode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_(">=", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const AndNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("and", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const OrNode *op,
std::ostream &os) { // NOLINT(*)
PrintBinaryExpr_("or", op->dtype, op->a, op->b, os);
}
void CodeGenTileLangPY::VisitExpr_(const NotNode *op,
std::ostream &os) { // NOLINT(*)
os << "(not ";
PrintExpr_(op->a, os);
os << ")";
}
void CodeGenTileLangPY::VisitExpr_(const SelectNode *op,
std::ostream &os) { // NOLINT(*)
os << "(";
PrintExpr_(op->true_value, os);
os << " if ";
PrintExpr_(op->condition, os);
os << " else ";
PrintExpr_(op->false_value, os);
os << ")";
}
void CodeGenTileLangPY::VisitExpr_(const RampNode *op,
std::ostream &os) { // NOLINT(*)
int lanes = op->dtype.lanes();
os << "(";
for (int i = 0; i < lanes; i++) {
os << "(" << PrintExpr_(op->base) << ")"
<< "+(" << PrintExpr_(op->stride) << "*" << i << ")";
if (i != lanes - 1)
os << ", ";
}
os << ")";
}
void CodeGenTileLangPY::VisitExpr_(const CallNode *op,
std::ostream &os) { // NOLINT(*)
if (auto opt_call_op = op->op.as<Op>()) {
const auto &call_op = opt_call_op.value();
if (op->op.same_as(builtin::ret())) {
os << "return " << RemoveOutermostParentheses(PrintExpr_(op->args[0]));
} else if (op->op.same_as(builtin::continue_loop())) {
os << "continue";
} else if (op->op.same_as(builtin::break_loop())) {
os << "break";
} else if (op->op.same_as(builtin_call_extern_) ||
op->op.same_as(builtin_call_pure_extern_)) {
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
PrintCallExtern_(GetType(ffi::GetRef<PrimExpr>(op)), func->value,
op->args, true, os);
} else if (op_attr_global_symbol_.count(call_op)) {
// call extern if the op itself have a global symbol.
PrintCallExtern_(GetType(ffi::GetRef<PrimExpr>(op)),
op_attr_global_symbol_[call_op], op->args, false, os);
} else if (op->op.same_as(builtin::large_uint_imm())) {
ICHECK_EQ(op->args.size(), 2U);
uint64_t low =
static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
uint64_t high =
static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
uint64_t val = (high << 32U) | low;
if (op->dtype == DataType::UInt(32)) {
std::ostringstream temp;
temp << val;
MarkConst(temp.str());
os << temp.str();
} else {
PrintType(op->dtype, os);
os << "(" << val << ")";
}
} else if (op->op.same_as(builtin::bitwise_and())) {
PrintBinaryIntrinsic_(op, "&", os);
} else if (op->op.same_as(builtin::bitwise_or())) {
PrintBinaryIntrinsic_(op, "|", os);
} else if (op->op.same_as(builtin::bitwise_xor())) {
PrintBinaryIntrinsic_(op, "^", os);
} else if (op->op.same_as(builtin::bitwise_not())) {
ICHECK_EQ(op->args.size(), 1U);
os << "~";
PrintExpr_(op->args[0], os);
} else if (op->op.same_as(builtin::shift_left())) {
PrintBinaryIntrinsic_(op, "<<", os);
} else if (op->op.same_as(builtin::shift_right())) {
PrintBinaryIntrinsic_(op, ">>", os);
} else if (op->op.same_as(builtin::if_then_else())) {
std::string cond = PrintExpr_(op->args[0]);
std::string true_val = PrintExpr_(op->args[1]);
std::string false_val = PrintExpr_(op->args[2]);
os << "(" << true_val << " if " << cond << " else " << false_val << ")";
} else if (op->op.same_as(builtin::isnullptr())) {
ICHECK_EQ(op->args.size(), 1U);
os << "(";
PrintExpr_(op->args[0], os);
os << " is None)";
} else if (op->op.same_as(builtin::isnan())) {
os << "(";
PrintExpr_(op->args[0], os);
os << " != ";
PrintExpr_(op->args[0], os);
os << ")";
} else {
LOG(FATAL) << "Unresolved call " << op->op;
}
} else if (auto opt = op->op.as<GlobalVar>()) {
const auto &gvar = opt.value();
auto callee_name = GetFunctionName_(gvar);
PrintCallExtern_(GetType(ffi::GetRef<PrimExpr>(op)), callee_name, op->args,
false, os);
} else {
LOG(FATAL)
<< "CodeGenTileLangPY: Unknown operation " << op->op
<< " is neither a recognized built-in, "
<< "nor a GlobalVar reference to another function in the IRModule";
}
}
void CodeGenTileLangPY::VisitExpr_(const BufferLoadNode *op,
std::ostream &os) { // NOLINT(*)
ICHECK_EQ(op->indices.size(), 1)
<< "Load from non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer load is not supported.";
DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
DataType element_dtype = op->buffer->dtype;
ICHECK_EQ(value_dtype, element_dtype)
<< "value_dtype and element_dtype must be same for a BufferLoadNode";
std::string ref = GetBufferRef_(op->dtype, op->buffer.get(), index);
os << ref;
}
void CodeGenTileLangPY::VisitStmt_(const BufferStoreNode *op) {
ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer store is not supported.";
DataType value_dtype = op->value.dtype();
DataType element_dtype = op->buffer->dtype;
PrimExpr index_expr = op->indices[0];
Var buffer_var = op->buffer->data;
ICHECK_EQ(value_dtype, element_dtype)
<< "value_dtype and element_dtype must be same for a BufferStoreNode";
std::string value = PrintExpr_(op->value);
std::string ref = GetBufferRef_(value_dtype, op->buffer.get(), index_expr);
PrintIndent();
stream << ref << " = " << RemoveOutermostParentheses(value) << "\n";
}
void CodeGenTileLangPY::VisitStmt_(const DeclBufferNode *op) {
PrintStmt_(op->body);
}
void CodeGenTileLangPY::VisitStmt_(const LetStmtNode *op) {
std::string value = PrintExpr_(op->value);
PrintIndent();
stream << AllocVarID(op->var.get()) << " = " << value << "\n";
PrintStmt_(op->body);
}
void CodeGenTileLangPY::VisitStmt_(const AllocateNode *op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
PrintIndent();
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
auto scope = GetPtrStorageScope(op->buffer_var);
alloc_storage_scope_[op->buffer_var.get()] = scope;
stream << vid << " = [None] * " << constant_size << "\n";
RegisterHandleType_(op->buffer_var.get(), op->dtype);
PrintStmt_(op->body);
}
void CodeGenTileLangPY::VisitStmt_(const AttrStmtNode *op) {
PrintStmt_(op->body);
}
void CodeGenTileLangPY::VisitStmt_(const ForNode *op) {
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
stream << "for " << vid << " in range(";
if (is_zero(op->min)) {
PrintExpr_(op->extent, stream);
} else {
PrintExpr_(op->min, stream);
stream << ", ";
PrimExpr upper_bound = arith::Analyzer().Simplify(op->extent + op->min);
PrintExpr_(upper_bound, stream);
}
stream << "):\n";
int for_scope = BeginScope();
PrintStmt_(op->body);
EndScope(for_scope);
}
void CodeGenTileLangPY::VisitStmt_(const WhileNode *op) {
std::string cond = PrintExpr_(op->condition);
PrintIndent();
stream << "while " << RemoveOutermostParentheses(cond) << ":\n";
int while_scope = BeginScope();
PrintStmt_(op->body);
EndScope(while_scope);
}
void CodeGenTileLangPY::VisitStmt_(const IfThenElseNode *op) {
std::string cond = PrintExpr_(op->condition);
PrintIndent();
stream << "if " << RemoveOutermostParentheses(cond) << ":\n";
int then_scope = BeginScope();
PrintStmt_(op->then_case);
EndScope(then_scope);
if (op->else_case) {
PrintIndent();
stream << "else:\n";
int else_scope = BeginScope();
PrintStmt_(op->else_case.value());
EndScope(else_scope);
}
}
void CodeGenTileLangPY::VisitStmt_(const SeqStmtNode *op) {
for (Stmt stmt : op->seq) {
PrintStmt_(stmt);
}
}
void CodeGenTileLangPY::VisitStmt_(const EvaluateNode *op) {
if (is_const_int(op->value))
return;
std::string vid = PrintExpr_(op->value);
if (!vid.empty()) {
PrintIndent();
stream << vid << "\n";
}
}
void CodeGenTileLangPY::VisitStmt_(const AssertStmtNode *op) {
std::string cond = PrintExpr_(op->condition);
PrintIndent();
if (const auto *str = op->message.as<StringImmNode>()) {
stream << "assert " << cond << ", ";
EscapeStringLiteral_(str->value, stream);
stream << "\n";
} else {
stream << "assert " << cond << "\n";
}
PrintStmt_(op->body);
}
std::string CodeGenTileLangPY::CastFromTo_(const std::string &value,
DataType from, DataType target) {
if (from == target)
return value;
std::ostringstream os;
PrintType(target, os);
os << "(" << value << ")";
return os.str();
}
void CodeGenTileLangPY::PrintBinaryExpr_(const std::string &opstr,
DataType dtype, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os) { // NOLINT(*)
ICHECK_EQ(dtype.lanes(), 1);
if (isalpha(opstr[0]) && opstr != "and" && opstr != "or") {
os << opstr << '(';
PrintExpr_(lhs, os);
os << ", ";
PrintExpr_(rhs, os);
os << ')';
} else {
os << '(';
PrintExpr_(lhs, os);
os << ' ' << opstr << ' ';
PrintExpr_(rhs, os);
os << ')';
}
}
void CodeGenTileLangPY::PrintBinaryIntrinsic_(const CallNode *op,
const char *opstr,
std::ostream &os) { // NOLINT(*)
ICHECK_EQ(op->dtype.lanes(), 1);
ICHECK_EQ(op->args.size(), 2U);
os << '(';
PrintExpr_(op->args[0], os);
os << ' ' << opstr << ' ';
PrintExpr_(op->args[1], os);
os << ')';
}
void CodeGenTileLangPY::PrintCallExtern_(Type ret_type,
ffi::String global_symbol,
const ffi::Array<PrimExpr> &args,
bool skip_first_arg,
std::ostream &os) { // NOLINT(*)
os << global_symbol << "(";
for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
PrintExpr_(args[i], os);
if (i < args.size() - 1) {
os << ", ";
}
}
os << ")";
}
// Print a reference expression to a buffer.
std::string CodeGenTileLangPY::GetBufferRef_(DataType t,
const BufferNode *buffer,
PrimExpr index) {
const VarNode *buffer_var = buffer->data.get();
std::string vid = GetVarID(buffer_var);
DataType buffer_element_dtype = buffer->dtype;
ICHECK(HandleTypeMatch_(buffer_var, buffer_element_dtype));
ICHECK_EQ(t, buffer_element_dtype);
std::string index_str = PrintExpr_(index);
return vid + "[" + index_str + "]";
}
void CodeGenTileLangPY::RegisterHandleType_(const VarNode *buf_var,
DataType t) {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) {
handle_data_type_[buf_var] = t;
} else {
ICHECK(it->second == t) << "conflicting buf var type";
}
}
bool CodeGenTileLangPY::HandleTypeMatch_(const VarNode *buf_var,
DataType t) const {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end())
return false;
return it->second == t;
}
void CodeGenTileLangPY::EscapeStringLiteral_(const std::string &s,
std::ostream &os) {
os << '"';
for (unsigned char c : s) {
switch (c) {
case '\\':
os << "\\\\";
break;
case '"':
os << "\\\"";
break;
case '\n':
os << "\\n";
break;
case '\r':
os << "\\r";
break;
case '\t':
os << "\\t";
break;
case '\f':
os << "\\f";
break;
case '\b':
os << "\\b";
break;
default:
// Handle non-printable and non-ASCII characters
if (c < 32 || c == 127) {
// Output as \xHH
os << "\\x";
const char hex[] = "0123456789abcdef";
os << hex[(c >> 4) & 0xF];
os << hex[c & 0xF];
} else {
os << c;
}
break;
}
}
os << '"';
}
} // namespace codegen
} // namespace tvm
/*!
* \file codegen_py.h
* \brief Common utilities to generate simple Python code.
*/
#ifndef TVM_TL_TARGET_CODEGEN_PY_H_
#define TVM_TL_TARGET_CODEGEN_PY_H_
#include <tvm/ir/op.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <string>
#include <unordered_map>
// from tvm/src/
#include "target/source/codegen_source_base.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace codegen {
using namespace tir;
/*!
* \brief A base class to generate simple Python code.
*/
class CodeGenTileLangPY
: public ExprFunctor<void(const PrimExpr &, std::ostream &)>,
public StmtFunctor<void(const Stmt &)>,
public CodeGenSourceBase {
public:
/*!
* \brief Add the function definition to the generated module.
* \param gvar The GlobalVar representing the function.
* \param func The function to be compiled.
*/
virtual void AddFunction(const GlobalVar &gvar, const PrimFunc &func);
/*!
* \brief Finalize the compilation and return the code.
* \return The code.
*/
virtual std::string Finish();
protected:
/*!
* \brief Get the name of a declared function
* \param gvar The GlobalVar of the function
* \returns The string name of the function
*/
ffi::String GetFunctionName_(const GlobalVar &gvar);
/*!
* \brief Reserve the function name in the generated module.
*
* \param gvar The GlobalVar representing the function.
* \param func The function to be compiled.
* \param whether to append return 0 in the end.
*/
virtual void RegisterFunction_(const GlobalVar &gvar, const PrimFunc &func);
/*!
* \brief Initialize codegen state for generating f.
* \param f The function to be compiled.
*/
virtual void InitFuncState_(const PrimFunc &f);
/*! \brief Print the function signature before ":"
* \param function_name The name of the function
* \param func The function whose signature should be printed
* \param os The output stream
*/
virtual void PrintFunctionSignature_(const ffi::String &function_name,
const PrimFunc &func,
std::ostream &os); // NOLINT(*)
/*!
* \brief Print the function decorator
* \param os The output stream
*/
virtual void PrintFuncDecorator_(std::ostream &os) {} // NOLINT(*)
/*!
* \brief Insert statement before function body.
* \param f The function to be compiled.
*/
virtual void PreFunctionBody_(const PrimFunc &f) {}
protected:
/*! \brief reserves common Python keywords */
void ReserveKeywordsAsUnique_();
void PrintSSAAssign(const std::string &target, const std::string &src,
DataType t) override;
protected:
/*!
* \brief Print Type representation of type type.
* \param t The type representation.
* \param os The output stream
*/
void PrintType(DataType type, std::ostream &os) override; // NOLINT(*)
/*!
* \brief Print the Stmt n to CodeGenTileLangPY->stream
* \param n The statement to be printed.
*/
void PrintStmt_(const Stmt &n) { VisitStmt(n); }
/*!
* \brief Print the expression n into os
* \param n The expression to be printed.
* \param os The output stream
*/
void PrintExpr_(const PrimExpr &n, std::ostream &os) { // NOLINT(*)
VisitExpr(n, os);
}
/*!
* \brief Same as PrintExpr_, but simply returns result string
* \param n The expression to be printed.
*/
std::string PrintExpr_(const PrimExpr &n) {
std::ostringstream os;
PrintExpr_(n, os);
return os.str();
}
// expression
void VisitExpr_(const VarNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const IntImmNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const StringImmNode *op,
std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const CastNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const AddNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const SubNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MulNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const DivNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const ModNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MinNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const MaxNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const EQNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const NENode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const LTNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const LENode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const GTNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const GENode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const AndNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const OrNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const NotNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const SelectNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const RampNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const CallNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const BufferLoadNode *op,
std::ostream &os) override; // NOLINT(*)
// statment
void VisitStmt_(const BufferStoreNode *op) override;
void VisitStmt_(const DeclBufferNode *op) override;
void VisitStmt_(const LetStmtNode *op) override;
void VisitStmt_(const AllocateNode *op) override;
void VisitStmt_(const AttrStmtNode *op) override;
void VisitStmt_(const ForNode *op) override;
void VisitStmt_(const WhileNode *op) override;
void VisitStmt_(const IfThenElseNode *op) override;
void VisitStmt_(const SeqStmtNode *op) override;
void VisitStmt_(const EvaluateNode *op) override;
void VisitStmt_(const AssertStmtNode *op) override;
protected:
// Get a string of type casting
virtual std::string CastFromTo_(const std::string &value, DataType from,
DataType target);
virtual void PrintBinaryExpr_(const std::string &opstr, DataType dtype,
PrimExpr lhs, PrimExpr rhs,
std::ostream &os); // NOLINT(*)
virtual void PrintBinaryIntrinsic_(const CallNode *op, const char *opstr,
std::ostream &os); // NOLINT(*)
/*!
* \brief Print external function call.
* \param ret_type The return type.
* \param global_symbol The symbolc of the target function.
* \param args The arguments to the function.
* \param skip_first_arg Whether to skip the first arguments.
* \param os The output stream.
*/
virtual void PrintCallExtern_(Type ret_type, ffi::String global_symbol,
const ffi::Array<PrimExpr> &args,
bool skip_first_arg,
std::ostream &os); // NOLINT(*)
// Print reference to a buffer as type t in index.
virtual std::string GetBufferRef_(DataType t, const BufferNode *buffer,
PrimExpr index);
/*!
* \brief Register the data type of buf_var
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
void RegisterHandleType_(const VarNode *buf_var, DataType t);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
bool HandleTypeMatch_(const VarNode *buf_var, DataType t) const;
protected:
/*! \brief the storage scope of allocation */
std::unordered_map<const VarNode *, std::string> alloc_storage_scope_;
/*! \brief Record of ops that have pre-defined global symbol. */
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ =
Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
// cache commonly used ops
const Op &builtin_call_extern_ = builtin::call_extern();
const Op &builtin_call_pure_extern_ = builtin::call_pure_extern();
private:
/*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode *, DataType> handle_data_type_;
/* \brief Map of GlobalVar to their symbol.
*
* For externally-exposed functions, this is given by the
* tvm::attr::kTarget attribute of the PrimFunc. For internal
* functions, this is the name of the function's GlobalVar, possibly
* altered to prevent duplicate names.
*/
std::unordered_map<GlobalVar, ffi::String> internal_functions_;
/* \brief Name supply to generate unique function names */
NameSupply func_name_supply_;
/*!
* \brief Escape a string to be a valid Python double-quoted string literal.
* \param s The input string to escape.
* \param os The output stream to write the escaped string to.
*/
void EscapeStringLiteral_(const std::string &s, std::ostream &os);
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TL_TARGET_CODEGEN_PY_H_
/*!
* \file target/codegen_utils.cc
* \brief Shared utility functions for code generation
*/
#include "codegen_utils.h"
namespace tvm {
namespace codegen {
bool CheckOutermostParenthesesMatch(const std::string &s) {
if (!s.empty() && s.front() == '(' && s.back() == ')') {
size_t len = s.size();
int n_unmatched = 0;
for (size_t i = 0; i < len; ++i) {
if (s[i] == '(') {
n_unmatched++;
} else if (s[i] == ')') {
n_unmatched--;
}
if (n_unmatched < 0) {
return false;
}
if (n_unmatched == 0) {
return i == len - 1;
}
}
}
return false;
}
std::string RemoveOutermostParentheses(const std::string &s) {
if (CheckOutermostParenthesesMatch(s)) {
return s.substr(1, s.size() - 2);
} else {
return s;
}
}
} // namespace codegen
} // namespace tvm
/*!
* \file target/codegen_utils.h
* \brief Shared utility functions for code generation
*/
#ifndef TVM_TARGET_CODEGEN_UTILS_H_
#define TVM_TARGET_CODEGEN_UTILS_H_
#include <string>
namespace tvm {
namespace codegen {
/*!
* \brief Check if the outermost parentheses match
* \param s The input string
* \return true if the first character is '(' and the last character is ')'
* and they form a matching pair
*/
bool CheckOutermostParenthesesMatch(const std::string &s);
/*!
* \brief Remove outermost parentheses if they match
* \param s The input string
* \return The string with outermost parentheses removed if they match,
* otherwise return the original string
*/
std::string RemoveOutermostParentheses(const std::string &s);
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_CODEGEN_UTILS_H_
......@@ -2,6 +2,7 @@
#include "runtime/cuda/cuda_module.h"
#include "runtime/pack_args.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
namespace tvm {
namespace codegen {
......@@ -24,7 +25,11 @@ ExtractFuncInfo(const IRModule &mod) {
continue;
}
}
info.arg_types.push_back(f->params[i].dtype());
DataType dtype = f->params[i].dtype();
// Device runtime cannot directly take bool arguments, map to int32.
if (dtype.is_bool())
dtype = DataType::Int(32);
info.arg_types.push_back(dtype);
}
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
......@@ -62,7 +67,10 @@ ffi::Module BuildTileLangCUDA(IRModule mod, Target target) {
std::string ptx;
if (const auto f =
ffi::Function::GetGlobal("tilelang_callback_cuda_compile")) {
ptx = (*f)(code, target).cast<std::string>();
// Fetch current pass context config and pass into the compile callback
tvm::transform::PassContext pass_ctx =
tvm::transform::PassContext::Current();
ptx = (*f)(code, target, pass_ctx->config).cast<std::string>();
if (ptx[0] != '/')
fmt = "cubin";
} else {
......
#include "codegen_cutedsl.h"
#include "runtime/cuda/cuda_module.h"
#include "runtime/pack_args.h"
#include <tvm/ffi/reflection/registry.h>
namespace tvm {
namespace codegen {
static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info;
for (size_t i = 0; i < f->params.size(); ++i) {
if (f->params[i]->dtype.is_handle()) {
auto ptr = f->params[i]->type_annotation.as<PointerTypeNode>();
if (ptr && ptr->storage_scope == "grid_constant") {
info.arg_types.push_back(DataType(runtime::kDLGridConstant, 64, 1));
continue;
}
}
info.arg_types.push_back(f->params[i].dtype());
}
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag);
}
}
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info;
}
return fmap;
}
ffi::Module BuildTileLangCuTeDSLWithoutCompile(IRModule mod, Target target) {
CodeGenTileLangCuTeDSL cg;
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCuTeDSL: Can only take PrimFunc";
auto gvar = Downcast<GlobalVar>(kv.first);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
cg.AddFunction(gvar, f);
}
std::string code = cg.Finish();
if (const auto f =
ffi::Function::GetGlobal("tilelang_callback_cutedsl_postproc")) {
code = (*f)(code, target).cast<std::string>();
}
return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("target.build.tilelang_cutedsl_without_compile",
BuildTileLangCuTeDSLWithoutCompile);
}
} // namespace codegen
} // namespace tvm
......@@ -35,7 +35,11 @@ ExtractFuncInfo(const IRModule &mod) {
continue;
}
}
info.arg_types.push_back(f->params[i].dtype());
DataType dtype = f->params[i].dtype();
// Device runtime cannot directly take bool arguments, map to int32.
if (dtype.is_bool())
dtype = DataType::Int(32);
info.arg_types.push_back(dtype);
}
if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
tir::attr::kKernelLaunchParams)) {
......
......@@ -12,7 +12,11 @@ using cutlass::bfloat16_t;
using cutlass::half_t;
#define TL_DEVICE __forceinline__ __device__
#define TL_NOT_IMPLEMENTED() \
{ \
printf("%s not implemented\n", __PRETTY_FUNCTION__); \
asm volatile("brkpt;\n"); \
}
template <typename T> struct normalize_atomic_type {
using type = T;
};
......@@ -42,98 +46,284 @@ template <> TL_DEVICE __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) {
#endif
template <typename T1, typename T2>
TL_DEVICE void AtomicMax(T1 &ref, T2 val,
TL_DEVICE void AtomicMax(T1 *ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr ((std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) &&
memory_order == int(cuda::memory_order_relaxed)) {
atomicMax(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
T1 *address = ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
// There is no implementation of atomicMax for half and bf16 in cuda.
// We simulate this process by atomicCAS loop.
unsigned short *address_as_ushort =
reinterpret_cast<unsigned short *>(address);
unsigned short val_as_ushort = *reinterpret_cast<unsigned short *>(&val);
unsigned short old_val_ushort = *address_as_ushort;
while (val > *reinterpret_cast<T1 *>(&old_val_ushort)) {
unsigned short assumed_val_ushort = old_val_ushort;
old_val_ushort =
atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort);
if (assumed_val_ushort == old_val_ushort) {
break;
}
}
} else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
}
}
template <typename T1, typename T2>
TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val,
TL_DEVICE T1 AtomicMaxRet(T1 *ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr ((std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) &&
memory_order == int(cuda::memory_order_relaxed)) {
return static_cast<T1>(
atomicMax(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
T1 *address = ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
unsigned short *address_as_ushort =
reinterpret_cast<unsigned short *>(address);
unsigned short val_as_ushort = *reinterpret_cast<unsigned short *>(&val);
unsigned short old_val_ushort = *address_as_ushort;
while (val > *reinterpret_cast<T1 *>(&old_val_ushort)) {
unsigned short assumed_val_ushort = old_val_ushort;
old_val_ushort =
atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort);
if (assumed_val_ushort == old_val_ushort) {
break;
}
}
return static_cast<T1>(*reinterpret_cast<T1 *>(&old_val_ushort));
} else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>(
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
}
}
template <typename T1, typename T2>
TL_DEVICE void AtomicMin(T1 &ref, T2 val,
TL_DEVICE void AtomicMin(T1 *ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr ((std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) &&
memory_order == int(cuda::memory_order_relaxed)) {
atomicMin(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
T1 *address = ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
// There is no implementation of atomicMin for half and bf16 in cuda.
// We simulate this process by atomicCAS loop.
unsigned short *address_as_ushort =
reinterpret_cast<unsigned short *>(address);
unsigned short val_as_ushort = *reinterpret_cast<unsigned short *>(&val);
unsigned short old_val_ushort = *address_as_ushort;
while (val < *reinterpret_cast<T1 *>(&old_val_ushort)) {
unsigned short assumed_val_ushort = old_val_ushort;
old_val_ushort =
atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort);
if (assumed_val_ushort == old_val_ushort) {
break;
}
}
} else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
}
}
template <typename T1, typename T2>
TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val,
TL_DEVICE T1 AtomicMinRet(T1 *ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr ((std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) &&
memory_order == int(cuda::memory_order_relaxed)) {
return static_cast<T1>(
atomicMin(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
T1 *address = ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
unsigned short *address_as_ushort =
reinterpret_cast<unsigned short *>(address);
unsigned short val_as_ushort = *reinterpret_cast<unsigned short *>(&val);
unsigned short old_val_ushort = *address_as_ushort;
while (val < *reinterpret_cast<T1 *>(&old_val_ushort)) {
unsigned short assumed_val_ushort = old_val_ushort;
old_val_ushort =
atomicCAS(address_as_ushort, assumed_val_ushort, val_as_ushort);
if (assumed_val_ushort == old_val_ushort) {
break;
}
}
return static_cast<T1>(*reinterpret_cast<T1 *>(&old_val_ushort));
} else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>(
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
}
}
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 890))
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
TL_DEVICE void AtomicAdd(T1 *address, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr ((std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) &&
memory_order == int(cuda::memory_order_relaxed)) {
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
if (memory_order == int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
// Since atomic ref do not support memory order, we need to inline ptx
// code here for each situation
if constexpr (std::is_same_v<NT1, half>) {
// fp16
__half ret_val;
unsigned short ret_val_cast =
*reinterpret_cast<unsigned short *>(&ret_val);
unsigned long long ref_address =
reinterpret_cast<unsigned long long>(address);
unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
if (memory_order == int(cuda::memory_order_release) ||
memory_order == int(cuda::memory_order_consume)) {
asm volatile("atom.release.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acquire)) {
asm volatile("atom.acquire.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acq_rel) ||
memory_order == int(cuda::memory_order_seq_cst)) {
asm volatile("atom.acq_rel.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
}
} else if constexpr (std::is_same_v<NT1, __nv_bfloat16>) {
// bf16
__nv_bfloat16 ret_val;
unsigned short ret_val_cast =
*reinterpret_cast<unsigned short *>(&ret_val);
unsigned long long ref_address =
reinterpret_cast<unsigned long long>(address);
unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
if (memory_order == int(cuda::memory_order_release) ||
memory_order == int(cuda::memory_order_consume)) {
asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acquire)) {
asm volatile("atom.acquire.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acq_rel) ||
memory_order == int(cuda::memory_order_seq_cst)) {
asm volatile("atom.acq_rel.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
}
}
}
} else {
atomicAdd(reinterpret_cast<NT1 *>(address), cuda_cast<NT1>(val));
}
}
#else
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 *address, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
(void)memory_order;
atomicAdd(reinterpret_cast<NT1 *>(address), cuda_cast<NT1>(val));
}
#endif
template <typename T1, typename T2>
TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr ((std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) &&
memory_order == int(cuda::memory_order_relaxed)) {
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
if (memory_order == int(cuda::memory_order_relaxed)) {
return static_cast<T1>(
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
} else {
if constexpr (std::is_same_v<NT1, half>) {
// fp16
__half ret_val;
unsigned short ret_val_cast =
*reinterpret_cast<unsigned short *>(&ret_val);
unsigned long long ref_address =
reinterpret_cast<unsigned long long>(address);
unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
if (memory_order == int(cuda::memory_order_release) ||
memory_order == int(cuda::memory_order_consume)) {
asm volatile("atom.release.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acquire)) {
asm volatile("atom.acquire.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acq_rel) ||
memory_order == int(cuda::memory_order_seq_cst)) {
asm volatile("atom.acq_rel.gpu.global.add.noftz.f16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
}
return static_cast<T1>(*reinterpret_cast<__half *>(&ret_val_cast));
} else if constexpr (std::is_same_v<NT1, __nv_bfloat16>) {
// bf16
__nv_bfloat16 ret_val;
unsigned short ret_val_cast =
*reinterpret_cast<unsigned short *>(&ret_val);
unsigned long long ref_address =
reinterpret_cast<unsigned long long>(address);
unsigned short val_cast = *reinterpret_cast<unsigned short *>(&val);
if (memory_order == int(cuda::memory_order_release) ||
memory_order == int(cuda::memory_order_consume)) {
asm volatile("atom.release.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acquire)) {
asm volatile("atom.acquire.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
} else if (memory_order == int(cuda::memory_order_acq_rel) ||
memory_order == int(cuda::memory_order_seq_cst)) {
asm volatile("atom.acq_rel.gpu.global.add.noftz.bf16 %0, [%1], %2;"
: "=h"(ret_val_cast)
: "l"(ref_address), "h"(val_cast)
: "memory");
}
return static_cast<T1>(
*reinterpret_cast<__nv_bfloat16 *>(&ret_val_cast));
}
}
} else {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
return static_cast<T1>(
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order)));
#else
TL_NOT_IMPLEMENTED();
#endif
}
}
......@@ -456,16 +646,66 @@ AtomicAddx4Ret(float *ref, float *val,
return ret_val;
}
}
#else
TL_DEVICE void AtomicAddx2(float *ref, float *val,
int memory_order = int(cuda::memory_order_relaxed)) {
(void)memory_order;
float2 add_val = *reinterpret_cast<float2 *>(val);
atomicAdd(ref + 0, add_val.x);
atomicAdd(ref + 1, add_val.y);
}
TL_DEVICE float2
AtomicAddx2Ret(float *ref, float *val,
int memory_order = int(cuda::memory_order_relaxed)) {
(void)memory_order;
float2 add_val = *reinterpret_cast<float2 *>(val);
float2 ret;
ret.x = atomicAdd(ref + 0, add_val.x);
ret.y = atomicAdd(ref + 1, add_val.y);
return ret;
}
TL_DEVICE void AtomicAddx4(float *ref, float *val,
int memory_order = int(cuda::memory_order_relaxed)) {
(void)memory_order;
float4 add_val = *reinterpret_cast<float4 *>(val);
atomicAdd(ref + 0, add_val.x);
atomicAdd(ref + 1, add_val.y);
atomicAdd(ref + 2, add_val.z);
atomicAdd(ref + 3, add_val.w);
}
TL_DEVICE float4
AtomicAddx4Ret(float *ref, float *val,
int memory_order = int(cuda::memory_order_relaxed)) {
(void)memory_order;
float4 add_val = *reinterpret_cast<float4 *>(val);
float4 ret;
ret.x = atomicAdd(ref + 0, add_val.x);
ret.y = atomicAdd(ref + 1, add_val.y);
ret.z = atomicAdd(ref + 2, add_val.z);
ret.w = atomicAdd(ref + 3, add_val.w);
return ret;
}
#endif
template <typename T> TL_DEVICE T AtomicLoad(T &ref, int memory_order) {
cuda::atomic_ref<T, cuda::thread_scope_device> aref(ref);
template <typename T> TL_DEVICE T AtomicLoad(T *ref, int memory_order) {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<T, cuda::thread_scope_device> aref(*ref);
return aref.load(cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
}
template <typename T1, typename T2>
TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) {
TL_DEVICE void AtomicStore(T1 *ref, T2 value, int memory_order) {
using NT1 = typename normalize_atomic_type<T1>::type;
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(ref);
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*ref);
aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
#endif
}
......@@ -127,6 +127,16 @@ TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2,
return result;
}
TL_DEVICE int4_t make_int4(short x0, short x1, short y0, short y1, short z0,
short z1, short w0, short w1) {
int4_t result;
*((short2 *)&result.x) = make_short2(x0, x1);
*((short2 *)&result.y) = make_short2(y0, y1);
*((short2 *)&result.z) = make_short2(z0, z1);
*((short2 *)&result.w) = make_short2(w0, w1);
return result;
}
// Pack eight int values.
TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0,
int z1, int w0, int w1) {
......
......@@ -26,7 +26,8 @@ template <int N> TL_DEVICE void cp_async_wait() {
}
template <int N>
TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
TL_DEVICE void cp_async_gs(void const *const smem_addr,
void const *global_ptr) {
static_assert(N == 16 || N == 8 || N == 4);
unsigned int addr = smem_ptr_to_uint(smem_addr);
if constexpr (N == 16) {
......@@ -37,7 +38,7 @@ TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
"cp.async.cg.shared.global [%0], [%1], %2;"
#endif
::"r"(addr),
"l"((void *)(global_ptr)), "n"(N));
"l"((void const *)(global_ptr)), "n"(N));
} else {
asm volatile(
#if TL_ENABLE_L2_PREFETCH
......@@ -46,13 +47,13 @@ TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
"cp.async.ca.shared.global [%0], [%1], %2;"
#endif
::"r"(addr),
"l"((void *)(global_ptr)), "n"(N));
"l"((void const *)(global_ptr)), "n"(N));
}
}
template <int N>
TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
void *global_ptr, bool cond) {
void const *global_ptr, bool cond) {
static_assert(N == 16 || N == 8 || N == 4);
int bytes = cond ? N : 0;
unsigned int addr = smem_ptr_to_uint(smem_addr);
......@@ -64,7 +65,7 @@ TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
"cp.async.cg.shared.global [%0], [%1], %2, %3;"
#endif
::"r"(addr),
"l"((void *)(global_ptr)), "n"(N), "r"(bytes));
"l"((void const *)(global_ptr)), "n"(N), "r"(bytes));
} else {
asm volatile(
#if TL_ENABLE_L2_PREFETCH
......@@ -73,7 +74,7 @@ TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
"cp.async.ca.shared.global [%0], [%1], %2, %3;"
#endif
::"r"(addr),
"l"((void *)(global_ptr)), "n"(N), "r"(bytes));
"l"((void const *)(global_ptr)), "n"(N), "r"(bytes));
}
}
......
......@@ -5,6 +5,7 @@
namespace tl {
// 256-bit load for longlong4
__device__ __forceinline__ longlong4 ld_global_256(const longlong4 *ptr) {
longlong4 ret;
asm volatile("ld.global.v4.s64 {%0, %1, %2, %3}, [%4];"
......@@ -13,13 +14,18 @@ __device__ __forceinline__ longlong4 ld_global_256(const longlong4 *ptr) {
return ret;
}
__device__ __forceinline__ void st_global_256(longlong4 *ptr, longlong4 &val) {
asm volatile("st.global.v4.s64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
// 256-bit load for ulonglong4
__device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) {
ulonglong4 ret;
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
return ret;
}
__device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) {
// Generic 256-bit load for FP8 types (returns ulonglong4)
template <typename T>
__device__ __forceinline__ ulonglong4 ld_global_256(const T *ptr) {
ulonglong4 ret;
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
......@@ -27,6 +33,22 @@ __device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) {
return ret;
}
// 256-bit store for longlong4
__device__ __forceinline__ void st_global_256(longlong4 *ptr, longlong4 &val) {
asm volatile("st.global.v4.s64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
// 256-bit store for ulonglong4 with non-const reference
__device__ __forceinline__ void st_global_256(ulonglong4 *ptr,
ulonglong4 &val) {
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
// 256-bit store for ulonglong4 with const reference
// must be const &val, otherwise the compiler will generate a temporary variable
// and compilation will fail if we have st_global_256(ptr, ld_global_256(ptr))
__device__ __forceinline__ void st_global_256(ulonglong4 *ptr,
......@@ -36,20 +58,22 @@ __device__ __forceinline__ void st_global_256(ulonglong4 *ptr,
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e4_32_t *ptr) {
ulonglong4 ret;
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
return ret;
// Generic 256-bit store for FP8 types
template <typename T>
__device__ __forceinline__ void st_global_256(T *ptr, const ulonglong4 &val) {
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
__device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr,
fp8_e4_32_t &val8) {
ulonglong4 &val = *((ulonglong4 *)&val8);
// Generic 256-bit store for FP8 types with non-const reference
template <typename T>
__device__ __forceinline__ void st_global_256(T *ptr, T &val) {
ulonglong4 &val_u64 = *((ulonglong4 *)&val);
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
: "l"(ptr), "l"(val_u64.x), "l"(val_u64.y), "l"(val_u64.z),
"l"(val_u64.w));
}
__device__ __forceinline__ unsigned long long
......@@ -95,38 +119,38 @@ __device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col,
}
}
template <int N, typename dst_t>
template <int N, bool pack16, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp32bNx, 7, N>(tmem_start_col + tmem_col_offset,
dst_ptr);
tcgen05_ld_core<tl::tmem_ld_32dp32bNx<pack16>, 7, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load();
}
template <int N, typename dst_t>
template <int N, bool pack16, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp64bNx, 7, N>(tmem_start_col + tmem_col_offset,
dst_ptr);
tcgen05_ld_core<tl::tmem_ld_32dp64bNx<pack16>, 7, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load();
}
template <int N, typename dst_t>
template <int N, bool pack16, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp128bNx, 6, N>(
tcgen05_ld_core<tl::tmem_ld_32dp128bNx<pack16>, 6, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load();
}
template <int N, typename dst_t>
template <int N, bool pack16, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp256bNx, 5, N>(
tcgen05_ld_core<tl::tmem_ld_32dp256bNx<pack16>, 5, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load();
}
......
......@@ -15,14 +15,14 @@ enum class CacheHintSm90 : uint64_t {
};
template <typename BarrierType = uint64_t>
TL_DEVICE void tma_load(void *smem_ptr, void *gmem_ptr, BarrierType &smem_mbar,
uint32_t size) {
TL_DEVICE void tma_load(void *smem_ptr, void const *gmem_ptr,
BarrierType &smem_mbar, uint32_t size) {
uint32_t smem_int_mbar =
smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar));
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::"
"bytes [%0], [%1], %2, [%3]; \n" ::"r"(smem_int_ptr),
"l"(gmem_ptr), "r"(size), "r"(smem_int_mbar)
"l"((void const *)gmem_ptr), "r"(size), "r"(smem_int_mbar)
:);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment