Commit ece99756 authored by LeiWang1999's avatar LeiWang1999
Browse files

Merge branch 'main' of https://github.com/microsoft/TileLang into main

parents ee973b49 913d14f2
......@@ -99,6 +99,8 @@ tilelang_file_glob(GLOB TILE_LANG_SRCS
src/transform/*.cc
src/op/*.cc
src/target/utils.cc
src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc
)
# Include CUDA source files if CUDA is enabled
......
......@@ -7,13 +7,33 @@
*
*/
#include <tvm/arith/analyzer.h>
#include <tvm/script/ir_builder/tir/ir.h>
namespace tvm {
namespace tl {
constexpr const char *tilelang_is_cpu_kernel_frame =
"tilelang.is_cpu_kernel_frame";
using namespace script::ir_builder::tir;
static ForFrame MakeIterVarFrame(std::string name, PrimExpr dom) {
using namespace tvm::tir;
Var var = Var(name);
// Create a frame that represents a loop over the given domain.
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
n->vars.push_back(var);
n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), 1);
ICHECK_EQ(doms.size(), 1);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body);
};
return ForFrame(n);
}
ForFrame ParallelFor(Array<PrimExpr> extents,
Map<String, ObjectRef> annotations) {
using namespace tvm::tir;
......@@ -121,6 +141,23 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
Array<PrimExpr> block_size,
Map<String, ObjectRef> attrs) {
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
// If the kernel is a CPU kernel, we don't need to launch any threads.
bool is_cpu_kernel_frame =
attrs.defined() && attrs.count(tilelang_is_cpu_kernel_frame);
if (is_cpu_kernel_frame) {
ICHECK(grid_size.size() >= 0);
ICHECK(block_size.size() == 0) << "CPU kernel cannot have block size";
ICHECK(attrs.defined());
// create grid loop var
for (int i = 0; i < grid_size.size(); i++) {
n->frames.push_back(
MakeIterVarFrame("block_var_" + std::to_string(i), grid_size[i]));
}
// Launch CPU Kernel
} else {
// Launch GPU Kernel
ICHECK(grid_size.size() <= 3);
if (grid_size.size() > 0)
n->frames.push_back(LaunchThread("blockIdx.x", grid_size[0]));
......@@ -139,6 +176,8 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
} else {
n->frames.push_back(Block(""));
}
}
if (attrs.defined()) {
auto empty_block = Block("");
empty_block->annotations = attrs;
......
......@@ -138,6 +138,8 @@ For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
}
Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
bool is_cpu_target = target->GetTargetDeviceType() == kDLCPU;
Stmt ldsm_stmt = LowerLDSMCopy(T, analyzer);
if (ldsm_stmt.defined())
return ldsm_stmt;
......@@ -148,12 +150,19 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
For vectorized_thread_loop;
auto par_op = std::make_unique<ParallelOp>(fused_loop);
if (is_cpu_target) {
vectorized_thread_loop = VectorizeLoop(fused_loop);
} else {
par_op->InferLayout({T.target, T.block_size, T.layout_map, T.buffer_remap},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
vectorized_thread_loop = VectorizeLoop(thread_loop);
}
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_c_host.cc
*/
#include "codegen_cpp.h"
#include <tvm/relay/executor.h>
#include <tvm/relay/runtime.h>
#include <tvm/runtime/crt/error_codes.h>
#include <tvm/runtime/module.h>
#include <tvm/target/codegen.h>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "support/str_escape.h"
#include "target/build_common.h"
#include "target/func_registry_generator.h"
#include "target/source/codegen_params.h"
namespace tvm {
namespace codegen {
CodeGenTileLangCPP::CodeGenTileLangCPP() {
module_name_ = name_supply_->FreshName("__tvm_module_ctx");
}
void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts,
bool emit_fwd_func_decl, std::string target_str,
const std::unordered_set<std::string> &devices) {
emit_asserts_ = emit_asserts;
emit_fwd_func_decl_ = emit_fwd_func_decl;
declared_globals_.clear();
decl_stream << "// tilelang target: " << target_str << "\n";
decl_stream << "#include <tl_templates/cpp/common.h>\n";
decl_stream << "#include <tl_templates/cpp/gemm.h>\n";
decl_stream << "\n";
CodeGenC::Init(output_ssa);
}
void CodeGenTileLangCPP::InitGlobalContext() {
decl_stream << "void* " << tvm::runtime::symbol::tvm_module_ctx
<< " = NULL;\n";
}
void CodeGenTileLangCPP::DefineModuleName() {
decl_stream << "void* " << module_name_ << " = NULL;\n";
}
void CodeGenTileLangCPP::GenerateForwardFunctionDeclarations(
String global_symbol,
const Array<Type> &arg_types, const Type &ret_type) {
if (!emit_fwd_func_decl_) {
return;
}
for (auto &func_already_defined : GetFunctionNames()) {
if (global_symbol == func_already_defined) {
return;
}
}
this->PrintFuncPrefix(fwd_decl_stream);
this->PrintType(ret_type, fwd_decl_stream);
fwd_decl_stream << " " << global_symbol << "(";
for (size_t i = 0; i < arg_types.size(); ++i) {
if (i > 0) {
fwd_decl_stream << ", ";
}
CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream);
}
fwd_decl_stream << ");\n";
}
void CodeGenTileLangCPP::PrintFuncPrefix(std::ostream &os) { // NOLINT(*)
os << "#ifdef __cplusplus\n"
<< "extern \"C\"\n"
<< "#endif\n";
}
void CodeGenTileLangCPP::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
ICHECK_EQ(lanes, 1) << "does not support vector types";
os << "void*";
return;
}
if (t.is_void()) {
os << "void";
return;
}
if (t == DataType::Bool()) {
os << "bool";
return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16:
os << "half";
break;
case 32:
os << "float";
break;
case 64:
os << "double";
break;
default:
fail = true;
break;
}
if (!fail && lanes == 1)
return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes;
return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
os << 'u';
}
switch (t.bits()) {
case 8:
os << "int8_t";
break;
case 16:
os << "int16_t";
break;
case 32:
os << "int32_t";
break;
case 64:
os << "int64_t";
break;
case 1:
os << "int32_t";
break;
default:
fail = true;
break;
}
if (!fail && lanes == 1)
return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes;
return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to C type";
}
void CodeGenTileLangCPP::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
int lanes = op->dtype.lanes();
os << "((";
PrintType(op->dtype, os);
os << ")(";
for (int i = 0; i < lanes; ++i) {
if (i != 0)
os << ", ";
os << v;
}
os << "))";
}
void CodeGenTileLangCPP::PrintGetFuncFromBackend(
const std::string &func_name, const std::string &packed_func_name) {
this->PrintIndent();
this->stream << "if (" << packed_func_name << " == NULL) {\n";
int packed_func_if_scope = this->BeginScope();
this->PrintIndent();
this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \""
<< func_name << "\""
<< ", &" << packed_func_name << ") != 0) {\n";
int get_func_env_scope = this->BeginScope();
this->PrintIndent();
this->stream << "return -1;\n";
this->EndScope(get_func_env_scope);
this->PrintIndent();
this->stream << "}\n";
this->EndScope(packed_func_if_scope);
this->PrintIndent();
this->stream << "}\n";
}
void CodeGenTileLangCPP::PrintFuncCall(const std::string &packed_func_name,
int num_args) {
this->PrintIndent();
std::string ret_val = name_supply_->FreshName("ret_val");
std::string ret_type_code = name_supply_->FreshName("ret_type_code");
this->stream << "TVMValue " << ret_val << ";\n";
this->PrintIndent();
this->stream << "int " << ret_type_code << ";\n";
this->PrintIndent();
this->stream << "if (TVMFuncCall(" << packed_func_name << ", "
<< "(TVMValue*) stack_value"
<< ", "
<< "(int*) stack_tcode"
<< ", " << num_args << ", "
<< "&" << ret_val << ", "
<< "&" << ret_type_code << ") != 0) {\n";
int func_call_scope = this->BeginScope();
this->PrintIndent();
this->stream << "return -1;\n";
this->EndScope(func_call_scope);
this->PrintIndent();
this->stream << "}\n";
}
void CodeGenTileLangCPP::PrintFuncCallC(
const std::string &packed_func_name, int num_args,
const std::string &resource_handle_name) {
this->PrintIndent();
std::string ret_val = name_supply_->FreshName("ret_val");
std::string ret_type_code = name_supply_->FreshName("ret_type_code");
this->stream << "TVMValue " << ret_val << ";\n";
this->PrintIndent();
this->stream << "int " << ret_type_code << ";\n";
this->PrintIndent();
this->stream << "if (" << packed_func_name << "( "
<< "(TVMValue*) stack_value "
<< ", "
<< "(int*) stack_tcode"
<< ", " << num_args << ", "
<< "&" << ret_val << ", "
<< "&" << ret_type_code << ", " << resource_handle_name
<< ") != 0){\n";
int func_call_scope = this->BeginScope();
this->PrintIndent();
this->stream << "return -1;\n";
this->EndScope(func_call_scope);
this->PrintIndent();
this->stream << "}\n";
}
void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) {
// clear previous generated state.
this->InitFuncState(f);
// reserve keywords
ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
this->PrintFuncPrefix(stream);
CodeGenC::PrintType(f->ret_type, stream);
this->PrintExtraAttrs(f, stream);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
for (size_t i = 0; i < f->params.size(); ++i) {
tir::Var v = f->params[i];
std::string vid = AllocVarID(v.get());
if (i != 0)
stream << ", ";
if (v.dtype().is_handle()) {
// work around for grid constant parameters.
if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (ptr->storage_scope == "grid_constant") {
stream << "__grid_constant__ const ";
CodeGenC::PrintType(ptr->element_type, stream);
stream << ' ' << vid;
continue;
}
}
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
CodeGenC::PrintType(GetType(v), stream);
if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype);
}
}
if (no_alias) {
PrintRestrict(v, stream);
}
} else {
CodeGenC::PrintType(GetType(v), stream);
}
stream << ' ' << vid;
}
stream << ") {\n";
this->PreFunctionBody(f);
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
}
std::string CodeGenTileLangCPP::GetPackedName(const CallNode *op) {
const StringImmNode *s = op->args[0].as<StringImmNode>();
ICHECK(s != nullptr)
<< "tvm_call_packed_lowered expects first argument as function name";
std::string func_name = s->value;
std::string packed_func_name = func_name + "_packed";
std::string unique_name;
auto it = declared_globals_.find(packed_func_name);
if (it != declared_globals_.end()) {
unique_name = it->second;
} else {
unique_name = name_supply_->FreshName(packed_func_name);
declared_globals_[packed_func_name] = unique_name;
decl_stream << "static void* " << unique_name << " = NULL;\n";
}
return unique_name;
}
CodeGenTileLangCPP::FunctionInfo
CodeGenTileLangCPP::GetFunctionInfo(const CallNode *op,
bool has_resource_handle) {
const StringImmNode *s = op->args[0].as<StringImmNode>();
ICHECK(s != nullptr)
<< "tvm_call_[c]packed_lowered expects first argument as function name";
int64_t begin = op->args[3].as<IntImmNode>()->value;
int64_t end = op->args[4].as<IntImmNode>()->value;
int64_t num_args = end - begin;
ICHECK_GE(num_args, 0);
std::string func_name = s->value;
if (has_resource_handle) {
const StringImmNode *resource_handle_var = op->args[5].as<StringImmNode>();
if (resource_handle_var != nullptr) {
std::string resource_handle_name = resource_handle_var->value;
return {func_name, num_args - 1, resource_handle_name};
} else {
// The final arg should be "(void*) NULL" to indicate the empty
// resource_handle.
num_args--;
const CallNode *reinterpret_call = op->args[5].as<CallNode>();
ICHECK_NE(reinterpret_call, (void *)nullptr)
<< "At CallNode to " << s
<< "arg 5: Expect either StringImm naming the resource_handle var "
"from interface API or "
<< "reinterpret(0); got: " << op->args[5];
ICHECK_EQ(reinterpret_call->op, builtin::reinterpret())
<< "At CallNode to " << s
<< "arg 5: Expect either StringImm naming the resource_handle var "
"from interface API or "
<< "reinterpret(0); got: " << op->args[5];
ICHECK(is_zero(reinterpret_call->args[0]))
<< "At CallNode to " << s
<< " arg 5: Expect either StringImm naming the "
"resource_handle var from interface API, or "
<< "zero; got " << op->args[5];
}
}
return {func_name, num_args, "NULL"};
}
void CodeGenTileLangCPP::VisitExpr_(const CallNode *op,
std::ostream &os) { // NOLINT(*)
if (op->op.same_as(builtin::tvm_stack_alloca())) {
std::string stack_name = name_supply_->FreshName("stack");
const std::string &type = op->args[0].as<StringImmNode>()->value;
const IntImmNode *num = op->args[1].as<IntImmNode>();
ICHECK(num != nullptr);
static_assert(alignof(TVMValue) % alignof(DLTensor) == 0, "invariant");
size_t unit = sizeof(TVMValue);
size_t size = 0;
if (type == "shape") {
size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit;
} else if (type == "arg_value") {
size = (num->value * sizeof(TVMValue) + unit - 1) / unit;
} else if (type == "arg_tcode") {
size = (num->value * sizeof(int) + unit - 1) / unit;
} else if (type == "array") {
size = (num->value * sizeof(DLTensor) + unit - 1) / unit;
} else {
LOG(FATAL) << "Unknown stack alloca type " << type;
}
this->PrintIndent();
this->stream << "TVMValue " << stack_name << "[" << size << "];\n";
os << stack_name;
} else if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
auto function_info = GetFunctionInfo(op, false /* has_resource_handle */);
std::string func_name_packed = GetPackedName(op);
this->PrintGetFuncFromBackend(function_info.func_name, func_name_packed);
this->PrintFuncCall(func_name_packed, function_info.num_args);
} else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) {
auto function_info = GetFunctionInfo(op, true /* has_resource_handle */);
this->PrintFuncCallC(function_info.func_name, function_info.num_args,
function_info.resource_handle_name);
} else if (op->op.same_as(builtin::tvm_throw_last_error())) {
this->PrintIndent();
this->stream << "return -1;\n";
} else {
CodeGenC::VisitExpr_(op, os);
}
}
void CodeGenTileLangCPP::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*)
if (emit_asserts_) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (!(" << cond << ")) {\n";
int assert_if_scope = this->BeginScope();
PrintIndent();
stream << "TVMAPISetLastError(\"" << op->message.as<StringImmNode>()->value
<< "\");\n";
PrintIndent();
stream << "return -1;\n";
this->EndScope(assert_if_scope);
PrintIndent();
stream << "}\n";
}
this->PrintStmt(op->body);
}
void CodeGenTileLangCPP::VisitStmt_(const AllocateNode *op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
this->PrintIndent();
std::string scope = GetPtrStorageScope(op->buffer_var);
const VarNode *buffer = op->buffer_var.as<VarNode>();
PrintType(op->dtype, stream);
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
stream << ' ' << vid << '[' << constant_size << "];\n";
RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
}
void CodeGenTileLangCPP::VisitExpr_(const MinNode *op,
std::ostream &os) { // NOLINT(*)
PrintTernaryCondExpr(op, "<", os);
}
void CodeGenTileLangCPP::VisitExpr_(const MaxNode *op,
std::ostream &os) { // NOLINT(*)
PrintTernaryCondExpr(op, ">", os);
}
template <typename T>
inline void
CodeGenTileLangCPP::PrintTernaryCondExpr(const T *op, const char *compare,
std::ostream &os) { // NOLINT(*)
std::ostringstream temp_a;
VisitExpr(op->a, temp_a);
std::string a_id = SSAGetID(temp_a.str(), op->a.dtype());
std::ostringstream temp_b;
VisitExpr(op->b, temp_b);
std::string b_id = SSAGetID(temp_b.str(), op->b.dtype());
os << "((" << a_id << ") " << compare << " (" << b_id << ") "
<< "? (" << a_id << ") : (" << b_id << "))";
}
} // namespace codegen
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_c_host.h
* \brief Generate C host code.
*/
#ifndef TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_
#define TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "target/source/codegen_c.h"
#include "tvm/target/codegen.h"
#include "tvm/tir/expr.h"
namespace tvm {
namespace codegen {
class CodeGenTileLangCPP : public CodeGenC {
public:
CodeGenTileLangCPP();
void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl,
std::string target_str,
const std::unordered_set<std::string> &devices);
void InitGlobalContext();
// Override this as a work around for non tvm runtime code generations
void AddFunction(const PrimFunc &f);
/*!
* \brief Add functions from the (unordered) range to the current module in a
* deterministic order. This helps with debugging.
*
* \param functions A vector of unordered range of current module.
*/
void AddFunctionsOrdered(
std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> functions);
void DefineModuleName();
using CodeGenC::PrintType;
void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void PrintFuncPrefix(std::ostream &os) final; // NOLINT(*)
// overload visitor functions
void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const CallNode *op, std::ostream &os) override; // NOLINT(*)
// overload min and max to use the ternary operator, so we don't rely on the
// standard library implementations
void VisitExpr_(const MinNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const MaxNode *op, std::ostream &os) final; // NOLINT(*)
void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*)
void VisitStmt_(const AllocateNode *op) final; // NOLINT(*)
void GenerateForwardFunctionDeclarations(String global_symbol,
const Array<Type> &arg_types,
const Type &ret_type) override;
Array<String> GetFunctionNames() { return function_names_; }
private:
/* \brief Internal structure to store information about function calls */
struct FunctionInfo {
/* \brief function name */
std::string func_name;
/* number of arguments required by the function */
int64_t num_args;
/* \brief name of resource_handle to pass */
std::string resource_handle_name;
};
std::string module_name_;
/* \brief mapping global packed func to the unique name */
std::unordered_map<std::string, std::string> declared_globals_;
/* \brief names of the functions declared in this module */
Array<String> function_names_;
/*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_;
/*! \brief whether to emit forwared function declarations in the resulting C
* code */
bool emit_fwd_func_decl_;
FunctionInfo GetFunctionInfo(const CallNode *op, bool has_resource_handle);
std::string GetPackedName(const CallNode *op);
void PrintGetFuncFromBackend(const std::string &func_name,
const std::string &packed_func_name);
void PrintFuncCall(const std::string &packed_func_name, int num_args);
void PrintFuncCallC(const std::string &packed_func_name, int num_args,
const std::string &resource_handle_name);
/*!
* \brief Print ternary conditional operator implementing binary `op`
* Forces the operands to be in SSA form.
* \param op binary operator being expressed
* \param compare string representation of comparison operator
* \param os stream reference to print into
*/
template <typename T>
inline void PrintTernaryCondExpr(const T *op, const char *compare,
std::ostream &os); // NOLINT(*)
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "codegen_cpp.h"
namespace tvm {
namespace codegen {
runtime::Module BuildCPPHost(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
bool emit_asserts = false;
bool emit_fwd_func_decl = true;
std::unordered_set<std::string> devices;
if (mod->GetAttr<Map<GlobalVar, String>>("device_contexts") != nullptr) {
Map<GlobalVar, String> device_contexts =
mod->GetAttr<Map<GlobalVar, String>>("device_contexts").value();
for (auto const &context : device_contexts) {
devices.insert(context.second.data());
}
}
CodeGenTileLangCPP cg;
cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices);
cg.SetConstantsByteAlignment(
target->GetAttr<Integer>("constants-byte-alignment").value_or(16));
auto is_aot_executor_fn = [](const PrimFunc &func) -> bool {
return func->GetAttr<Bool>("runner_function", Bool(false)).value();
};
std::vector<std::pair<GlobalVar, PrimFunc>> funcs;
for (auto [gvar, base_func] : mod->functions) {
ICHECK(base_func->IsInstance<PrimFuncNode>())
<< "CodegenCHost: Can only take PrimFunc";
auto prim_func = Downcast<PrimFunc>(base_func);
funcs.push_back({gvar, prim_func});
}
// Sort functions
auto sort_key = [&is_aot_executor_fn](const auto &kv) {
return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint};
};
std::sort(funcs.begin(), funcs.end(),
[&sort_key](const auto &kv_a, const auto &kv_b) {
return sort_key(kv_a) < sort_key(kv_b);
});
// Declare all functions first. This ensures that all functions,
// including the __tvm_main__ used in AOT, have access to forward
// declarations of other functions in the IRModule.
for (const auto &[gvar, prim_func] : funcs) {
cg.DeclareFunction(gvar, prim_func);
}
// Codegen all functions. Passing emit_fwd_func_decl=true adds a
// forward declaration for any `builtin::call_extern`, based on the
// arguments provided to it.
for (const auto &[gvar, prim_func] : funcs) {
cg.AddFunction(prim_func);
}
if (target->GetAttr<Bool>("system-lib").value_or(Bool(false))) {
ICHECK_EQ(target->GetAttr<String>("runtime").value_or(""), "c")
<< "c target only supports generating C runtime SystemLibs";
}
std::string code = cg.Finish();
return CSourceModuleCreate(code, "c", cg.GetFunctionNames());
}
TVM_REGISTER_GLOBAL("target.build.tilelang_cpp").set_body_typed(BuildCPPHost);
} // namespace codegen
} // namespace tvm
#include <math.h>
#include <stdbool.h>
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file annotate_device_regions.cc
* \brief Split device function from host.
*/
#include "tir/transforms/ir_utils.h"
#include <tvm/ir/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tl {
using namespace tir;
class DeviceRegionAnnotater : public StmtMutator {
public:
explicit DeviceRegionAnnotater(Target device_target)
: device_target_(device_target) {}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::attr::kTarget) {
// If a target attribute already exists, use it as-is.
return GetRef<Stmt>(op);
} else if (op->attr_key == tir::attr::thread_extent ||
op->attr_key == tir::attr::pipeline_exec_scope ||
op->attr_key == tir::attr::device_scope) {
// These attributes are only allowed in device-side code, so
// they should be annotated with the function's default target.
Stmt body = GetRef<Stmt>(op);
return AttrStmt(device_target_, tvm::attr::kTarget, 0, body);
} else {
// All other annotations are ignored
return StmtMutator::VisitStmt_(op);
}
}
private:
Target device_target_;
};
tvm::transform::Pass AnnotateDeviceRegions() {
using namespace tir::transform;
auto pass_func = [](PrimFunc func, IRModule mod,
tvm::transform::PassContext ctx) -> PrimFunc {
auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute";
Target target = opt_target.value();
Target device_target = target.WithoutHost();
if (target->GetHost()) {
if (device_target->kind->name == "c") {
// Annotate the function with the device target
auto func_body = func->body;
func.CopyOnWrite()->body =
AttrStmt(device_target, tvm::attr::kTarget, 0, func_body);
}
DeviceRegionAnnotater mutator(target.WithoutHost());
func.CopyOnWrite()->body = mutator(func->body);
}
return func;
};
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateDeviceRegions", {});
}
TVM_REGISTER_GLOBAL("tl.transform.AnnotateDeviceRegions")
.set_body_typed(AnnotateDeviceRegions);
} // namespace tl
} // namespace tvm
......@@ -35,10 +35,86 @@
#include "common/loop_fusion_utils.h"
#include "loop_partition.h"
#include "loop_vectorize.h"
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
using runtime::StorageRank;
using runtime::StorageScope;
static bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn";
}
static bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == "";
}
static bool isLocalFragment(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kLocal &&
storage_scope.tag == ".fragment";
}
/*!
* \brief collect the mapping from the buffer var to its allocate
*/
class AllocateCollector : public StmtExprVisitor {
public:
void VisitStmt_(const AllocateNode *op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
dyn_shmem_allocs_[op->buffer_var.get()] = op;
} else if (IsStaticSharedMemory(op->buffer_var)) {
static_shmem_allocs_[op->buffer_var.get()] = op;
} else if (isLocalFragment(op->buffer_var)) {
local_fragment_allocs_[op->buffer_var.get()] = op;
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const BlockNode *op) final {
for (auto buffer : op->alloc_buffers) {
if (IsDynamicSharedMemory(buffer->data)) {
dyn_shmem_allocs_[buffer->data.get()] = op;
} else if (IsStaticSharedMemory(buffer->data)) {
static_shmem_allocs_[buffer->data.get()] = op;
} else if (isLocalFragment(buffer->data)) {
local_fragment_allocs_[buffer->data.get()] = op;
}
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AllocateConstNode *op) final {
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const SeqStmtNode *op) final {
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode *op) final {
StmtExprVisitor::VisitStmt_(op);
}
// The dynamic mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const Object *> dyn_shmem_allocs_;
// The static mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const Object *> static_shmem_allocs_;
// The local fragment mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const Object *> local_fragment_allocs_;
};
using namespace tir;
using arith::IRMutatorWithAnalyzer;
......@@ -50,37 +126,113 @@ struct LayoutInferenceResult {
class BufferUseDefCollector : public StmtExprVisitor {
public:
BufferUseDefCollector() = default;
BufferUseDefCollector(bool skip_thread_partition)
: skip_thread_partition_(skip_thread_partition) {}
LayoutInferenceResult Run() {
// Basic consistency check: infer_list_ and thread_var_vec_ should have the
// same size
ICHECK_EQ(infer_list_.size(), thread_var_vec_.size())
<< "Size mismatch: infer_list_ and thread_var_vec_ must match in "
"length.";
// If needed, you can also check that annotated_layout_map_ is not empty, or
// anything else relevant to your setup.
// Copy the annotated layout map to local variable
Map<Buffer, Layout> layout_map = annotated_layout_map_;
int num_infer = infer_list_.size();
// maintain a bfs queue and infer common layout
// Prepare BFS queue for iterative inference
std::queue<int> q;
std::vector<bool> in_queue(num_infer, true);
for (int i = 0; i < num_infer; i++)
q.push(i);
for (int i = 0; i < num_infer; i++) {
// Check that each infer_list_ entry is valid
ICHECK(infer_list_[i] != nullptr)
<< "infer_list_[" << i
<< "] is null. The inference object is not allocated properly.";
// Check that each thread_var_vec_ entry is defined
if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
// TODO(lei): This is a hack for cpu backend
if (!thread_var_.defined()) {
// Fake thread var to inference predicate for the buffer
thread_var_ = IterVar(Range::FromMinExtent(PrimExpr(0), PrimExpr(1)),
Var(""), IterVarType::kDataPar);
}
thread_var_vec_[i] = thread_var_;
}
q.push(i);
}
auto run_infer_step = [&](int cur_infer_id, InferLevel level,
bool update_queue) {
// Range check for cur_infer_id
ICHECK_GE(cur_infer_id, 0)
<< "cur_infer_id is negative, which is invalid.";
ICHECK_LT(cur_infer_id, num_infer)
<< "cur_infer_id " << cur_infer_id << " is out of range, must be < "
<< num_infer << ".";
// Make sure we can safely access infer_list_[cur_infer_id] and
// thread_var_vec_[cur_infer_id]
auto &next = infer_list_[cur_infer_id];
auto iter_var = thread_var_vec_[cur_infer_id];
// Double-check that 'next' is valid
ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id
<< "] is null inside run_infer_step.";
// Check iter_var->dom and dom->extent
ICHECK(iter_var.defined())
<< "thread_var_vec_[" << cur_infer_id << "] is not defined.";
ICHECK(iter_var->dom.defined())
<< "iter_var->dom is not defined for infer_list_[" << cur_infer_id
<< "].";
ICHECK(iter_var->dom->extent.defined())
<< "iter_var->dom->extent is not defined for infer_list_["
<< cur_infer_id << "].";
const int64_t *extent_ptr = as_const_int(iter_var->dom->extent);
ICHECK(extent_ptr != nullptr)
<< "iter_var->dom->extent is not a constant integer, which is "
"required for layout inference.";
// Run InferLayout
auto updates = next->InferLayout(
LayoutInferArgs{
target_,
static_cast<size_t>(*as_const_int(iter_var->dom->extent)),
LayoutInferArgs{target_, static_cast<size_t>(*extent_ptr),
layout_map},
level);
// Process the returned updates
for (const auto &[buffer, layout] : updates) {
// Basic validity checks
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
if (layout_map.count(buffer)) {
// If already in map, ensure they are structurally equal
ICHECK(StructuralEqual()(layout, layout_map[buffer]))
<< "Get different layout for " << buffer;
<< "Get different layout for " << buffer
<< " in cur_infer_id = " << cur_infer_id;
} else {
// Otherwise, update map
layout_map.Set(buffer, layout);
if (!update_queue)
continue;
// Check if buffer exists in use_list_
ICHECK(use_list_.count(buffer))
<< "Buffer " << buffer << " not found in use_list_. "
<< "Potential mismatch between inference updates and use_list_.";
// Push back into BFS queue
for (int idx : use_list_[buffer]) {
ICHECK_GE(idx, 0) << "Index in use_list_ for buffer " << buffer
<< " is negative.";
ICHECK_LT(idx, num_infer)
<< "Index in use_list_ for buffer " << buffer
<< " out of range: " << idx << " >= " << num_infer << ".";
if (!in_queue[idx] && idx != cur_infer_id) {
in_queue[idx] = true;
q.push(idx);
......@@ -89,49 +241,66 @@ public:
}
}
};
auto finish_infer_queue = [&]() {
while (!q.empty()) {
int cur_infer_id = q.front();
q.pop();
// Range check again, just to be safe
ICHECK_GE(cur_infer_id, 0);
ICHECK_LT(cur_infer_id, num_infer);
in_queue[cur_infer_id] = false;
run_infer_step(cur_infer_id, InferLevel::kCommon, true);
}
};
// step 1, infer strict layout
// step 1: infer strict layout
for (int i = 0; i < num_infer; i++) {
run_infer_step(i, InferLevel::kStrict, false);
}
// step 2: infer common layout with BFS
// step2, infer common layout with bfs
finish_infer_queue();
// step 3, relax the infer constraint to free and rerun.
// step 3: relax constraints to free and re-run
for (int i = 0; i < num_infer; i++) {
run_infer_step(i, InferLevel::kFree, true);
finish_infer_queue();
}
// Check that all fragments have been inferred
// Check that all local.fragment buffers have inferred layouts
for (const auto &[buffer, _] : use_list_) {
if (buffer.scope() == "local.fragment" && layout_map.count(buffer) == 0)
LOG_ERROR << "The layout for fragment " << buffer
if (buffer.scope() == "local.fragment") {
ICHECK_NE(layout_map.count(buffer), 0)
<< "The layout for fragment " << buffer
<< " can not be inferred correctly.";
}
}
// Collect the layout for for nodes
// Collect layout info for For nodes
Map<For, Fragment> for_map;
Map<For, PrimExpr> predicate_map;
for (auto &base_infer : infer_list_) {
// Check if base_infer is valid
ICHECK(base_infer != nullptr) << "Null pointer encountered in "
"infer_list_ while collecting for_map.";
if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) {
// Check that the loop layout is defined
ICHECK(for_infer->GetLoopLayout().defined())
<< "The Layout for Parallel for can not be inferred correctly : \n"
<< "The Layout for Parallel for cannot be inferred correctly:\n"
<< for_infer->GetRoot();
for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
if (auto predicate = for_infer->GetPredicate(thread_var_->var))
// thread_var_ should be defined if we rely on it
ICHECK(thread_var_.defined())
<< "thread_var_ is not defined. Cannot retrieve predicate.";
if (auto predicate = for_infer->GetPredicate(thread_var_->var)) {
predicate_map.Set(for_infer->GetRoot(), predicate.value());
}
}
}
return {layout_map, for_map, predicate_map};
}
......@@ -231,26 +400,28 @@ private:
std::vector<IterVar> thread_var_vec_;
Target target_;
LayoutMap annotated_layout_map_;
bool skip_thread_partition_{false};
};
class LayoutInferencer : public IRMutatorWithAnalyzer {
public:
static PrimFunc Substitute(PrimFunc f) {
static PrimFunc Substitute(PrimFunc f, bool skip_thread_partition = false) {
arith::Analyzer analyzer;
PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = ParallelLoopFuser::Fuse(f->body);
BufferUseDefCollector collector;
BufferUseDefCollector collector(skip_thread_partition);
collector.Collect(f);
auto result = collector.Run();
LayoutInferencer substituter(result, &analyzer);
LayoutInferencer substituter(result, skip_thread_partition, &analyzer);
fptr->body = substituter.VisitStmt(f->body);
return f;
}
private:
LayoutInferencer(const LayoutInferenceResult result,
arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer), result_(result){};
bool skip_thread_partition, arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer), result_(result),
skip_thread_partition_(skip_thread_partition){};
Stmt VisitStmt_(const BlockNode *op) final {
Block block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
......@@ -270,8 +441,12 @@ private:
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(GetRef<For>(op))) {
auto loop_layout = result_.for_map[GetRef<For>(op)];
if (!skip_thread_partition_) {
// If none thread bindings are provided, partition the loop
for_node =
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
}
for_node = VectorizeLoop(for_node);
if (result_.predicate_map.count(GetRef<For>(op))) {
return IfThenElse(result_.predicate_map[GetRef<For>(op)], for_node);
......@@ -296,12 +471,22 @@ private:
private:
const LayoutInferenceResult result_;
IterVar thread_var_;
bool skip_thread_partition_{false};
};
tvm::transform::Pass LayoutInference() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return LayoutInferencer::Substitute(std::move(f));
AllocateCollector collector;
collector(f->body);
// TODO(Lei): This is a hack to avoid the issue of thread partition
// for cpu backend. We should remove this after we have a better
// solution for thread partition detect.
bool need_thread_partition = (collector.dyn_shmem_allocs_.size() > 1 ||
collector.static_shmem_allocs_.size() > 1 ||
collector.local_fragment_allocs_.size() > 1);
bool skip_thread_partition = !need_thread_partition;
return LayoutInferencer::Substitute(std::move(f), skip_thread_partition);
};
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file make_packed_api.cc Lower PrimFunc to use the packed function API.
*/
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include <utility>
#include <vector>
#include "tir/transforms/arg_binder.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
static constexpr const char *kDeviceContextVar = "device_api_context";
namespace {
class ReturnRewriter : public StmtMutator {
public:
explicit ReturnRewriter(Var ret_var, Var ret_tcode)
: ret_var_(ret_var), ret_tcode_(ret_tcode) {}
Stmt VisitStmt_(const ForNode *node) override {
if (node->kind == ForKind::kParallel)
in_parallel_ += 1;
Stmt ret = StmtMutator::VisitStmt_(node);
if (node->kind == ForKind::kParallel)
in_parallel_ -= 1;
return ret;
}
Stmt VisitStmt_(const EvaluateNode *node) override {
Stmt ret = StmtMutator::VisitStmt_(node);
const EvaluateNode *eval = ret.as<EvaluateNode>();
ICHECK(eval);
if (const CallNode *call = eval->value.as<CallNode>()) {
if (call->op.same_as(builtin::ret())) {
ICHECK_EQ(in_parallel_, 0)
<< "tir.ret cannot be used in parallel scope.";
ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument.";
ret = WriteToOut(call->args[0]);
}
}
return ret;
}
private:
struct ConvertedInfo {
int tcode{-1};
PrimExpr expr;
Buffer dummy_val_buffer;
Buffer dummy_tcode_buffer;
};
ConvertedInfo ConvertForFFI(PrimExpr val) {
ConvertedInfo info;
// convert val's data type to FFI data type, return type code
DataType dtype = val.dtype();
if (dtype.is_int() || dtype.is_uint()) {
info.tcode = kTVMArgInt;
info.expr = Cast(DataType::Int(64), val);
} else if (dtype.is_float()) {
info.tcode = kTVMArgFloat;
info.expr = Cast(DataType::Float(64), val);
} else if (dtype.is_void()) {
info.tcode = kTVMNullptr;
info.expr = val;
} else {
LOG(FATAL) << "data type " << dtype << " not supported yet";
}
// If multiple return locations have the same data type, use the
// same dummy buffer declaration.
auto it = dummy_val_buffer_map_.find(info.tcode);
if (it != dummy_val_buffer_map_.end()) {
info.dummy_val_buffer = it->second;
} else {
info.dummy_val_buffer =
Buffer(ret_var_, info.expr.dtype(), {1}, {1}, ConstInt32(0),
ret_var_->name_hint, 0, 0, kDefault);
dummy_val_buffer_map_[info.tcode] = info.dummy_val_buffer;
}
// The tcode is always a 32-bit int, so we don't need to have a separate
// map.
if (!dummy_tcode_buffer_.defined()) {
dummy_tcode_buffer_ =
Buffer(ret_tcode_, DataType::Int(32), {1}, {1}, ConstInt32(0),
ret_tcode_->name_hint, 0, 0, kDefault);
}
info.dummy_tcode_buffer = dummy_tcode_buffer_;
return info;
}
Stmt WriteToOut(PrimExpr val) {
auto info = ConvertForFFI(val);
Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0});
Stmt store_tcode = BufferStore(info.dummy_tcode_buffer, info.tcode, {0});
Stmt ret_zero = Evaluate(tvm::ret(0));
return SeqStmt({store_val, store_tcode, ret_zero});
}
Var ret_var_;
Var ret_tcode_;
int in_parallel_{0};
std::unordered_map<int, Buffer> dummy_val_buffer_map_;
Buffer dummy_tcode_buffer_;
};
Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
ReturnRewriter rewriter(ret_var, ret_tcode);
return rewriter(body);
}
class SubroutineCallRewriter : public StmtExprMutator {
public:
static Optional<Stmt> Apply(const Map<GlobalVar, String> &packed_func_methods,
Stmt stmt) {
SubroutineCallRewriter rewriter(packed_func_methods);
stmt = rewriter.VisitStmt(std::move(stmt));
if (rewriter.made_change_) {
return stmt;
} else {
return NullOpt;
}
}
private:
explicit SubroutineCallRewriter(
const Map<GlobalVar, String> &packed_func_methods)
: packed_func_methods(packed_func_methods) {}
PrimExpr VisitExpr_(const CallNode *op) override {
auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (auto *gvar_ptr = node->op.as<GlobalVarNode>()) {
auto gvar = GetRef<GlobalVar>(gvar_ptr);
if (auto symbol = packed_func_methods.Get(gvar)) {
Array<PrimExpr> cpacked_args;
cpacked_args.push_back(tir::StringImm(symbol.value()));
for (auto arg : node->args) {
cpacked_args.push_back(arg);
}
// push an empty handle to be compatible with current cpacked convention
cpacked_args.push_back(tir::make_zero(DataType::Handle()));
made_change_ = true;
return tir::Call(node->dtype, tir::builtin::tvm_call_cpacked(),
cpacked_args);
}
}
return node;
}
const Map<GlobalVar, String> &packed_func_methods;
bool made_change_{false};
};
} // namespace
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}
inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) {
Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr});
return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0));
}
/* \brief Return the global_symbol of the function, if it should be updated
*
* \param func The function to be inspected
*
* \returns The global_symbol to be used for the function at call
* sites, or NullOpt if the function is to remain unchanged.
*/
Optional<String> RequiresPackedAPI(const PrimFunc &func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
if (auto opt = func->GetAttr<Integer>(tvm::attr::kCallingConv)) {
if (CallingConv(opt.value()->value) != CallingConv::kDefault) {
return NullOpt;
}
}
// Internal function calls do not need the PackedFunc API
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (!global_symbol.defined()) {
return NullOpt;
}
return global_symbol;
}
PrimFunc MakePackedAPI(PrimFunc func) {
auto global_symbol = RequiresPackedAPI(func);
if (!global_symbol.defined()) {
return func;
}
std::string name_hint = global_symbol.value();
Target target = [&]() {
auto opt = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt) << "MakePackedAPI required the function to be annotated with "
"tvm::attr::kTarget ("
<< tvm::attr::kTarget
<< "), but the function only has attributes " << func->attrs;
return opt.value();
}();
int target_device_type = target->GetTargetDeviceType();
// A function without a host target has already been lowered.
Target target_host;
if (auto opt = target->GetHost()) {
target_host = opt.value();
} else {
return func;
}
auto *func_ptr = func.CopyOnWrite();
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
// Data field definitions
// The packed fields
Var v_packed_args("args", DataType::Handle());
Buffer buf_packed_arg_type_ids =
decl_buffer({IntImm(DataType::Int(32), func_ptr->params.size())},
DataType::Int(32), "arg_type_ids");
Var v_num_packed_args("num_args", DataType::Int(32));
Var v_out_ret_value("out_ret_value", PointerType(PrimType(DataType::Void())));
Var v_out_ret_tcode("out_ret_tcode",
PointerType(PrimType(DataType::Int(32))));
Var v_resource_handle("resource_handle", DataType::Handle());
// The arguments of the function.
// The device context
Var device_id("dev_id");
Integer device_type(target_device_type);
// seq_init gives sequence of initialization
// seq_check gives sequence of later checks after init
std::vector<Stmt> seq_init, seq_check, arg_buffer_declarations;
std::unordered_map<const VarNode *, PrimExpr> vmap;
ArgBinder binder(&vmap);
// ---------------------------
// local function definitions
// load i-th argument as type t
auto f_arg_value = [&](DataType t, int i) {
Array<PrimExpr> call_args{
v_packed_args, IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), builtin::kTVMValueContent)};
// load 64 bit version
DataType api_type = APIType(t);
PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
// cast to the target version.
if (api_type != t) {
res = Cast(t, res);
}
return res;
};
// Find the device API context argument based on name
for (const auto &param : func_ptr->params) {
if (param->name_hint == kDeviceContextVar) {
num_args--;
v_resource_handle = param;
break;
}
}
// Assert correct type codes for each argument. This must be done
// *before* any initialization steps produced by
// `binder.BindDLTensor()`. The validity of those initialization
// steps depends on the correct types being present, and must not
// occur before the type codes are actually checked.
seq_init.push_back(
MakeAssertEQ(v_num_packed_args, num_args, [&]() -> std::string {
std::ostringstream error_message;
error_message << name_hint << ": num_args should be " << num_args;
return error_message.str();
}()));
seq_init.push_back(MakeAssertNotNull(
v_packed_args, name_hint + ": TVMValue* arg pointer was NULL"));
seq_init.push_back(MakeAssertNotNull(
buf_packed_arg_type_ids->data, name_hint + ": int* type_codes was NULL"));
seq_init.emplace_back(DeclBuffer(buf_packed_arg_type_ids, nop));
// Need to delay binding of the buffers, in case some arguments also
// appear in the buffer.
std::vector<std::pair<PrimExpr, Var>> var_def;
std::vector<std::pair<Var, Buffer>> buffer_def;
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i];
// Ignore the device context argument, as it will still be passed
// as a native argument.
if (param->name_hint == kDeviceContextVar) {
continue;
}
var_def.emplace_back(f_arg_value(param.dtype(), i), param);
if (func_ptr->buffer_map.count(param)) {
buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
}
// type code checks
Var tcode(param->name_hint + ".code", DataType::Int(32));
seq_init.emplace_back(LetStmt(
tcode,
BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}),
nop));
DataType t = param.dtype();
if (t.is_handle()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_init.emplace_back(
AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
tvm::tir::StringImm(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_init.emplace_back(
AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
} else {
ICHECK(t.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_init.emplace_back(
AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
}
}
Array<Var> args{v_packed_args, buf_packed_arg_type_ids->data,
v_num_packed_args, v_out_ret_value,
v_out_ret_tcode, v_resource_handle};
// Arg definitions are defined before buffer binding to avoid the use before
// def errors.
//
// For example, for auto broadcasting, checks are required to guarantee that
// either 0 or the original stride will be correctly used. Checks here have
// to use the args that may have no let binding yet. Therefore, hoisting let
// binding for args before buffer declaration is needed.
for (const auto &[expr, param] : var_def) {
binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
}
for (const auto &kv : buffer_def) {
binder.BindDLTensor(kv.second, device_type, device_id, kv.first,
name_hint + "." + kv.first->name_hint);
arg_buffer_declarations.push_back(DeclBuffer(kv.second, nop));
}
func =
WithAttrs(std::move(func),
{{tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)},
{tvm::attr::kTarget, target_host}});
Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::compute_scope,
StringImm(name_hint + "_compute_"), body);
// Set device context
if (vmap.count(device_id.get())) {
ObjectRef node = String("default");
seq_check.push_back(AttrStmt(node, tir::attr::device_id, device_id, nop));
seq_check.push_back(
AttrStmt(node, tir::attr::device_type, device_type, nop));
bool need_set_device =
(target_device_type != kDLMicroDev &&
(
// or is c source target
target_device_type != kDLCPU || target->kind->name != "llvm"));
if (need_set_device) {
Stmt set_device =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(runtime::symbol::tvm_set_device),
device_type, device_id}));
body = SeqStmt({set_device, body});
}
}
// Return error code of zero on success
body = SeqStmt({body, Evaluate(ret(Integer(0)))});
body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(),
arg_buffer_declarations},
body);
func_ptr->body = body;
func_ptr->params = args;
Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
ICHECK_EQ(undefined.size(), 0)
<< "In PrimFunc " << name_hint << " variables " << undefined
<< " are used, but are not passed in as API arguments";
func_ptr->buffer_map = Map<Var, Buffer>();
func_ptr->checked_type_ = func_ptr->func_type_annotation();
func_ptr->ret_type = PrimType(DataType::Int(32)); // return the function.
return func;
}
tvm::transform::Pass MakePackedAPI() {
using tvm::transform::Pass;
auto pass_func = [](IRModule mod, tvm::transform::PassContext ctx) {
Map<GlobalVar, String> packed_func_methods;
for (const auto &[gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto prim_func = opt.value();
if (auto global_symbol = RequiresPackedAPI(prim_func)) {
packed_func_methods.Set(gvar, global_symbol.value());
}
}
}
IRModuleNode *mptr = mod.CopyOnWrite();
IRModule updates;
for (const auto &[gvar, base_func] : mptr->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
auto func = opt.value();
auto orig_func = func;
if (auto body = SubroutineCallRewriter::Apply(packed_func_methods,
func->body)) {
func.CopyOnWrite()->body = body.value();
}
func = MakePackedAPI(std::move(func));
if (!func.same_as(orig_func)) {
updates->Add(gvar, func);
}
}
}
if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0, "tl.MakePackedAPI", {});
}
TVM_REGISTER_GLOBAL("tl.transform.MakePackedAPI").set_body_typed([]() {
return MakePackedAPI();
});
} // namespace tl
} // namespace tvm
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
import tilelang.testing
from tilelang import tvm as tvm
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
num_stages = 0
@T.prim_func
def matmul(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by):
A_local = T.alloc_local((block_M, block_K), dtype)
B_local = T.alloc_local((block_K, block_N), dtype)
C_local = T.alloc_local((block_M, block_N), accum_dtype)
T.clear(C_local)
# Apply layout optimizations or define your own layout
# (Optional).
# T.annotate_layout(
# {
# A_local: make_swizzle_layout(A_local),
# B_local: make_swizzle_layout(B_local),
# }
# )
for ko in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, ko * block_K], A_local)
# Or Copy with Parallel
for k, j in T.Parallel(block_K, block_N):
B_local[k, j] = B[ko * block_K + k, by * block_N + j]
for i, j, k in T.grid(block_M, block_N, block_K):
C_local[i, j] += A_local[i, k] * B_local[k, j]
T.copy(C_local, C[by * block_M, bx * block_N])
return matmul
def assert_matmul_codegen(M=1024, N=1024, K=1024, block_M=128, block_N=128, block_K=32):
func = matmul(M, N, K, block_M, block_N, block_K)
rt_mod, _ = tilelang.lower(func, target="c")
code = rt_mod.imported_modules[0].get_source()
assert code is not None, "Code generation failed"
def test_matmul_codegen():
assert_matmul_codegen(M=1024, N=1024, K=1024, block_M=128, block_N=128, block_K=32)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -5,16 +5,26 @@
import tilelang as tl
import os
import os.path as osp
from typing import Literal, Union
from typing import Union, Optional
from tilelang import tvm as tvm
from tvm import tir, relay
from tvm.ir import CallingConv
from tvm.target import Target
from tilelang.contrib import hipcc, nvcc
from tilelang.utils import determine_target
from tilelang.utils.target import determine_target
def is_device_call(func: tir.PrimFunc):
return bool(func.attrs and "calling_conv" in func.attrs and func.attrs["calling_conv"] == 2)
attrs = func.attrs
# consider c source as a device call
if "target" in attrs:
target = attrs["target"]
if target.kind.name == "c":
return True
return bool(func.attrs and "calling_conv" in func.attrs and
func.attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH)
def is_host_call(func: tir.PrimFunc):
......@@ -95,13 +105,26 @@ def extrac_params(func: tir.PrimFunc):
return tensor_types
def canon_target_host(target: Union[str, Target], target_host: Optional[Union[str, Target]]):
def target_is_c(target):
if isinstance(target, str):
return target == "c"
return target.kind.name == "c"
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
return target_host
def lower(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[Literal["auto", "cuda", "hip"], Target] = "auto",
target_host="llvm",
target: Union[str, Target] = "auto",
target_host: Optional[Union[str, Target]] = None,
runtime_only=False,
):
# TODO(lei): Append C Source code host generation to the runtime
mod = func_or_mod
if isinstance(func_or_mod, tir.PrimFunc):
func = func_or_mod
......@@ -111,6 +134,8 @@ def lower(
if isinstance(target, str):
target = determine_target(target)
target_host = canon_target_host(target, target_host)
target_host = tvm.target.Target.canon_target(target_host)
target = tvm.target.Target(target, target_host)
......@@ -167,13 +192,13 @@ def lower(
mod = tl.transform.LowerHopperIntrin()(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)
mod = tir.transform.AnnotateDeviceRegions()(mod)
mod = tl.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
mod = tir.transform.ThreadSync("shared")(mod)
mod = tir.transform.ThreadSync("shared.dyn")(mod)
mod = tir.transform.MakePackedAPI()(mod)
mod = tl.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
host_mod = tir.transform.Filter(is_host_call)(mod)
host_mod = tir.transform.BindTarget(target_host)(host_mod)
......@@ -187,6 +212,8 @@ def lower(
if target_host.kind.name == "llvm":
host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host)
elif target_host.kind.name == "c":
host_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(host_mod, target_host)
else:
raise ValueError("Target host is not supported")
......@@ -201,6 +228,10 @@ def lower(
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
elif target.kind.name == "c":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target)
elif target.kind.name == "llvm":
device_mod = tvm._ffi.get_global_func("target.build.llvm")(device_mod, target)
else:
raise ValueError("Target is not supported")
......
......@@ -6,7 +6,7 @@ from typing import Union, List, Tuple, Optional
from collections import deque
from tvm import tir
from tvm.tir import Var
from tvm.script.ir_builder.tir.frame import TIRFrame
from tvm.script.ir_builder.tir.frame import TIRFrame, BlockFrame
from tvm._ffi import register_object
from tilelang import _ffi_api
......@@ -73,12 +73,21 @@ class KernelLaunchFrame(TIRFrame):
"""
super().__enter__()
_kernel_launch_frame_stack.push(self)
# If we have exactly 5 frames, return the single iter_var.var.
if len(self.frames) == 5:
return self.frames[0].iter_var.var
last_block_frame = self.frames[-1]
assert isinstance(last_block_frame, BlockFrame), "Last frame must be a block frame"
maybe_cpu = last_block_frame.annotations.get("tilelang.is_cpu_kernel_frame", False)
if maybe_cpu:
# CPU kernel frame, return a list of for frame items.
return [frame.vars[0] for frame in self.frames[0:-1]]
else:
# Otherwise, return a list of iter_var.var objects (excluding the last 4 frames).
# As 4 frames for threadIdx.x, threadIdx.y, threadIdx.z and block frame with attributes
return [frame.iter_var.var for frame in self.frames[0:-4]]
def __exit__(self, ptype, value, trace):
......@@ -148,7 +157,8 @@ class KernelLaunchFrame(TIRFrame):
def Kernel(
*blocks: List[tir.PrimExpr],
threads: Union[int, List[int], Tuple] = 128,
threads: Optional[Union[int, List[int], Tuple]] = None,
is_cpu: bool = False,
prelude: Optional[str] = None,
):
"""Tools to quickly construct a GPU kernel launch frame.
......@@ -161,11 +171,13 @@ def Kernel(
A integer representing blockDim.x
Or a list of integers representing blockDim.(x|y|z)
if the value is -1, we skip the threadIdx.x binding.
is_cpu : bool
Whether the kernel is running on CPU.
Thus we will not bind threadIdx.x, threadIdx.y, threadIdx.z.
and blockIdx.x, blockIdx.y, blockIdx.z.
prelude : str
The import c code of the kernel,
will be injected before the generated kernel code.
layout_annotation: Optional[Map[tir.Buffer, tir.IndexMap]]
The layout annotation map, used to annotate the layout of the buffers.
Returns
-------
......@@ -174,6 +186,9 @@ def Kernel(
"""
attrs: dict = {}
if not is_cpu and threads is None:
threads = 128 # default thread number
if isinstance(threads, int):
threads = [threads, 1, 1]
elif isinstance(threads, list):
......@@ -181,7 +196,10 @@ def Kernel(
elif isinstance(threads, tuple):
threads = list(threads) + [1] * (3 - len(threads))
else:
raise ValueError("threads must be an integer or a list of integers")
assert is_cpu, "threads must be an integer or a list of integers"
if is_cpu:
attrs["tilelang.is_cpu_kernel_frame"] = True
if prelude is not None:
attrs["pragma_import_c"] = prelude
......
......@@ -164,3 +164,25 @@ def LegalizeSafeMemoryAccess():
The result pass
"""
return _ffi_api.LegalizeSafeMemoryAccess() # type: ignore
def MakePackedAPI():
"""MakePackedAPI
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MakePackedAPI() # type: ignore
def AnnotateDeviceRegions():
"""AnnotateDeviceRegions
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.AnnotateDeviceRegions() # type: ignore
......@@ -7,6 +7,14 @@ from tvm.target import Target
from tvm.contrib import rocm
from tilelang.contrib import nvcc
AVALIABLE_TARGETS = {
"auto",
"cuda",
"hip",
"c", # represent c source backend
"llvm",
}
def check_cuda_availability() -> bool:
"""
......@@ -64,8 +72,6 @@ def determine_target(target: Union[str, Target, Literal["auto"]]) -> Union[str,
raise ValueError("No CUDA or HIP available on this system.")
else:
# Validate the target if it's not "auto"
assert isinstance(target, Target) or target in [
"cuda",
"hip",
], f"Target {target} is not supported"
assert isinstance(
target, Target) or target in AVALIABLE_TARGETS, f"Target {target} is not supported"
return target
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment