"ts/webui/vscode:/vscode.git/clone" did not exist on "f84d90d699815ada9bc050261231a93a2f3ff147"
Commit 57ab687c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Initialization] Migration of Codebase from Dev Branch into Main (#10)



* Add format.sh script for code formatting and linting

* docs update

* center align the title

* lint fix

* add ignore

* Add .gitignore for 3rdparty directory

* Add requirements-dev.txt, requirements-test.txt, and requirements.txt

* 3rdparty

* Add gemm.h, CMakeLists.txt, _ffi_api.py, __init__.py, runtime.h, reduce.h, loop_partition.h, utils.h, and loop_vectorize.h

* Refactor CMakeLists.txt and include statements

- Update CMakeLists.txt to use a newer version of CMake and add project name
- Remove unnecessary include directories

Fix include paths in layout.cc, codegen.cc, codegen.h, rt_mod.cc, frontend_legalize.cc, inject_pipeline.cc, layout_inference.cc, loop_vectorize.cc, and lower_tile_op.cc

- Update include paths to use relative paths instead of absolute paths

* Update submodule for 3rdparty/tvm

* update

* load dll first

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* git keep update

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* refactor code structure

* Update Readme

* CMakeLists Customized

* update readme

* update README

* update readme

* update usage

* with TVM_IMPORT_PYTHON_PATH to handle own tvm build python import

* annotate lower transform global func with `transform` prefix

* Migrate Simplify Pass from tilelang tvm branch

* enhance system environment handling with __init__ and CMake

* Initial commit

* CODE_OF_CONDUCT.md committed

* LICENSE committed

* README.md committed

* SECURITY.md committed

* SUPPORT.md committed

* CODE_OF_CONDUCT Commit

* LICENSE Commit

* SECURITY Commit

* SUPPORT Commit

* Modify Support

* Update README.md

* security ci update

* remove examples

* Update and implement clang-format

* add composable kernel components

* Migrate from latest update

* submodule update

* Test update

* Update License

* Spell check

* lint fix

* add clang-tidy to apply static analysis for c source

* update tilelang examples

* Update Install Docs

* Refactor filetree

* Enhance Install

* conflict resloved

* annotate_version

* Initial Update

* test fix

* install

* Implement setup.py

* lint fix

* Separate Init

* Separate test

* docker file commit

* add logo

* Update Readme and Examples

* update readme

* update logo

* Implement AMD Installation

* Add License

* Update AMD MI300x Benchmark

* update README

* update mi300 benchmark scripts

* update ignore

* enhance build scirpt

* update image

* enhance setup.py to remove duplicated libraries

* remove debug files

* update readme

* update image

* update gemm examples

* update flashattention README

* readme update

* add cmake into requirements

* libinfo fix

* auto update submodule

* lint fix

* Fix AMD Build and Test

* Update check for transpose attribute for CDNA Arch

* typo fix for amd

* Implement Matmul Benchmark

* Refactor Code

* [TypoFix] Fix GEMM Example

* [Docs] Init Linear Attention README

* [TYPO] Typo fix

* [Lint] Lint Fix

* enhance example with intrinsics

* [Enhancement] Improve Buffer Collection during IR Parser

* [Dev] Introduce Current classmethod to get current frame

* submodule update

* fake test pass update

* support thread_extent_api

* code optimize

* Add GEMM function implementation for matrix multiplication

* Update logging format to reflect TileLang in logger messages

* Refactor CMakeLists.txt for improved readability and set default build type to Release

* Support Gemm SS Primitives Implementation

* [README] Upload Tile Language Logo (#5)

* update logo

* Update README.md to enhance formatting and center the title

---------
Co-authored-by: default avatarmicrosoft-github-operations[bot] <55726097+microsoft-github-operations[bot]@users.noreply.github.com>
Co-authored-by: default avatarMicrosoft Open Source <microsoftopensource@users.noreply.github.com>
Co-authored-by: default avatarYu Cheng <yu.cheng@pku.edu.cn>
parent 64f17c2f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/runtime/runtime.h
* \brief Runtime functions.
*
*/
#ifndef TVM_TL_RUNTIME_RUNTIME_H_
#define TVM_TL_RUNTIME_RUNTIME_H_
namespace tvm {
namespace tl {
constexpr const char* tvm_tensormap_create_tiled = "__tvm_tensormap_create_tiled";
constexpr const char* tvm_tensormap_create_im2col = "__tvm_tensormap_create_im2col";
} // namespace tl
} // namespace tvm
#endif // TVM_TL_RUNTIME_RUNTIME_H_
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file target/codegen.cc
*/
#include "codegen_cuda.h"
#include <tvm/tir/index_map.h>
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>
#include <cmath>
#include <string>
#include <utility>
#include <vector>
#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "target/source/ptx.h"
namespace tvm {
namespace codegen {
CodeGenTileLangCUDA::CodeGenTileLangCUDA() { restrict_keyword_ = "__restrict__"; }
void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; }
class LaunchConfigExtractor : public tir::StmtVisitor {
private:
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x") {
threadIdx_x_ext = op->value;
} else if (iv->var->name_hint == "threadIdx.y" || iv->thread_tag == "threadIdx.y") {
threadIdx_y_ext = op->value;
} else if (iv->var->name_hint == "threadIdx.z" || iv->thread_tag == "threadIdx.z") {
threadIdx_z_ext = op->value;
}
}
StmtVisitor::VisitStmt_(op);
}
public:
PrimExpr threadIdx_x_ext = Integer(1);
PrimExpr threadIdx_y_ext = Integer(1);
PrimExpr threadIdx_z_ext = Integer(1);
};
void CodeGenTileLangCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {
LaunchConfigExtractor extractor;
extractor(f->body);
arith::Analyzer analyzer;
PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
extractor.threadIdx_z_ext);
if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as<IntImmNode>()) {
if (threadIdx_ext_int->value == 1) {
// unable to extract the number of threads per block, hence directly return
return;
}
stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
}
}
std::string CodeGenTileLangCUDA::Finish() {
if (need_mma_h_) {
decl_stream << "#include <mma.h>\n";
}
decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
decl_stream << "#include <tl_templates/cuda/copy.h>\n";
decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
decl_stream << "#include <tl_templates/cuda/ldsm.h>\n";
decl_stream << "#include <tl_templates/cuda/threadblock_swizzle.h>\n";
decl_stream << "\n";
return CodeGenC::Finish();
}
void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode* op) {
if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent();
stream << "#pragma unroll\n";
}
std::string extent = PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
std::string start = PrintExpr(op->min);
stream << "for (";
PrintType(op->loop_var.dtype(), stream);
stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent << "; ++" << vid
<< ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
PrintIndent();
stream << "}\n";
}
void CodeGenTileLangCUDA::BindThreadIndex(const IterVar& iv) {
ICHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
}
void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
ICHECK(t.is_scalar()) << "do not yet support vector types";
os << "void*";
return;
}
if (t.is_void()) {
os << "void";
return;
}
if (t == tl::cuTensorMapType()) {
os << "CUtensorMap";
return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16:
if (t.is_scalar()) {
os << "half_t";
} else if (lanes <= 8) {
// Emit CUDA code to access fp16 vector elements.
//
// half4 is stored as uint2
//
// h4.x is emitted as *(half2*)(&(u2.x)).x
// h4.y is emitted as *(half2*)(&(u2.x)).y
// h4.z is emitted as *(half2*)(&(u2.y)).x
// h4.w is emitted as *(half2*)(&(u2.y)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "uint" << lanes / 2;
} else {
fail = true;
}
break;
case 32:
if (lanes <= 4) {
os << "float";
} else if (lanes <= 8) {
// Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
//
// float8 is stored as ulonglong4
//
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4";
os << "ulonglong" << lanes / 2;
} else {
fail = true;
}
break;
case 64:
os << "double";
break;
default:
fail = true;
break;
}
if (!fail && (t.is_scalar() || t.bits() == 16)) return;
if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes;
return;
}
} else if (t.is_bfloat16()) {
if (t.is_scalar()) {
os << "bfloat16_t";
} else if (lanes <= 8) {
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "uint" << lanes / 2;
} else {
fail = true;
}
if (!fail) return;
} else if (t.is_float8()) {
if (t.is_scalar()) {
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
} else if (lanes == 2) {
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short
} else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else {
fail = true;
}
if (!fail) return;
} else if (t == DataType::Bool()) {
os << "bool";
return;
} else if (t.is_vector_bool()) {
// CUDA does not support bool vectors.
// Use ushort vectors to represent instead.
int n = t.lanes();
if (n <= 4) {
os << "ushort" << n;
return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
os << "u";
}
switch (t.bits()) {
case 1: {
if (t.is_scalar()) {
os << "int";
return;
} else if (t.lanes() == 8) {
os << "int8_t";
return;
} else if (t.lanes() == 16) {
os << "int16_t";
return;
} else if (t.lanes() == 32) {
os << "int";
return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
}
case 4: {
if (t.is_scalar()) {
os << "int";
return;
} else if (t.lanes() == 4) {
os << "int16_t";
return;
} else if (t.lanes() == 8) {
// directly 8 4-bit int in integer.
os << "int";
return;
} else if (t.lanes() == 16) {
os << "int2";
return;
} else if (t.lanes() == 32) {
os << "int4";
return;
} else if (t.lanes() == 64) {
os << "int8";
return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
}
case 8: {
if (t.lanes() == 4) {
// directly 4 8 bit int in integer.
// We use int for int8x4 instead of char4 because using char4 is
// likely to produce extra instructions to pack four int8 elements
// into 32-bit data.
os << "int";
return;
} else if (t.lanes() == 8) {
os << "int2";
return;
} else if (t.lanes() == 16) {
os << "int4";
return;
} else if (!t.is_uint() && t.is_scalar()) {
os << "signed char";
break;
} else {
os << "char";
break;
}
}
case 16: {
if (t.is_scalar()) {
os << "short";
} else if (t.lanes() <= 4) {
os << "short" << lanes;
} else if (t.lanes() <= 8) {
// Emit CUDA code to access int16 vector elements.
//
// short4 is stored as int2
//
// s4.x is emitted as *(short2*)(&(i2.x)).x
// s4.y is emitted as *(short2*)(&(i2.x)).y
// s4.z is emitted as *(short2*)(&(i2.y)).x
// s4.w is emitted as *(short2*)(&(i2.y)).y
//
ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4";
os << "int" << t.lanes() / 2;
} else {
fail = true;
}
if (!fail) {
return;
}
break;
}
case 32: {
if (t.is_scalar()) {
os << "int";
} else if (t.lanes() <= 4) {
os << "int" << t.lanes();
} else if (t.lanes() <= 8) {
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
//
// int8 is stored as longlong4
//
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4";
os << "longlong" << lanes / 2;
} else {
fail = true;
}
if (!fail) {
return;
}
break;
}
case 64: {
if (t.is_scalar()) {
os << "int64_t";
} else if (t.lanes() == 2) {
os << "longlong2";
} else if (t.lanes() == 3) {
os << "longlong3";
} else if (t.lanes() == 4) {
os << "longlong4";
}
return;
}
default:
fail = true;
break;
}
if (!fail && lanes == 1) {
return;
}
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes;
return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}
void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) { // NOLINT(*)
// Declare the result.
std::string sret = name_supply_->FreshName("_");
this->PrintIndent();
this->PrintType(t, stream);
stream << ' ' << sret << ";\n";
int ssa_scope = BeginScope();
{
// Unpack into individual ops.
std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());
for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
std::ostringstream value_temp;
if (isalpha(op[0])) {
value_temp << op << "(";
PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
value_temp << ", ";
PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
} else {
value_temp << "(";
PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
value_temp << op;
PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
}
PrintVecElemStore(sret, t, i, value_temp.str());
}
}
EndScope(ssa_scope);
os << sret;
}
void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) { // NOLINT(*)
if (t.is_scalar()) {
os << vec;
return;
}
static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
std::string type_name = t.is_int() ? "char" : "unsigned char";
if (t.lanes() == 2 || t.lanes() == 3) {
os << vec << "." << access[i % t.lanes()];
} else {
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
}
} else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else if (t.is_bfloat16()) {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
if (t.is_int()) {
type_name = "short";
} else if (t.is_uint()) {
type_name = "ushort";
}
} else if (t.bits() == 32) {
if (t.is_int()) {
type_name = "int";
} else if (t.is_uint()) {
type_name = "uint";
} else if (t.is_float()) {
type_name = "float";
}
}
ICHECK(!type_name.empty());
os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else {
os << vec << "." << access[i];
}
}
void CodeGenTileLangCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
const std::string& value) {
this->PrintIndent();
static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (t.lanes() == 2 || t.lanes() == 3) {
stream << vec << '.' << access[i % t.lanes()] << "=" << "(" << value << ");\n";
} else {
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
stream << ac << "=";
// Do not read the first undef lane.
if (i != 0) {
stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |";
}
stream << "(" << value << " << " << i % 4 * 8 << ");\n";
}
} else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
<< value << ";\n";
} else if (t.is_bfloat16()) {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
<< " = " << value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
if (t.is_int()) {
type_name = "short";
} else if (t.is_uint()) {
type_name = "ushort";
}
} else if (t.bits() == 32) {
if (t.is_int()) {
type_name = "int";
} else if (t.is_uint()) {
type_name = "uint";
} else if (t.is_float()) {
type_name = "float";
}
}
ICHECK(!type_name.empty());
stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n";
} else {
stream << vec << "." << access[i] << " = " << value << ";\n";
}
}
void CodeGenTileLangCUDA::PrintStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
// DO nothing.
} else if (sync == "shared" || sync == "shared.dyn") {
this->PrintIndent();
this->stream << "__syncthreads();\n";
}
}
void CodeGenTileLangCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead";
if (scope == "shared") {
os << "__shared__ ";
} else if (scope == "shared.dyn") {
os << "extern __shared__ __align__(1024) ";
}
}
std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from, DataType target) {
if (from == target) return value;
std::ostringstream os;
os << "((";
this->PrintType(target, os);
os << ")";
if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) {
os << "(";
if (target.is_uint()) {
os << "u";
}
os << "int)";
}
os << value << ")";
return os.str();
}
void CodeGenTileLangCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
DataType from_ty = op->value.dtype();
DataType target_ty = op->dtype;
ICHECK_EQ(target_ty.lanes(), from_ty.lanes());
// Emit simple C-style type conversion.
if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
// We could emit make_float4 like calls, but the emitted code looks
// too compact to read. Emit this as vectorized unary ops.
std::string sret = name_supply_->FreshName("_");
this->PrintIndent();
this->PrintType(target_ty, stream);
stream << ' ' << sret << ";\n";
{
std::string src = SSAGetID(PrintExpr(op->value), from_ty);
for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
std::ostringstream val;
val << "(";
PrintType(target_ty.element_of(), val);
val << ")(";
PrintVecElemLoad(src, from_ty, i, val);
val << ")";
PrintVecElemStore(sret, target_ty, i, val.str());
}
}
os << sret;
}
void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg, std::ostream& os) { // NOLINT(*)
DataType ret_dtype = GetRuntimeDataType(ret_type);
if (ret_dtype.is_vector()) {
//
// Emit an unsupported vector call
//
// v = intrin_f((float4*)A[0], (float4*)B[0])
//
// as
//
// float4 __ret;
// {
// float4 __arg0 = ((float4*)A)[0];
// float4 __arg1 = ((float4*)B)[0];
// __ret.x = intrin_f(__arg0.x, __arg1.x);
// __ret.y = intrin_f(__arg0.y, __arg1.y);
// __ret.z = intrin_f(__arg0.z, __arg1.z);
// __ret.w = intrin_f(__arg0.w, __arg1.w);
// }
// v = __ret;
//
// Declare the result vector.
std::string sret = name_supply_->FreshName("_");
this->PrintIndent();
this->PrintType(ret_dtype, stream);
stream << ' ' << sret << ";\n";
{
// Load arguments.
std::vector<std::string> sargs;
size_t arg_begin = static_cast<size_t>(skip_first_arg);
for (size_t i = arg_begin; i < args.size(); ++i) {
std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype());
sargs.push_back(std::move(val));
}
// Emit a scalar call for each lane.
for (int i = 0; i < ret_dtype.lanes(); ++i) {
std::ostringstream scall;
scall << global_symbol << "(";
for (size_t j = 0; j < sargs.size(); ++j) {
if (j > 0) scall << ", ";
PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
}
scall << ")";
PrintVecElemStore(sret, ret_dtype, i, scall.str());
}
}
os << sret;
} else {
CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os);
}
}
// Print a reference expression to a buffer.
std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) {
const VarNode* buffer_var = buffer->data.get();
std::ostringstream os;
std::string vid = GetVarID(buffer_var);
std::string scope;
if (alloc_storage_scope_.count(buffer_var)) {
scope = alloc_storage_scope_.at(buffer_var);
}
// bool is_vol = IsVolatile(buffer_var);
// always false for tl cutlass backend.
bool is_vol = false;
auto ptr_cast = [this, is_vol, scope](DataType pointed_to) {
std::ostringstream ptr_os;
ptr_os << "(";
if (is_vol) {
ptr_os << "volatile ";
}
if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, ptr_os);
}
PrintType(pointed_to, ptr_os);
ptr_os << "*)";
return ptr_os.str();
};
DataType buffer_element_dtype = buffer->dtype;
std::string buffer_str = vid;
if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) {
std::stringstream temp;
temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
buffer_str = temp.str();
}
std::string index_str = PrintExpr(index);
if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
// This is a special case, because CodegenCUDA::PrintType()
// returns "int" for bool and for 4-bit integers. In most cases,
// we divide by the number of lanes to determine the index.
// However, the backing type for scalar int4 and scalar bool is
// int32. Therefore, we need to divide by the ratio of their
// sizes in that case.
int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes();
os << "*("
<< "(" << ptr_cast(t) << vid << ")"
<< " + " << index_str << " / " << div_factor << ")";
} else if (t == buffer_element_dtype) {
os << buffer_str << "[" << index_str << "]";
} else {
os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
}
return os.str();
}
void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) {
this->PrintIndent();
this->stream << name << "(";
for (size_t i = offset; i < op->args.size(); i++) {
if (i > offset) this->stream << ", ";
this->stream << this->PrintExpr(op->args[i]);
}
this->stream << ");\n";
};
if (op->op.same_as(builtin::ptx_cp_async())) {
std::string dst = this->PrintExpr(op->args[0]);
std::string dst_offset = this->PrintExpr(op->args[1]);
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
// use size of argument list to indicate whether or not to use predicated cp.async
if (op->args.size() == 5) {
this->PrintIndent();
this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+" << dst_offset << ", " << src
<< "+" << src_offset << ");\n";
} else {
std::string condition = this->PrintExpr(op->args[5]);
this->PrintIndent();
this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst << "+" << dst_offset
<< ", " << src << "+" << src_offset << ", " << condition << ");\n";
}
} else if (op->op.same_as(builtin::ptx_commit_group())) {
print_extern_call_stmt("tl::cp_async_commit");
} else if (op->op.same_as(builtin::ptx_wait_group())) {
int n = Downcast<IntImm>(op->args[0])->value;
std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">";
print_extern_call_stmt(func_name, 1);
} else if (op->op.same_as(builtin::create_barriers())) {
this->PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value;
std::string barrier_name = "_mbarrier";
this->stream << "__shared__ uint64_t " << barrier_name << "[" << barrier_count << "];\n";
} else if (op->op.same_as(tl::GetMBarrierOp())) {
std::string barrier_name = "_mbarrier";
std::string barrier_id = this->PrintExpr(op->args[0]);
os << barrier_name + "[" + barrier_id + "]";
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
print_extern_call_stmt("tl::mbarrier_arrive");
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
print_extern_call_stmt("tl::mbarrier_init");
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
print_extern_call_stmt("tl::mbarrier_arrive_expect_tx");
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::MBarrierExpectTX())) {
print_extern_call_stmt("tl::mbarrier_expect_tx");
} else if (op->op.same_as(tl::MBarrierWaitParity())) {
print_extern_call_stmt("tl::mbarrier_wait");
} else if (op->op.same_as(tl::SyncThreadsPartialOp())) {
print_extern_call_stmt("tl::syncthreads_partial");
} else if (op->op.same_as(tl::TMALoadOp())) {
print_extern_call_stmt("tl::tma_load");
} else if (op->op.same_as(tl::TMALoadIm2ColOp())) {
print_extern_call_stmt("tl::tma_load_im2col");
} else if (op->op.same_as(tl::TMAStoreOp())) {
print_extern_call_stmt("tl::tma_store");
} else if (op->op.same_as(tl::LDMatrixOp())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
if (trans == 1) func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::STMatrixOp())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
if (trans == 1) func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::FenceProxyAsyncOp())) {
print_extern_call_stmt("tl::fence_proxy_async");
} else if (op->op.same_as(tl::SetMaxNReg())) {
this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value;
int is_inc = Downcast<IntImm>(op->args[1])->value;
std::string func_name = is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
} else if (op->op.same_as(tl::WaitWgmma())) {
this->PrintIndent();
int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
} else if (op->op.same_as(tl::PackB16Op())) {
os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1])
<< ")";
} else if (op->op.same_as(builtin::tvm_fill_fragment())) {
need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 6U);
os << "nvcuda::wmma::fill_fragment(";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[5], os);
os << ")";
} else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::load_matrix_sync(";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[5], os);
os << ", ";
this->PrintExpr(op->args[6], os);
os << ")";
} else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::store_matrix_sync(";
this->PrintExpr(op->args[5], os);
os << ", ";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[6], os);
if (const StringImmNode* str = op->args[7].as<StringImmNode>()) {
os << ", nvcuda::wmma::mem_" << str->value;
} else {
LOG(FATAL) << "Invalid parameters";
}
os << ")";
} else if (op->op.same_as(builtin::tvm_mma_sync())) {
need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::mma_sync(";
for (int i = 0; i < 4; ++i) {
this->PrintExpr(op->args[i * 2], os);
os << "[";
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", " : ")");
}
} else if (op->op.same_as(builtin::tvm_bmma_sync())) {
need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::bmma_sync(";
for (int i = 0; i < 4; ++i) {
this->PrintExpr(op->args[i * 2], os);
os << "[";
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", " : ")");
}
} else if (op->op.same_as(builtin::ptx_mma())) {
// arg 0: shape: mXnXkX
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: fp16, fp64, ...
// arg 4: B precision: fp16, fp64, ...
// arg 5: C precision: fp32, fp64, ...
// arg 6: A multiplicand
// arg 7: A multiplicand index
// arg 8: B multiplicand
// arg 9: B multiplicand index
// arg 10: C accumulator
// arg 11: C accumulator index
// arg 12: saturate
// arg 13: (optional) 1-bit operator (xor or and)
ICHECK(op->args.size() == 13U || op->args.size() == 14U);
std::string shape = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
std::string a_ref = this->PrintExpr(op->args[6]);
std::string a_bias = this->PrintExpr(op->args[7]);
std::string b_ref = this->PrintExpr(op->args[8]);
std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
bool saturate = Downcast<Bool>(op->args[12])->value;
std::string bit_op = op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
std::string asm_code =
PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref,
b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
this->stream << asm_code;
} else if (op->op.same_as(builtin::ptx_mma_sp())) {
// arg 0: shape: mXnXkX
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: fp16, fp32, ...
// arg 4: B precision: fp16, fp32, ...
// arg 5: C precision: fp16, fp32, ...
// arg 6: A multiplicand pointer
// arg 7: A multiplicand index
// arg 8: B multiplicand pointer
// arg 9: B multiplicand index
// arg 10: C accumulator pointer
// arg 11: C accumulator index
// arg 12: metadata
// arg 13: metadata index
// arg 14: sparse_selector
// arg 15: saturate
ICHECK_EQ(op->args.size(), 16U);
std::string shape = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
std::string a_ref = this->PrintExpr(op->args[6]);
std::string a_offset = this->PrintExpr(op->args[7]);
std::string b_ref = this->PrintExpr(op->args[8]);
std::string b_offset = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_offset = this->PrintExpr(op->args[11]);
std::string metadata = this->PrintExpr(op->args[12]);
std::string metadata_offset = this->PrintExpr(op->args[13]);
std::string sparse_selector = this->PrintExpr(op->args[14]);
bool saturate = Downcast<Bool>(op->args[15])->value;
std::string asm_code = PrintMMAAssembly(
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset,
c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate);
this->stream << asm_code;
} else if (op->op.same_as(builtin::ptx_ldmatrix())) {
// arg 0: whether the matrix is loaded in column major format or not.
// arg 1: number of matrices to load.
// arg 2: The data type in the matrix, .b16 is the only accepted data type.
// arg 3: pointer to local buffer.
// arg 4: The offset of the element to store in the local buffer.
// arg 5: pointer to the shared memory buffer to load.
// arg 6: The offset of the start element of the row to load in shared memory.
ICHECK_EQ(op->args.size(), 7U);
bool trans = Downcast<Bool>(op->args[0])->value;
int num = Downcast<Integer>(op->args[1])->value;
std::string type = Downcast<StringImm>(op->args[2])->value;
std::string local_ptr = this->PrintExpr(op->args[3]);
std::string local_elem_offset = this->PrintExpr(op->args[4]);
std::string smem_ptr = this->PrintExpr(op->args[5]);
if (trans && op->dtype.bits() == 8) {
// Since ldmatrix assumes that a matrix element is 16 bit, it cannot properly transpose an
// int8 matrix.
std::string smem_stride = this->PrintExpr(op->args[6]);
ICHECK(num == 4);
os << "for (int i = 0; i < 16; ++i) {\n";
os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
<< "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride +
"+ (i % 4) * " + smem_stride + " + threadIdx.x / 4 + (i / 8) * 8];\n";
os << "}\n";
} else {
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
need_cast_smem_ptr_to_int_ = true;
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
smem_ptr, smem_elem_offset);
}
} else if (op->op.same_as(builtin::mma_store())) {
int m = Downcast<Integer>(op->args[0])->value;
int n = Downcast<Integer>(op->args[1])->value;
std::string dst = this->PrintExpr(op->args[2]);
std::string src = this->PrintExpr(op->args[3]);
std::string src_offset = this->PrintExpr(op->args[4]);
PrimExpr stride = op->args[5];
ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now";
// Each thread in a warp holds a certain number of elements of an MMA output.
// For example, if we compute a 16x16 tile using MMA, each thread holds 8 elements
// in its registers. So conceptually, a warp memory is organized as a 32x8 block.
// A map from a 16x16 tile to a 32x8 block of memory is specified by the index map below.
// To store the 32x8 output back to a 16x16 tile in shared or global memory, we invert this map
// to determine the output location for each 8 element.
const auto* index_map_func =
runtime::Registry::Get("tir.index_map.shared_16x16_to_mma_32x8_layout");
IndexMap index_map;
if (!index_map_func) {
Var i, j;
// The index map is defined as follows:
index_map = IndexMap({i, j}, {
4 * FloorMod(i, 8) + FloorDiv(FloorMod(j, 8), 2), 4 * FloorDiv(j, 8) + FloorDiv(i, 8) * 2 + FloorMod(j, 2)
});
} else{
index_map = IndexMap::FromFunc(2, *index_map_func);
}
arith::Analyzer analyzer;
auto inverse_index_map =
index_map.Inverse({Range(0, m), Range(0, n)}, &analyzer);
auto indices_16x16 = inverse_index_map->final_indices;
// "//" and "%" in the index map are translated to FloorDiv/Mod, but the plain Div/Mod are fine.
// FloorDiv/Mod are supposed to be lowered before they reach codegen, so manually replace them
// to the plain ones here.
class LowerFloorDivMod : public ExprMutator {
public:
PrimExpr VisitExpr_(const FloorDivNode* op) {
return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b));
}
PrimExpr VisitExpr_(const FloorModNode* op) {
return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b));
}
};
auto dst_ind = LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]);
var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x";
var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id";
if (op->dtype.bits() == 16) {
os << "for (int local_id = 0; local_id < 8; local_id+=2) {\n";
os << "*((uint *)&" << dst << "[" + this->PrintExpr(dst_ind) + "])"
<< " = "
<< "*((uint *)&" << src << "[" << src_offset << " + local_id]);\n";
os << "}\n";
}
else {
os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
os << dst << "[" + this->PrintExpr(dst_ind) + "]"
<< " = " << src << "[" << src_offset << " + local_id];\n";
os << "}\n";
}
} else if (op->op.same_as(builtin::mma_fill())) {
std::string num_elem = this->PrintExpr(op->args[0]);
std::string dst = this->PrintExpr(op->args[1]);
std::string dst_offset = this->PrintExpr(op->args[2]);
os << "for (int i = 0; i < " << num_elem << "; ++i) {\n";
os << dst << "[" << dst_offset << " + i] = 0.0;";
os << "}\n";
} else if (op->op.same_as(builtin::ptx_cp_async())) {
std::string dst = this->PrintExpr(op->args[0]);
std::string dst_offset = this->PrintExpr(op->args[1]);
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
need_cast_smem_ptr_to_int_ = true;
// use size of argument list to indicate whether or not to use predicated cp.async
if (op->args.size() == 5) {
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
} else {
this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, size,
this->PrintExpr(op->args[5]));
}
} else if (op->op.same_as(builtin::ptx_cp_async_bulk())) {
need_cast_smem_ptr_to_int_ = true;
std::string dst = this->PrintExpr(op->args[0]);
std::string dst_offset = this->PrintExpr(op->args[1]);
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
int barrier_id = Downcast<IntImm>(op->args[5])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier);
} else if (op->op.same_as(builtin::ptx_commit_group())) {
this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
} else if (op->op.same_as(builtin::ptx_wait_group())) {
int n = Downcast<IntImm>(op->args[0])->value;
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << ";\");\n\n";
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintCpAsyncBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
std::string thread_count = this->PrintExpr(op->args[1]);
this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintArriveBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
std::string byte_count = this->PrintExpr(op->args[1]);
this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count);
} else if (op->op.same_as(builtin::ptx_wait_barrier())) {
need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintWaitBarrierAsm(barrier);
} else if (op->op.same_as(builtin::create_barriers())) {
CHECK_EQ(barrier_count_, -1);
int barrier_count = Downcast<IntImm>(op->args[0])->value;
// pad barrier alignment to avoid runtime alignment errors
CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0);
int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t);
if (barrier_count % barrier_alignment_count != 0) {
barrier_count = ((barrier_count / barrier_alignment_count) + 1) * barrier_alignment_count;
}
barrier_count_ = barrier_count;
this->stream << "__shared__ __align__(" << barrier_alignment_bytes_ << ") uint64_t "
<< barrier_name_ << "[" << barrier_count << "];\n";
this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { " << barrier_name_
<< "[i] = 0; }\n";
} else if (op->op.same_as(builtin::ptx_ldg32())) {
/*
asm volatile (
"{.reg .pred p;\n"
" setp.ne.b32 p, %2, 0;\n"
// " @p ld.global.nc.f32 %0, [%1];}\n"t
" @p ld.global.nc.L2::128B.f32 %0, [%1];}\n"
: "=f"(reg)
: "l"(addr), "r"((int)guard)
);
*/
// get local
std::string reg = this->PrintExpr(op->args[0]);
// get guard
std::string guard = this->PrintExpr(op->args[1]);
const BufferLoadNode* addr_buffer = op->args[2].as<BufferLoadNode>();
std::string global_addr = this->PrintExpr(addr_buffer->indices[0]);
std::string global_buffer = this->PrintExpr(addr_buffer->buffer->data);
std::string local_addr = this->PrintExpr(op->args[3]);
this->stream << "asm volatile (\n";
this->stream << "\"{.reg .pred p;\\n\"\n";
this->stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n";
this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n";
this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n";
// stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
stream << ": \"=f\"(" << reg << "[" << local_addr << "]"
<< ")\n";
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)"
<< guard << ")\n";
stream << ");\n";
} else {
CodeGenC::VisitExpr_(op, os);
}
}
void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == tir::attr::fragment_shape) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* shape_str = op->value.as<StringImmNode>();
fragment_shapes[buffer] = shape_str->value;
} else if (op->attr_key == tir::attr::fragment_layout) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* layout_str = op->value.as<StringImmNode>();
fragment_layouts[buffer] = layout_str->value;
} else if (op->attr_key == tir::attr::async_commit_queue_scope) {
const IntImmNode* queue_id = op->value.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
this->VisitStmt(op->body);
auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
this->VisitExpr(commit_group, this->stream);
return;
} else if (op->attr_key == tir::attr::async_wait_queue_scope) {
auto wait_attrs = GetAsyncWaitAttributes(op);
auto queue_id = wait_attrs.first.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
auto wait_cnt = wait_attrs.second;
auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
this->VisitExpr(wait_group, this->stream);
auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner);
this->VisitStmt(inner->body);
return;
} else if (op->attr_key == "threadblock_swizzle_pattern") {
this->PrintIndent();
const StringImmNode* pattern = op->value.as<StringImmNode>();
ICHECK(pattern);
this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
this->VisitStmt(op->body);
return;
}
CodeGenC::VisitStmt_(op);
}
void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode* op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
this->PrintIndent();
std::string scope = GetPtrStorageScope(op->buffer_var);
const VarNode* buffer = op->buffer_var.as<VarNode>();
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) ||
op->dtype == DataType::BFloat(16))
<< "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now";
} else {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) ||
op->dtype == DataType::Int(32))
<< "Accumulator only support half, float and int type for now";
}
PrintWmmaScope(scope, op->dtype, buffer, stream);
} else{
PrintStorageScope(scope, stream);
PrintType(op->dtype, stream);
}
if (scope == "shared.dyn") {
stream << ' ' << vid << "[];\n";
} else {
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
if (scope.find("wmma.") == 0) {
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
}
if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1)) &&
scope == "shared") {
constant_size = constant_size / (32 / op->dtype.bits());
}
stream << ' ' << vid << '[' << constant_size << "];\n";
}
RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
}
void CodeGenTileLangCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
os << "(make_";
PrintType(op->dtype, os);
os << "(";
for (int i = 0; i < lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
if (i != lanes - 1) os << ", ";
}
os << "))";
}
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && lanes == 4) {
// make_int8x4
const int64_t* p = as_const_int(op->value);
ICHECK(p);
int64_t v = *p & 0xFF;
v = (v << 24) | (v << 16) | (v << 8) | v;
if (op->dtype.is_uint()) {
os << "(uint)" << v;
} else {
os << "(int)" << v;
}
return;
}
if (op->dtype.is_float16()) {
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_half2(" << v << ", " << v << ")";
}
os << ')';
return;
}
if (op->dtype.is_bfloat16()) {
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
}
os << ')';
return;
}
if (op->dtype.is_float() && op->dtype.bits() == 32 && op->dtype.lanes() == 8) {
std::string v = PrintExpr(op->value);
os << "make_ulonglong4(";
for (int i = 0; i < 4; ++i) {
if (i != 0) os << ", ";
os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")";
}
os << ')';
return;
}
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
bool fail = false;
const int64_t* p = as_const_int(op->value);
ICHECK(p);
int64_t v = *p & 0xF;
if (lanes == 4) {
v = (v << 12) | (v << 8) | (v << 4) | v;
if (op->dtype.is_uint()) {
os << "(uint16_t)" << v;
} else {
os << "(int16_t)" << v;
}
} else {
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v;
if (lanes == 8) {
if (op->dtype.is_uint()) {
os << "(uint)" << v;
} else {
os << "(int)" << v;
}
} else if (lanes == 16 || lanes == 32) {
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 8; ++i) {
if (i != 0) os << ", ";
if (op->dtype.is_uint()) {
os << "(uint)" << v;
} else {
os << "(int)" << v;
}
}
os << ')';
} else {
fail = true;
}
}
if (!fail) {
return;
}
}
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
os << ')';
}
inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p) { // NOLINT(*)
// Type code is kBFloat
if (op->dtype.is_bfloat16()) {
os << "bfloat16_t";
os << '(' << std::scientific << op->value << 'f' << ')';
return;
}
// Type code is kFloat
switch (op->dtype.bits()) {
case 64:
case 32: {
std::ostringstream temp;
if (std::isinf(op->value)) {
if (op->value < 0) {
temp << "-";
}
temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
} else if (std::isnan(op->value)) {
temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
} else {
temp << std::scientific << op->value;
if (op->dtype.bits() == 32) temp << 'f';
}
p->MarkConst(temp.str());
os << temp.str();
break;
}
case 16: {
os << "half_t" << '(';
FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
PrintConst(const_f32.get(), os, p);
os << ')';
break;
}
default:
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
}
}
void CodeGenTileLangCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenTileLangCUDA::PrintWmmaScope(const std::string& scope, DataType t,
const VarNode* variable, std::ostream& os) {
std::stringstream type;
PrintType(t, type);
ICHECK(fragment_shapes.count(variable)) << "Cannot find shape of the wmma fragment "
<< variable->name_hint;
std::string shape_str = fragment_shapes.at(variable);
if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
type.str(std::string());
if (t.is_int()) {
if (t.bits() == 4) {
type << "nvcuda::wmma::experimental::precision::s4";
} else if (t.bits() == 1) {
type << "nvcuda::wmma::experimental::precision::b1";
} else {
LOG(FATAL) << "Unhandled integer type for wmma fragment!";
}
} else if (t.is_uint()) {
if (t.bits() == 4) {
type << "nvcuda::wmma::experimental::precision::u4";
} else {
LOG(FATAL) << "Unhandled integer type for wmma fragment!";
}
}
}
if (scope == "wmma.matrix_a") {
std::string layout_str = fragment_layouts[variable];
ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a";
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", " << type.str()
<< ", nvcuda::wmma::" << layout_str << ">";
} else if (scope == "wmma.matrix_b") {
std::string layout_str = fragment_layouts[variable];
ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b";
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", " << type.str()
<< ", nvcuda::wmma::" << layout_str << ">";
} else if (scope == "wmma.accumulator") {
os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str << ", " << type.str()
<< ">";
}
}
int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable,
int32_t size) {
ICHECK(fragment_shapes.count(variable)) << "Cannot find shape of the wmma fragment "
<< variable->name_hint;
std::string shape_str = fragment_shapes.at(variable);
std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope);
if (dim.first * dim.second != 0)
return size / dim.first / dim.second;
else
return 0;
}
void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
std::ostream& os) {
// Cast away volatile qualifier for fp16 types. That is, only loads and
// stores are volatile. The loaded objects are not marked as volatile.
//
if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) {
os << "(";
PrintType(op->dtype, os);
os << ")(" << value << ")";
} else {
os << value;
}
}
void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
std::ostream& os) {
ICHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (!(t.lanes() == 2 || t.lanes() == 3)) {
if (i != 0) {
os << "|";
}
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
return;
}
}
if (t.is_float16()) {
if (i == 0) {
os << "make_";
PrintType(t, os);
os << '(';
}
if (i % 2 == 0) {
os << "__pack_half2(" << value;
} else {
os << "," << value << ")";
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
}
return;
}
if (t.is_bfloat16()) {
if (i == 0) {
os << "make_";
PrintType(t, os);
os << '(';
}
if (i % 2 == 0) {
os << "__pack_bfloat162(" << value;
} else {
os << "," << value << ")";
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
}
return;
}
if (i == 0) {
os << "make_";
PrintType(t, os);
os << "(";
}
os << value;
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
return;
}
void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) {
// clear previous generated state.
this->InitFuncState(f);
// reserve keywords
ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
this->PrintFuncPrefix(stream);
CodeGenC::PrintType(f->ret_type, stream);
this->PrintExtraAttrs(f, stream);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
for (size_t i = 0; i < f->params.size(); ++i) {
tir::Var v = f->params[i];
std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", ";
if (v.dtype().is_handle()) {
// work around for grid constant parameters.
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (ptr->storage_scope == "grid_constant") {
stream << "__grid_constant__ const ";
CodeGenC::PrintType(ptr->element_type, stream);
stream << ' ' << vid;
continue;
}
}
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
CodeGenC::PrintType(GetType(v), stream);
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype);
}
}
if (no_alias) {
PrintRestrict(v, stream);
}
} else {
CodeGenC::PrintType(GetType(v), stream);
}
stream << ' ' << vid;
}
stream << ") {\n";
this->PreFunctionBody(f);
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
}
} // namespace codegen
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file target/codegen.h
* \brief Utility to generate code
*/
#ifndef TVM_TL_TARGET_CODEGEN_CUDA_H_
#define TVM_TL_TARGET_CODEGEN_CUDA_H_
#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <string>
#include <unordered_map>
#include "target/source/codegen_c.h"
namespace tvm {
namespace codegen {
class CodeGenTileLangCUDA final : public CodeGenC {
public:
CodeGenTileLangCUDA();
std::string Finish();
// override behavior
void PrintFuncPrefix(std::ostream& os) final;
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final;
std::string CastFromTo(std::string value, DataType from, DataType target) final;
// overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final;
void VisitStmt_(const AllocateNode* op) final;
void VisitStmt_(const AttrStmtNode* op) final;
// Override this as a work around for __grid_constant__ parameter
void AddFunction(const PrimFunc& f);
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg, std::ostream& os) final; // NOLINT(*)
private:
// Handle volatile loads
void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
std::ostream& os) final;
// Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final { return false; }
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p);
// The size of the barrier array in shared memory
int barrier_count_ = -1;
// whether need mma.h
bool need_mma_h_{false};
// whether need cast_smem_ptr_to_int helper function
bool need_cast_smem_ptr_to_int_{false};
// The name of the barrier array in shared memory
const std::string barrier_name_ = "barrier";
// The alignment of the barrier array in shared memory
// Set to 16 to maintain minimum alignment requirements for async bulk copy
const int barrier_alignment_bytes_ = 16;
std::unordered_map<const VarNode*, std::string> fragment_shapes;
std::unordered_map<const VarNode*, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p);
void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable,
std::ostream& os);
int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size);
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TL_TARGET_CODEGEN_CUDA_H_
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file target/codegen.cc
*/
#include "codegen_hip.h"
#include <tvm/tir/index_map.h>
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>
#include <cmath>
#include <string>
#include <utility>
#include <vector>
#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "target/source/ptx.h"
namespace tvm {
namespace codegen {
/*!
* \brief Replace patterns with replacement strings.
* \note should use std::format instead when codebase is ported to C++20.
*/
class Replacer {
public:
void register_rule(const std::string& pattern, const std::string& replacement) {
_rules.emplace_back(pattern, replacement);
}
std::string rewrite(std::string str) {
for (auto&& rule : _rules) {
auto [pattern, replacement] = rule;
size_t len = pattern.size();
size_t new_len = replacement.size();
size_t pos = str.find(pattern);
while (pos != std::string::npos) {
str = str.replace(pos, len, replacement);
pos = str.find(pattern, pos + new_len);
}
}
return str;
}
void empty_rules() { _rules.clear(); }
private:
std::vector<std::pair<std::string, std::string>> _rules;
};
CodeGenTileLangHIP::CodeGenTileLangHIP() { restrict_keyword_ = "__restrict__"; }
void CodeGenTileLangHIP::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; }
class LaunchConfigExtractor : public tir::StmtVisitor {
private:
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x") {
threadIdx_x_ext = op->value;
} else if (iv->var->name_hint == "threadIdx.y" || iv->thread_tag == "threadIdx.y") {
threadIdx_y_ext = op->value;
} else if (iv->var->name_hint == "threadIdx.z" || iv->thread_tag == "threadIdx.z") {
threadIdx_z_ext = op->value;
}
}
StmtVisitor::VisitStmt_(op);
}
public:
PrimExpr threadIdx_x_ext = Integer(1);
PrimExpr threadIdx_y_ext = Integer(1);
PrimExpr threadIdx_z_ext = Integer(1);
};
void CodeGenTileLangHIP::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {
LaunchConfigExtractor extractor;
extractor(f->body);
arith::Analyzer analyzer;
PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
extractor.threadIdx_z_ext);
if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as<IntImmNode>()) {
if (threadIdx_ext_int->value == 1) {
// unable to extract the number of threads per block, hence directly return
return;
}
stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
}
}
std::string CodeGenTileLangHIP::Finish() {
// hip must need a header file.
decl_stream << "#include <hip/hip_runtime.h>\n";
if (need_mma_h_) {
decl_stream << "#include <mma.h>\n";
}
decl_stream << "#include <tl_templates/hip/gemm.h>\n";
decl_stream << "#include <tl_templates/hip/copy.h>\n";
decl_stream << "#include <tl_templates/hip/reduce.h>\n";
decl_stream << "#include <tl_templates/hip/ldsm.h>\n";
decl_stream << "#include <tl_templates/hip/threadblock_swizzle.h>\n";
decl_stream << "\n";
return CodeGenC::Finish();
}
void CodeGenTileLangHIP::VisitStmt_(const tir::ForNode* op) {
if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent();
stream << "#pragma unroll\n";
}
std::string extent = PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
std::string start = PrintExpr(op->min);
stream << "for (";
PrintType(op->loop_var.dtype(), stream);
stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent << "; ++" << vid
<< ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
PrintIndent();
stream << "}\n";
}
void CodeGenTileLangHIP::BindThreadIndex(const IterVar& iv) {
ICHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
}
void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
ICHECK(t.is_scalar()) << "do not yet support vector types";
os << "void*";
return;
}
if (t.is_void()) {
os << "void";
return;
}
if (t == tl::cuTensorMapType()) {
os << "CUtensorMap";
return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16:
if (t.is_scalar()) {
os << "half_t";
} else if (lanes <= 8) {
// Emit CUDA code to access fp16 vector elements.
//
// half4 is stored as uint2
//
// h4.x is emitted as *(half2*)(&(u2.x)).x
// h4.y is emitted as *(half2*)(&(u2.x)).y
// h4.z is emitted as *(half2*)(&(u2.y)).x
// h4.w is emitted as *(half2*)(&(u2.y)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "uint" << lanes / 2;
} else {
fail = true;
}
break;
case 32:
if (lanes <= 4) {
os << "float";
} else if (lanes <= 8) {
// Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
//
// float8 is stored as ulonglong4
//
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4";
os << "ulonglong" << lanes / 2;
} else {
fail = true;
}
break;
case 64:
os << "double";
break;
default:
fail = true;
break;
}
if (!fail && (t.is_scalar() || t.bits() == 16)) return;
if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes;
return;
}
} else if (t.is_bfloat16()) {
if (t.is_scalar()) {
os << "bfloat16_t";
} else if (lanes <= 8) {
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "uint" << lanes / 2;
} else {
fail = true;
}
if (!fail) return;
} else if (t.is_float8()) {
if (t.is_scalar()) {
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
} else if (lanes == 2) {
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short
} else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else {
fail = true;
}
if (!fail) return;
} else if (t == DataType::Bool()) {
os << "bool";
return;
} else if (t.is_vector_bool()) {
// CUDA does not support bool vectors.
// Use ushort vectors to represent instead.
int n = t.lanes();
if (n <= 4) {
os << "ushort" << n;
return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
os << "u";
}
switch (t.bits()) {
case 1: {
if (t.is_scalar()) {
os << "int";
return;
} else if (t.lanes() == 8) {
os << "int8_t";
return;
} else if (t.lanes() == 16) {
os << "int16_t";
return;
} else if (t.lanes() == 32) {
os << "int";
return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
}
case 4: {
if (t.is_scalar()) {
os << "int";
return;
} else if (t.lanes() == 4) {
os << "int16_t";
return;
} else if (t.lanes() == 8) {
// directly 8 4-bit int in integer.
os << "int";
return;
} else if (t.lanes() == 16) {
os << "int2";
return;
} else if (t.lanes() == 32) {
os << "int4";
return;
} else if (t.lanes() == 64) {
os << "int8";
return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
}
case 8: {
if (t.lanes() == 4) {
// directly 4 8 bit int in integer.
// We use int for int8x4 instead of char4 because using char4 is
// likely to produce extra instructions to pack four int8 elements
// into 32-bit data.
os << "int";
return;
} else if (t.lanes() == 8) {
os << "int2";
return;
} else if (t.lanes() == 16) {
os << "int4";
return;
} else if (!t.is_uint() && t.is_scalar()) {
os << "signed char";
break;
} else {
os << "char";
break;
}
}
case 16: {
if (t.is_scalar()) {
os << "short";
} else if (t.lanes() <= 4) {
os << "short" << lanes;
} else if (t.lanes() <= 8) {
// Emit CUDA code to access int16 vector elements.
//
// short4 is stored as int2
//
// s4.x is emitted as *(short2*)(&(i2.x)).x
// s4.y is emitted as *(short2*)(&(i2.x)).y
// s4.z is emitted as *(short2*)(&(i2.y)).x
// s4.w is emitted as *(short2*)(&(i2.y)).y
//
ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4";
os << "int" << t.lanes() / 2;
} else {
fail = true;
}
if (!fail) {
return;
}
break;
}
case 32: {
if (t.is_scalar()) {
os << "int";
} else if (t.lanes() <= 4) {
os << "int" << t.lanes();
} else if (t.lanes() <= 8) {
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
//
// int8 is stored as longlong4
//
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4";
os << "longlong" << lanes / 2;
} else {
fail = true;
}
if (!fail) {
return;
}
break;
}
case 64: {
if (t.is_scalar()) {
os << "int64_t";
} else if (t.lanes() == 2) {
os << "longlong2";
} else if (t.lanes() == 3) {
os << "longlong3";
} else if (t.lanes() == 4) {
os << "longlong4";
}
return;
}
default:
fail = true;
break;
}
if (!fail && lanes == 1) {
return;
}
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes;
return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}
void CodeGenTileLangHIP::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) { // NOLINT(*)
// Declare the result.
std::string sret = name_supply_->FreshName("_");
this->PrintIndent();
this->PrintType(t, stream);
stream << ' ' << sret << ";\n";
int ssa_scope = BeginScope();
{
// Unpack into individual ops.
std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());
for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
std::ostringstream value_temp;
if (isalpha(op[0])) {
value_temp << op << "(";
PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
value_temp << ", ";
PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
} else {
value_temp << "(";
PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
value_temp << op;
PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
}
PrintVecElemStore(sret, t, i, value_temp.str());
}
}
EndScope(ssa_scope);
os << sret;
}
void CodeGenTileLangHIP::PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) { // NOLINT(*)
if (t.is_scalar()) {
os << vec;
return;
}
static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
std::string type_name = t.is_int() ? "char" : "unsigned char";
if (t.lanes() == 2 || t.lanes() == 3) {
os << vec << "." << access[i % t.lanes()];
} else {
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
}
} else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else if (t.is_bfloat16()) {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
if (t.is_int()) {
type_name = "short";
} else if (t.is_uint()) {
type_name = "ushort";
}
} else if (t.bits() == 32) {
if (t.is_int()) {
type_name = "int";
} else if (t.is_uint()) {
type_name = "uint";
} else if (t.is_float()) {
type_name = "float";
}
}
ICHECK(!type_name.empty());
os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else {
os << vec << "." << access[i];
}
}
void CodeGenTileLangHIP::PrintVecElemStore(const std::string& vec, DataType t, int i,
const std::string& value) {
this->PrintIndent();
static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (t.lanes() == 2 || t.lanes() == 3) {
stream << vec << '.' << access[i % t.lanes()] << "=" << "(" << value << ");\n";
} else {
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
stream << ac << "=";
// Do not read the first undef lane.
if (i != 0) {
stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |";
}
stream << "(" << value << " << " << i % 4 * 8 << ");\n";
}
} else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
<< value << ";\n";
} else if (t.is_bfloat16()) {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
<< " = " << value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
if (t.is_int()) {
type_name = "short";
} else if (t.is_uint()) {
type_name = "ushort";
}
} else if (t.bits() == 32) {
if (t.is_int()) {
type_name = "int";
} else if (t.is_uint()) {
type_name = "uint";
} else if (t.is_float()) {
type_name = "float";
}
}
ICHECK(!type_name.empty());
stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n";
} else {
stream << vec << "." << access[i] << " = " << value << ";\n";
}
}
void CodeGenTileLangHIP::PrintStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
// DO nothing.
} else if (sync == "shared" || sync == "shared.dyn") {
this->PrintIndent();
this->stream << "__syncthreads();\n";
}
}
void CodeGenTileLangHIP::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead";
if (scope == "shared") {
os << "__shared__ ";
} else if (scope == "shared.dyn") {
os << "extern __shared__ __align__(1024) ";
}
}
std::string CodeGenTileLangHIP::CastFromTo(std::string value, DataType from, DataType target) {
if (from == target) return value;
std::ostringstream os;
os << "((";
this->PrintType(target, os);
os << ")";
if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) {
os << "(";
if (target.is_uint()) {
os << "u";
}
os << "int)";
}
os << value << ")";
return os.str();
}
void CodeGenTileLangHIP::VisitExpr_(const CastNode* op, std::ostream& os) {
DataType from_ty = op->value.dtype();
DataType target_ty = op->dtype;
ICHECK_EQ(target_ty.lanes(), from_ty.lanes());
// Emit simple C-style type conversion.
if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
// We could emit make_float4 like calls, but the emitted code looks
// too compact to read. Emit this as vectorized unary ops.
std::string sret = name_supply_->FreshName("_");
this->PrintIndent();
this->PrintType(target_ty, stream);
stream << ' ' << sret << ";\n";
{
std::string src = SSAGetID(PrintExpr(op->value), from_ty);
for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
std::ostringstream val;
val << "(";
PrintType(target_ty.element_of(), val);
val << ")(";
PrintVecElemLoad(src, from_ty, i, val);
val << ")";
PrintVecElemStore(sret, target_ty, i, val.str());
}
}
os << sret;
}
void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg, std::ostream& os) { // NOLINT(*)
DataType ret_dtype = GetRuntimeDataType(ret_type);
if (ret_dtype.is_vector()) {
//
// Emit an unsupported vector call
//
// v = intrin_f((float4*)A[0], (float4*)B[0])
//
// as
//
// float4 __ret;
// {
// float4 __arg0 = ((float4*)A)[0];
// float4 __arg1 = ((float4*)B)[0];
// __ret.x = intrin_f(__arg0.x, __arg1.x);
// __ret.y = intrin_f(__arg0.y, __arg1.y);
// __ret.z = intrin_f(__arg0.z, __arg1.z);
// __ret.w = intrin_f(__arg0.w, __arg1.w);
// }
// v = __ret;
//
// Declare the result vector.
std::string sret = name_supply_->FreshName("_");
this->PrintIndent();
this->PrintType(ret_dtype, stream);
stream << ' ' << sret << ";\n";
{
// Load arguments.
std::vector<std::string> sargs;
size_t arg_begin = static_cast<size_t>(skip_first_arg);
for (size_t i = arg_begin; i < args.size(); ++i) {
std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype());
sargs.push_back(std::move(val));
}
// Emit a scalar call for each lane.
for (int i = 0; i < ret_dtype.lanes(); ++i) {
std::ostringstream scall;
scall << global_symbol << "(";
for (size_t j = 0; j < sargs.size(); ++j) {
if (j > 0) scall << ", ";
PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
}
scall << ")";
PrintVecElemStore(sret, ret_dtype, i, scall.str());
}
}
os << sret;
} else {
CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os);
}
}
// Print a reference expression to a buffer.
std::string CodeGenTileLangHIP::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) {
const VarNode* buffer_var = buffer->data.get();
std::ostringstream os;
std::string vid = GetVarID(buffer_var);
std::string scope;
if (alloc_storage_scope_.count(buffer_var)) {
scope = alloc_storage_scope_.at(buffer_var);
}
// bool is_vol = IsVolatile(buffer_var);
// always false for tl cutlass backend.
bool is_vol = false;
auto ptr_cast = [this, is_vol, scope](DataType pointed_to) {
std::ostringstream ptr_os;
ptr_os << "(";
if (is_vol) {
ptr_os << "volatile ";
}
if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, ptr_os);
}
PrintType(pointed_to, ptr_os);
ptr_os << "*)";
return ptr_os.str();
};
DataType buffer_element_dtype = buffer->dtype;
std::string buffer_str = vid;
if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) {
std::stringstream temp;
temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
buffer_str = temp.str();
}
std::string index_str = PrintExpr(index);
if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
// This is a special case, because CodegenCUDA::PrintType()
// returns "int" for bool and for 4-bit integers. In most cases,
// we divide by the number of lanes to determine the index.
// However, the backing type for scalar int4 and scalar bool is
// int32. Therefore, we need to divide by the ratio of their
// sizes in that case.
int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes();
os << "*("
<< "(" << ptr_cast(t) << vid << ")"
<< " + " << index_str << " / " << div_factor << ")";
} else if (t == buffer_element_dtype) {
os << buffer_str << "[" << index_str << "]";
} else {
os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
}
return os.str();
}
void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) {
this->PrintIndent();
this->stream << name << "(";
for (size_t i = offset; i < op->args.size(); i++) {
if (i > offset) this->stream << ", ";
this->stream << this->PrintExpr(op->args[i]);
}
this->stream << ");\n";
};
if (op->op.same_as(builtin::ptx_cp_async())) {
std::string dst = this->PrintExpr(op->args[0]);
std::string dst_offset = this->PrintExpr(op->args[1]);
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
// use size of argument list to indicate whether or not to use predicated cp.async
if (op->args.size() == 5) {
this->PrintIndent();
this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+" << dst_offset << ", " << src
<< "+" << src_offset << ");\n";
} else {
std::string condition = this->PrintExpr(op->args[5]);
this->PrintIndent();
this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst << "+" << dst_offset
<< ", " << src << "+" << src_offset << ", " << condition << ");\n";
}
} else if (op->op.same_as(builtin::ptx_commit_group())) {
print_extern_call_stmt("tl::cp_async_commit");
} else if (op->op.same_as(builtin::ptx_wait_group())) {
int n = Downcast<IntImm>(op->args[0])->value;
std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">";
print_extern_call_stmt(func_name, 1);
} else if (op->op.same_as(builtin::create_barriers())) {
this->PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value;
std::string barrier_name = "_mbarrier";
this->stream << "__shared__ uint64_t " << barrier_name << "[" << barrier_count << "];\n";
} else if (op->op.same_as(tl::GetMBarrierOp())) {
std::string barrier_name = "_mbarrier";
std::string barrier_id = this->PrintExpr(op->args[0]);
os << barrier_name + "[" + barrier_id + "]";
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
print_extern_call_stmt("tl::mbarrier_arrive");
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
print_extern_call_stmt("tl::mbarrier_init");
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
print_extern_call_stmt("tl::mbarrier_arrive_expect_tx");
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::MBarrierExpectTX())) {
print_extern_call_stmt("tl::mbarrier_expect_tx");
} else if (op->op.same_as(tl::MBarrierWaitParity())) {
print_extern_call_stmt("tl::mbarrier_wait");
} else if (op->op.same_as(tl::SyncThreadsPartialOp())) {
print_extern_call_stmt("tl::syncthreads_partial");
} else if (op->op.same_as(tl::TMALoadOp())) {
print_extern_call_stmt("tl::tma_load");
} else if (op->op.same_as(tl::TMALoadIm2ColOp())) {
print_extern_call_stmt("tl::tma_load_im2col");
} else if (op->op.same_as(tl::TMAStoreOp())) {
print_extern_call_stmt("tl::tma_store");
} else if (op->op.same_as(tl::LDMatrixOp())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
if (trans == 1) func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::STMatrixOp())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
if (trans == 1) func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::FenceProxyAsyncOp())) {
print_extern_call_stmt("tl::fence_proxy_async");
} else if (op->op.same_as(tl::SetMaxNReg())) {
this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value;
int is_inc = Downcast<IntImm>(op->args[1])->value;
std::string func_name = is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
} else if (op->op.same_as(tl::WaitWgmma())) {
this->PrintIndent();
int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
} else if (op->op.same_as(tl::PackB16Op())) {
os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1])
<< ")";
} else if (op->op.same_as(builtin::tvm_fill_fragment())) {
need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 6U);
os << "nvcuda::wmma::fill_fragment(";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[5], os);
os << ")";
} else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::load_matrix_sync(";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[5], os);
os << ", ";
this->PrintExpr(op->args[6], os);
os << ")";
} else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::store_matrix_sync(";
this->PrintExpr(op->args[5], os);
os << ", ";
this->PrintExpr(op->args[0], os);
os << "[";
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[6], os);
if (const StringImmNode* str = op->args[7].as<StringImmNode>()) {
os << ", nvcuda::wmma::mem_" << str->value;
} else {
LOG(FATAL) << "Invalid parameters";
}
os << ")";
} else if (op->op.same_as(builtin::tvm_mma_sync())) {
need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::mma_sync(";
for (int i = 0; i < 4; ++i) {
this->PrintExpr(op->args[i * 2], os);
os << "[";
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", " : ")");
}
} else if (op->op.same_as(builtin::tvm_bmma_sync())) {
need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::bmma_sync(";
for (int i = 0; i < 4; ++i) {
this->PrintExpr(op->args[i * 2], os);
os << "[";
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", " : ")");
}
}else if (op->op.same_as(builtin::tvm_mfma())) {
// arg 0: prefix: {otype}_16x16x16{itype}
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: float16, float32, ...
// arg 4: B precision: float16, float32, ...
// arg 5: C precision: float32, float64, ...
// arg 6: A multiplicand
// arg 7: A multiplicand index
// arg 8: B multiplicand
// arg 9: B multiplicand index
// arg 10: C accumulator
// arg 11: C accumulator index
ICHECK(op->args.size() == 12U) << "Invalid number of arguments for tvm_mfma";
std::string prefix = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
std::string a_ref = this->PrintExpr(op->args[6]);
std::string a_bias = this->PrintExpr(op->args[7]);
std::string b_ref = this->PrintExpr(op->args[8]);
std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
ICHECK(A_layout == "row" || B_layout == "row") << "Matrix core only support row major";
// map for dtype -> float32x4 -> float4
std::unordered_map<std::string, std::string> dtype_map = {
{"int8", "char"},
{"int32", "int"},
{"int8x4", "int32_t"},
{"int32x4", "int32x4"},
{"float16", "half"},
{"float32", "float"},
{"float64", "double"},
{"float16x4", "float16x4"},
{"bfloat16x4", "bfloat16x4"},
{"float32x4", "float32x4"},
{"float32x16", "float32x16"}
};
std::string call_mfma_code = R"({
*((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}),
*((({B_dytpe}*){b_ref}) + {b_bias}),
*((({C_dytpe}*){c_ref}) + {c_bias}), 0, 0, 0);
})";
std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix;
Replacer replacer;
replacer.register_rule("{mfma_buildin}", mfma_buildin);
replacer.register_rule("{A_dytpe}", dtype_map[A_dtype]);
replacer.register_rule("{B_dytpe}", dtype_map[B_dtype]);
replacer.register_rule("{C_dytpe}", dtype_map[C_dtype]);
replacer.register_rule("{a_ref}", a_ref);
replacer.register_rule("{a_bias}", a_bias);
replacer.register_rule("{b_ref}", b_ref);
replacer.register_rule("{b_bias}", b_bias);
replacer.register_rule("{c_ref}", c_ref);
replacer.register_rule("{c_bias}", c_bias);
os << replacer.rewrite(call_mfma_code);
} else {
CodeGenC::VisitExpr_(op, os);
}
}
void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == tir::attr::async_commit_queue_scope) {
const IntImmNode* queue_id = op->value.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
this->VisitStmt(op->body);
auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
this->VisitExpr(commit_group, this->stream);
return;
} else if (op->attr_key == tir::attr::async_wait_queue_scope) {
auto wait_attrs = GetAsyncWaitAttributes(op);
auto queue_id = wait_attrs.first.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
auto wait_cnt = wait_attrs.second;
auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
this->VisitExpr(wait_group, this->stream);
auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner);
this->VisitStmt(inner->body);
return;
} else if (op->attr_key == "threadblock_swizzle_pattern") {
this->PrintIndent();
const StringImmNode* pattern = op->value.as<StringImmNode>();
ICHECK(pattern);
this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
this->VisitStmt(op->body);
return;
}
CodeGenC::VisitStmt_(op);
}
void CodeGenTileLangHIP::VisitStmt_(const AllocateNode* op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
this->PrintIndent();
std::string scope = GetPtrStorageScope(op->buffer_var);
PrintStorageScope(scope, stream);
PrintType(op->dtype, stream);
if (scope == "shared.dyn") {
stream << ' ' << vid << "[];\n";
} else {
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1)) &&
scope == "shared") {
constant_size = constant_size / (32 / op->dtype.bits());
}
stream << ' ' << vid << '[' << constant_size << "];\n";
}
RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
}
void CodeGenTileLangHIP::VisitExpr_(const RampNode* op, std::ostream& os) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
os << "(make_";
PrintType(op->dtype, os);
os << "(";
for (int i = 0; i < lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
if (i != lanes - 1) os << ", ";
}
os << "))";
}
void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && lanes == 4) {
// make_int8x4
const int64_t* p = as_const_int(op->value);
ICHECK(p);
int64_t v = *p & 0xFF;
v = (v << 24) | (v << 16) | (v << 8) | v;
if (op->dtype.is_uint()) {
os << "(uint)" << v;
} else {
os << "(int)" << v;
}
return;
}
if (op->dtype.is_float16()) {
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_half2(" << v << ", " << v << ")";
}
os << ')';
return;
}
if (op->dtype.is_bfloat16()) {
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
}
os << ')';
return;
}
if (op->dtype.is_float() && op->dtype.bits() == 32 && op->dtype.lanes() == 8) {
std::string v = PrintExpr(op->value);
os << "make_ulonglong4(";
for (int i = 0; i < 4; ++i) {
if (i != 0) os << ", ";
os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")";
}
os << ')';
return;
}
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
bool fail = false;
const int64_t* p = as_const_int(op->value);
ICHECK(p);
int64_t v = *p & 0xF;
if (lanes == 4) {
v = (v << 12) | (v << 8) | (v << 4) | v;
if (op->dtype.is_uint()) {
os << "(uint16_t)" << v;
} else {
os << "(int16_t)" << v;
}
} else {
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v;
if (lanes == 8) {
if (op->dtype.is_uint()) {
os << "(uint)" << v;
} else {
os << "(int)" << v;
}
} else if (lanes == 16 || lanes == 32) {
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 8; ++i) {
if (i != 0) os << ", ";
if (op->dtype.is_uint()) {
os << "(uint)" << v;
} else {
os << "(int)" << v;
}
}
os << ')';
} else {
fail = true;
}
}
if (!fail) {
return;
}
}
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
os << ')';
}
inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangHIP* p) { // NOLINT(*)
// Type code is kBFloat
if (op->dtype.is_bfloat16()) {
os << "bfloat16_t";
os << '(' << std::scientific << op->value << 'f' << ')';
return;
}
// Type code is kFloat
switch (op->dtype.bits()) {
case 64:
case 32: {
std::ostringstream temp;
if (std::isinf(op->value)) {
if (op->value < 0) {
temp << "-";
}
temp << ((op->dtype.bits() == 32) ? "HIPRT_INF_F" : "HIPRT_INF");
} else if (std::isnan(op->value)) {
temp << ((op->dtype.bits() == 32) ? "HIPRT_NAN_F" : "HIPRT_NAN");
} else {
temp << std::scientific << op->value;
if (op->dtype.bits() == 32) temp << 'f';
}
p->MarkConst(temp.str());
os << temp.str();
break;
}
case 16: {
os << "half_t" << '(';
FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
PrintConst(const_f32.get(), os, p);
os << ')';
break;
}
default:
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
}
}
void CodeGenTileLangHIP::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenTileLangHIP::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
std::ostream& os) {
// Cast away volatile qualifier for fp16 types. That is, only loads and
// stores are volatile. The loaded objects are not marked as volatile.
//
if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) {
os << "(";
PrintType(op->dtype, os);
os << ")(" << value << ")";
} else {
os << value;
}
}
void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
std::ostream& os) {
ICHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (!(t.lanes() == 2 || t.lanes() == 3)) {
if (i != 0) {
os << "|";
}
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
return;
}
}
if (t.is_float16()) {
if (i == 0) {
os << "make_";
PrintType(t, os);
os << '(';
}
if (i % 2 == 0) {
os << "__pack_half2(" << value;
} else {
os << "," << value << ")";
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
}
return;
}
if (t.is_bfloat16()) {
if (i == 0) {
os << "make_";
PrintType(t, os);
os << '(';
}
if (i % 2 == 0) {
os << "__pack_bfloat162(" << value;
} else {
os << "," << value << ")";
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
}
return;
}
if (i == 0) {
os << "make_";
PrintType(t, os);
os << "(";
}
os << value;
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
return;
}
void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) {
// clear previous generated state.
this->InitFuncState(f);
// reserve keywords
ReserveKeywordsAsUnique();
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
this->PrintFuncPrefix(stream);
CodeGenC::PrintType(f->ret_type, stream);
this->PrintExtraAttrs(f, stream);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
for (size_t i = 0; i < f->params.size(); ++i) {
tir::Var v = f->params[i];
std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", ";
if (v.dtype().is_handle()) {
// work around for grid constant parameters.
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (ptr->storage_scope == "grid_constant") {
stream << "__grid_constant__ const ";
CodeGenC::PrintType(ptr->element_type, stream);
stream << ' ' << vid;
continue;
}
}
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
CodeGenC::PrintType(GetType(v), stream);
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype);
}
}
if (no_alias) {
PrintRestrict(v, stream);
}
} else {
CodeGenC::PrintType(GetType(v), stream);
}
stream << ' ' << vid;
}
stream << ") {\n";
this->PreFunctionBody(f);
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
}
} // namespace codegen
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file target/codegen.h
* \brief Utility to generate code
*/
#ifndef TVM_TL_TARGET_CODEGEN_HIP_H_
#define TVM_TL_TARGET_CODEGEN_HIP_H_
#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <string>
#include <unordered_map>
#include "target/source/codegen_c.h"
namespace tvm {
namespace codegen {
class CodeGenTileLangHIP final : public CodeGenC {
public:
CodeGenTileLangHIP();
std::string Finish();
// override behavior
void PrintFuncPrefix(std::ostream& os) final;
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final;
std::string CastFromTo(std::string value, DataType from, DataType target) final;
// overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final;
void VisitStmt_(const AllocateNode* op) final;
void VisitStmt_(const AttrStmtNode* op) final;
// Override this as a work around for __grid_constant__ parameter
void AddFunction(const PrimFunc& f);
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg, std::ostream& os) final; // NOLINT(*)
private:
// Handle volatile loads
void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
std::ostream& os) final;
// Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final { return false; }
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangHIP* p);
// whether need math_constants.h
bool need_math_constants_h_{false};
// whether need mfma.h
bool need_wmma_h_{false};
// The size of the barrier array in shared memory
int barrier_count_ = -1;
// whether need mma.h
bool need_mma_h_{false};
// whether need cast_smem_ptr_to_int helper function
bool need_cast_smem_ptr_to_int_{false};
// The name of the barrier array in shared memory
const std::string barrier_name_ = "barrier";
// The alignment of the barrier array in shared memory
// Set to 16 to maintain minimum alignment requirements for async bulk copy
const int barrier_alignment_bytes_ = 16;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TL_TARGET_CODEGEN_HIP_H_
This source diff could not be displayed because it is too large. You can view the blob instead.
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "runtime/cuda/cuda_module.h"
#include "codegen_cuda.h"
namespace tvm {
namespace codegen {
static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(const IRModule& mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info;
for (size_t i = 0; i < f->params.size(); ++i) {
if (f->params[i]->dtype.is_handle()) {
auto ptr = f->params[i]->type_annotation.as<PointerTypeNode>();
if (ptr && ptr->storage_scope == "grid_constant") {
info.arg_types.push_back(DataType(kTVMGridConstant, 64, 1));
continue;
}
}
info.arg_types.push_back(f->params[i].dtype());
}
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto& tag : opt.value()) {
info.launch_param_tags.push_back(tag);
}
}
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info;
}
return fmap;
}
runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenTileLangCUDA cg;
cg.Init(output_ssa);
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {
ptx = (*f)(code, target).operator std::string();
if (ptx[0] != '/') fmt = "cubin";
} else {
ICHECK(0);
}
return runtime::CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}
String BuildTLDebug(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenTileLangCUDA cg;
cg.Init(output_ssa);
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string();
}
return String(code);
}
TVM_REGISTER_GLOBAL("target.build.tilelang_cuda").set_body_typed(BuildTileLangCUDA);
TVM_REGISTER_GLOBAL("target.build.tl_debug_codegen").set_body_typed(BuildTLDebug);
} // namespace codegen
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#if defined(__linux__)
#include <sys/stat.h>
#endif
#include <hip/hip_runtime.h>
#include <hip/hiprtc.h>
#include "runtime/rocm/rocm_module.h"
#include "codegen_hip.h"
namespace tvm {
namespace codegen {
#define HIPRTC_CALL(x) \
\
{ \
\
hiprtcResult result = x; \
\
if (result != HIPRTC_SUCCESS) { \
\
LOG(FATAL) \
<< "HiprtcError: " #x " failed with error: " << hiprtcGetErrorString(result); \
\
\
} \
\
\
}
static std::string FindHIPIncludePath() {
#if defined(_WIN32)
const std::string delimiter = "\\";
#else
const std::string delimiter = "/";
#endif
std::string hip_include_path;
const char* hip_path_env = std::getenv("HIP_PATH");
if (hip_path_env != nullptr) {
hip_include_path += hip_path_env;
hip_include_path += delimiter + "include";
return hip_include_path;
}
#if defined(__linux__)
struct stat st;
hip_include_path = "/opt/rocm/hip/include";
if (stat(hip_include_path.c_str(), &st) == 0) {
return hip_include_path;
}
if (stat("/usr/include/hip/hip_runtime.h", &st) == 0) {
return "/usr/include/hip";
}
#endif
LOG(FATAL) << "Cannot find HIP include path."
<< "HIP_PATH is not set or ROCm is not installed in the default installation path."
<< "In other than linux, it is necessary to set HIP_PATH.";
return hip_include_path;
}
static std::string HIPRTCCompile(const std::string& code, bool include_path = false) {
std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{};
hiprtcProgram prog;
std::string cc = "gfx900"; // Default target architecture (can be changed as needed)
int major, minor;
hipError_t e1 = hipDeviceGetAttribute(&major, hipDeviceAttributeComputeCapabilityMajor, 0);
hipError_t e2 = hipDeviceGetAttribute(&minor, hipDeviceAttributeComputeCapabilityMinor, 0);
if (e1 == hipSuccess && e2 == hipSuccess) {
cc = "gfx" + std::to_string(major * 100 + minor * 10);
} else {
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to gfx900.";
}
compile_params.push_back("--gpu-architecture=" + cc);
if (include_path) {
std::string include_option = "--include-path=" + FindHIPIncludePath();
compile_params.push_back(include_option);
}
for (const auto& string : compile_params) {
param_cstrings.push_back(string.c_str());
}
HIPRTC_CALL(hiprtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
hiprtcResult compile_res =
hiprtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());
size_t log_size;
HIPRTC_CALL(hiprtcGetProgramLogSize(prog, &log_size));
std::string log;
log.resize(log_size);
HIPRTC_CALL(hiprtcGetProgramLog(prog, &log[0]));
ICHECK_EQ(compile_res, HIPRTC_SUCCESS) << log;
size_t code_size;
HIPRTC_CALL(hiprtcGetCodeSize(prog, &code_size));
std::string code_out;
code_out.resize(code_size);
HIPRTC_CALL(hiprtcGetCode(prog, &code_out[0]));
HIPRTC_CALL(hiprtcDestroyProgram(&prog));
return code_out;
}
static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(const IRModule& mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info;
for (size_t i = 0; i < f->params.size(); ++i) {
if (f->params[i]->dtype.is_handle()) {
auto ptr = f->params[i]->type_annotation.as<PointerTypeNode>();
if (ptr && ptr->storage_scope == "grid_constant") {
info.arg_types.push_back(DataType(kTVMGridConstant, 64, 1));
continue;
}
}
info.arg_types.push_back(f->params[i].dtype());
}
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto& tag : opt.value()) {
info.launch_param_tags.push_back(tag);
}
}
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol.value())] = info;
}
return fmap;
}
runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenTileLangHIP cg;
cg.Init(output_ssa);
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangHIP: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_hip_postproc")) {
code = (*f)(code, target).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
if (const auto* f = Registry::Get("tvm_callback_hip_compile")) {
ptx = (*f)(code, target).operator std::string();
if (ptx[0] != '/') fmt = "hsaco";
} else {
ptx = HIPRTCCompile(code, false);
}
return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string());
}
TVM_REGISTER_GLOBAL("target.build.tilelang_hip").set_body_typed(BuildTileLangHIP);
} // namespace codegen
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/target/utils.cc
* \brief helper functions for target attributes.
*/
#include "utils.h"
namespace tvm {
namespace tl {
bool TargetIsCuda(Target target) { return target->GetTargetDeviceType() == kDLCUDA; }
bool TargetIsRocm(Target target) { return target->GetTargetDeviceType() == kDLROCM; }
int GetArchInt(Target target) {
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
const char* arch_str = s.value().c_str();
ICHECK_EQ(arch_str[0], 's');
ICHECK_EQ(arch_str[1], 'm');
ICHECK_EQ(arch_str[2], '_');
return atoi(&arch_str[3]);
}
bool TargetIsVolta(Target target) {
if (!TargetIsCuda(target)) return false;
int arch = GetArchInt(target);
return arch >= 70 && arch < 75;
}
bool TargetIsTuring(Target target) {
if (!TargetIsCuda(target)) return false;
int arch = GetArchInt(target);
return arch >= 75 && arch < 80;
}
bool TargetIsAmpere(Target target) {
if (!TargetIsCuda(target)) return false;
int arch = GetArchInt(target);
return arch >= 80 && arch < 90;
}
bool TargetIsHopper(Target target) {
if (!TargetIsCuda(target)) return false;
int arch = GetArchInt(target);
return arch >= 90;
}
bool TargetIsCDNA(Target target) {
if (!TargetIsRocm(target)) return false;
if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
// if mcpu start with "gfx9", it is CDNA
return mcpu.find("gfx9") == 0;
}
return false;
}
bool TargetHasAsyncCopy(Target target) {
if (TargetIsCuda(target)) {
int arch = GetArchInt(target);
return arch >= 80;
} else if (TargetIsCDNA(target)) {
if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
if (mcpu.rfind("gfx9", 0) == 0) {
int gfx_version = std::stoi(mcpu.substr(3, 2));
return gfx_version >= 94;
}
return false;
} else {
return false;
}
}
return false;
}
bool TargetHasLdmatrix(Target target) {
if (!TargetIsCuda(target)) return false;
int arch = GetArchInt(target);
return arch >= 75;
}
bool TargetHasStmatrix(Target target) {
if (!TargetIsCuda(target)) return false;
int arch = GetArchInt(target);
return arch >= 90;
}
} // namespace tl
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/target/utils.h
* \brief helper functions for target attributes.
*
*/
#ifndef TVM_TL_TARGET_UTILS_H_
#define TVM_TL_TARGET_UTILS_H_
#include <tvm/target/target.h>
namespace tvm {
namespace tl {
bool TargetIsCuda(Target target);
bool TargetIsRocm(Target target);
bool TargetIsVolta(Target target);
bool TargetIsTuring(Target target);
bool TargetIsAmpere(Target target);
bool TargetIsHopper(Target target);
bool TargetIsCDNA(Target target);
bool TargetHasAsyncCopy(Target target);
bool TargetHasLdmatrix(Target target);
bool TargetHasStmatrix(Target target);
} // namespace tl
} // namespace tvm
#endif // TVM_TL_TARGET_UTILS_H_
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <cuda_runtime.h>
#include <cutlass/fast_math.h>
#include <cutlass/numeric_types.h>
#include <math_constants.h>
using cutlass::bfloat16_t;
using cutlass::half_t;
using cutlass::tfloat32_t;
#define hexp cutlass::fast_exp
#define hlog cutlass::fast_log
#define hsqrt cutlass::fast_sqrt
#define htanh cutlass::fast_tanh
#define hpow powf
#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short
#define TL_DEVICE __forceinline__ __device__
// Pack two half values.
TL_DEVICE unsigned __pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short*)&x);
unsigned v1 = *((unsigned short*)&y);
return (v1 << 16) | v0;
}
// Pack two half_t values.
TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
unsigned v0 = *((unsigned short*)&x);
unsigned v1 = *((unsigned short*)&y);
return (v1 << 16) | v0;
}
// Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) {
unsigned v0 = *((unsigned short*)&x);
unsigned v1 = *((unsigned short*)&y);
return (v1 << 16) | v0;
}
/// Helper to cast SMEM pointer to unsigned
TL_DEVICE uint32_t smem_ptr_to_uint(void const* const ptr) {
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
}
// AtomicAdd Functions for FP16
TL_DEVICE void atomicAdd(half_t* address, half_t val) {
// Use atomicCAS with built-in cuda_fp16 support
atomicAdd(reinterpret_cast<half*>(address), static_cast<half>(val));
}
// AtomicAdd Functions for FP16
TL_DEVICE void atomicAdd(half_t* address, half_t* val) {
atomicAdd(reinterpret_cast<half*>(address), static_cast<half>(*val));
}
// AtomicAdd Functions for FP16
TL_DEVICE void atomicAddx2(half_t* address, half_t* val) {
atomicAdd(reinterpret_cast<half2*>(address), static_cast<half2>(*reinterpret_cast<half2*>(val)));
}
TL_DEVICE void atomicAdd(half_t* address, float val) {
// Use atomicCAS with built-in cuda_fp16 support
atomicAdd(reinterpret_cast<half*>(address), __float2half(val));
}
// DP4A
template<typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype* a, InDatatype* b, OutDatatype* c) {
const int a_int = *((int*)a);
const int b_int = *((int*)b);
const int c_int = *((int*)c);
*c = __dp4a(a_int, b_int, c_int);
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "common.h"
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
#include "copy_sm90.h"
#endif
namespace tl {
TL_DEVICE void cp_async_commit() { asm volatile("cp.async.commit_group;\n" ::); }
template <int N>
TL_DEVICE void cp_async_wait() {
if constexpr (N == 0) {
asm volatile("cp.async.wait_all;\n" ::);
} else {
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
}
}
template <int N>
TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
static_assert(N == 16 || N == 8 || N == 4);
unsigned int addr = smem_ptr_to_uint(smem_addr);
if constexpr (N == 16) {
__asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
#else
"cp.async.cg.shared.global [%0], [%1], %2;"
#endif
::"r"(addr),
"l"((void*)(global_ptr)), "n"(N));
} else {
__asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH
"cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
#else
"cp.async.ca.shared.global [%0], [%1], %2;"
#endif
::"r"(addr),
"l"((void*)(global_ptr)), "n"(N));
}
}
template <int N>
TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global_ptr, bool cond) {
static_assert(N == 16 || N == 8 || N == 4);
int bytes = cond ? N : 0;
unsigned int addr = smem_ptr_to_uint(smem_addr);
if constexpr (N == 16) {
__asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
"cp.async.cg.shared.global [%0], [%1], %2, %3;"
#endif
::"r"(addr),
"l"((void*)(global_ptr)), "n"(N), "r"(bytes));
} else {
__asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH
"cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
"cp.async.ca.shared.global [%0], [%1], %2, %3;"
#endif
::"r"(addr),
"l"((void*)(global_ptr)), "n"(N), "r"(bytes));
}
}
} // namespace tl
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <cuda.h>
#include "common.h"
namespace tl {
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar,
void const* const smem_ptr, int32_t const& crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0)
: "memory");
}
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar,
void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1)
: "memory");
}
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar,
void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1,
int32_t const& crd2) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4, %5}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2)
: "memory");
}
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar,
void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1,
int32_t const& crd2, int32_t const& crd3) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2),
"r"(crd3)
: "memory");
}
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar,
void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1,
int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2),
"r"(crd3), "r"(crd4)
: "memory");
}
TL_DEVICE void tma_load_im2col(const CUtensorMap& descriptor, uint64_t& smem_mbar,
void const* const smem_ptr, int32_t const& coord_c,
int32_t const& coord_w, int32_t const& coord_h,
int32_t const& coord_n, uint16_t const& offset_w,
uint16_t const& offset_h) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(coord_c), "r"(coord_w),
"r"(coord_h), "r"(coord_n), "h"(offset_w), "h"(offset_h)
: "memory");
}
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr,
int32_t const& crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0)
: "memory");
}
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1)
: "memory");
}
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), "r"(crd2)
: "memory");
}
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2,
int32_t const& crd3) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
: "memory");
}
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2,
int32_t const& crd3, int32_t const& crd4) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
: "memory");
}
TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap& descriptor) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory");
}
TL_DEVICE void mbarrier_init(uint64_t& smem_barrier, uint32_t arrive_count) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.init.shared.b64 [%1], %0;" : : "r"(arrive_count), "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_wait(uint64_t& smem_barrier, int phase_bit) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n"
"@!P1 bra.uni LAB_WAIT;\n"
"}\n" ::"r"(smem_int_ptr),
"r"(phase_bit));
}
TL_DEVICE void mbarrier_arrive(uint64_t& smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_expect_tx(uint64_t& smem_barrier, uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.expect_tx.shared.b64 [%1], %0;"
:
: "r"(transaction_bytes), "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t& smem_barrier, uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0;"
:
: "r"(transaction_bytes), "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_cp_async_arrive(uint64_t& smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];" : : "r"(smem_int_ptr));
}
TL_DEVICE void fence_proxy_async() { asm volatile("fence.proxy.async.shared::cta;" : :); }
TL_DEVICE void syncthreads_partial(uint64_t& smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint64_t state;
asm volatile(
"{\n"
".reg .pred P1;\n"
"mbarrier.arrive.shared.b64 %1, [%0];\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.shared.b64 P1, [%0], %1;\n"
"@!P1 bra.uni LAB_WAIT;\n"
"}\n"
:
: "r"(smem_int_ptr), "l"(state));
}
template<uint32_t RegCount>
TL_DEVICE void warpgroup_reg_alloc(){
asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) );
}
template<uint32_t RegCount>
TL_DEVICE void warpgroup_reg_dealloc(){
asm volatile( "setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount) );
}
} // namespace tl
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
#include "gemm_sm90.h"
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
#include "gemm_sm80.h"
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 700))
#include "gemm_sm70.h"
#else
#endif
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/gemm/warp/mma_tensor_op_sm70.h>
#include "common.h"
using cutlass::gemm::GemmShape;
// Primary template
// Add 128 bits padding when the last dim is a multiple of 256 bits
template <typename T, bool transpose, int M, int K, typename Enable = void>
struct DispatchSharedMemoryLayoutA {
using Layout = typename std::conditional<transpose, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type;
static int constexpr Dim = transpose ? M : K;
static int constexpr Stride = (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
};
template <typename T, bool transpose, int N, int K, typename Enable = void>
struct DispatchSharedMemoryLayoutB {
using Layout = typename std::conditional<transpose, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type;
static int constexpr Dim = transpose ? K : N;
static int constexpr Stride = (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
};
// Partial specialization for half_t
template <int M, int K>
struct DispatchSharedMemoryLayoutA<half_t, true, M, K, typename std::enable_if<M % 64 == 0>::type> {
using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous<16>;
static int constexpr Stride = M;
};
template <int M, int K>
struct DispatchSharedMemoryLayoutA<half_t, false, M, K> {
using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
static int constexpr Stride = M;
};
template <int N, int K>
struct DispatchSharedMemoryLayoutB<half_t, true, N, K> {
using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
static int constexpr Stride = N;
};
template <int N, int K>
struct DispatchSharedMemoryLayoutB<half_t, false, N, K,
typename std::enable_if<N % 64 == 0>::type> {
using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous<16>;
static int constexpr Stride = N;
};
template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type_raw, typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
using A_type = A_type_raw;
using B_type = B_type_raw;
using C_type = C_type_raw;
using InstructionShape = GemmShape<16, 16, 4>;
using SMemLayoutA =
typename DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM, Shape::kK>::Layout;
using SMemLayoutB =
typename DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN, Shape::kK>::Layout;
static constexpr int stride_A =
DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM, Shape::kK>::Stride;
static constexpr int stride_B =
DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN, Shape::kK>::Stride;
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<InstructionShape, 32, A_type,
typename std::conditional<trans_A, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type,
B_type,
typename std::conditional<trans_B, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type,
C_type, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>,
cutlass::MatrixShape<1, 1> >;
static_assert(Shape::kM % num_warp_m == 0);
static_assert(Shape::kN % num_warp_n == 0);
using MmaWarp = typename cutlass::gemm::warp::MmaVoltaTensorOp<
GemmShape<Shape::kM / num_warp_m, Shape::kN / num_warp_n, InstructionShape::kK>, A_type,
SMemLayoutA, B_type, SMemLayoutB, C_type, cutlass::layout::RowMajor, Policy>;
using TensorRefA = typename MmaWarp::IteratorA::TensorRef;
using TensorRefB = typename MmaWarp::IteratorB::TensorRef;
using FragmentA = typename MmaWarp::FragmentA;
using FragmentB = typename MmaWarp::FragmentB;
using FragmentC = typename MmaWarp::FragmentC;
using IteratorA = typename MmaWarp::IteratorA;
using IteratorB = typename MmaWarp::IteratorB;
static_assert(Shape::kK % InstructionShape::kK == 0);
static int constexpr kKgroups = Shape::kK / InstructionShape::kK;
static CUTLASS_DEVICE void body(A_type_raw* pA, B_type_raw* pB, FragmentC& accum,
const int warp_idx_m, const int warp_idx_n, const int lane_id) {
MmaWarp mma_op;
FragmentA frag_A;
FragmentB frag_B;
const TensorRefA ref_A((A_type*)pA, stride_A);
const TensorRefB ref_B((B_type*)pB, stride_B);
IteratorA iter_A(ref_A, lane_id);
IteratorB iter_B(ref_B, lane_id);
iter_A.add_tile_offset({warp_idx_m, 0});
iter_B.add_tile_offset({0, warp_idx_n});
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < kKgroups; ++k) {
iter_A.load(frag_A);
iter_B.load(frag_B);
++iter_A;
++iter_B;
mma_op(accum, frag_A, frag_B, accum);
}
}
static CUTLASS_DEVICE void body_rs(const FragmentA* frag_A, B_type_raw* pB, FragmentC& accum,
const int warp_idx_n, const int lane_id) {
MmaWarp mma_op;
FragmentB frag_B;
const TensorRefB ref_B((B_type*)pB, stride_B);
IteratorB iter_B(ref_B, lane_id);
iter_B.add_tile_offset({0, warp_idx_n});
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < kKgroups; ++k) {
iter_B.load(frag_B);
++iter_B;
mma_op(accum, frag_A[k], frag_B, accum);
}
}
};
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A, trans_B, A_type,
B_type, C_type>;
using FragmentC = typename MMA::FragmentC;
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
MMA::body(pA, pB, *(FragmentC*)(accum), warp_id / num_warp_n, warp_id % num_warp_n, lane_id);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A, trans_B, A_type,
B_type, C_type>;
using FragmentA = typename MMA::FragmentA;
using FragmentC = typename MMA::FragmentC;
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
MMA::body_rs((const FragmentA*)(pA), pB, *(FragmentC*)(accum), warp_id % num_warp_n, lane_id);
}
}; // namespace tl
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <cute/algorithm/copy.hpp>
#include "common.h"
namespace cute {
template <typename A_type, typename B_type, typename C_type>
struct DispatchInstruction;
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
template <>
struct DispatchInstruction<half_t, half_t, half_t> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>;
};
template <>
struct DispatchInstruction<half_t, half_t, float> {
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>;
};
template <>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float> {
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>;
};
template <>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float> {
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>;
};
template <>
struct DispatchInstruction<int8_t, int8_t, int> {
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>;
};
template <>
struct DispatchInstruction<double, double, double> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
using MMA_Group = Layout<Shape<_2, _2, _1>>;
};
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template <>
struct DispatchInstruction<half_t, half_t, float> {
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _2>>;
};
#endif
template <int Bits, int N, int K, bool K_inner, typename Enable = void>
struct OperandTraits {
// Primary template, use padded layout and default copy
static constexpr int stride = K_inner ? K : N;
static constexpr int padded = stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride;
using Layout =
typename std::conditional<K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<padded>, _1>>,
Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<padded>>>>::type;
using Copy = DefaultCopy;
};
template <int N, int K>
struct OperandTraits<16, N, K, true, typename std::enable_if<K % 64 == 32>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<16, N, K, true, typename std::enable_if<K % 64 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<3, 3, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<16, N, K, false, typename std::enable_if<N % 64 == 32>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T;
};
template <int N, int K>
struct OperandTraits<16, N, K, false, typename std::enable_if<N % 64 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<3, 3, 3>{}, Layout<Shape<_64, _8>, Stride<_1, _64>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{}));
using Copy = SM75_U16x8_LDSM_T;
};
template <int N, int K>
struct OperandTraits<32, N, K, true, typename std::enable_if<K % 32 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<3, 2, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<32, N, K, true, typename std::enable_if<K % 32 == 16>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 2, 3>{}, Layout<Shape<_8, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<32, N, K, false, typename std::enable_if<N % 32 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<3, 2, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K>
struct OperandTraits<32, N, K, false, typename std::enable_if<N % 32 == 16>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 2, 3>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{}));
using Copy = UniversalCopy<tfloat32_t>;
};
template <int N, int K>
struct OperandTraits<8, N, K, true, typename std::enable_if<K % 128 == 64>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 4, 3>{}, Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<8, N, K, true, typename std::enable_if<K % 128 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<3, 4, 3>{}, Layout<Shape<_8, _128>, Stride<_128, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = SM75_U32x4_LDSM_N;
};
template <int N, int K>
struct OperandTraits<64, N, K, true, typename std::enable_if<K % 16 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 0, 4>{}, Layout<Shape<_4, _16>, Stride<_16, _1>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}));
using Copy = DefaultCopy;
};
template <int N, int K>
struct OperandTraits<64, N, K, false, typename std::enable_if<N % 16 == 0>::type> {
using LayoutAtom =
decltype(composition(Swizzle<2, 2, 2>{}, Layout<Shape<_16, _4>, Stride<_1, _16>>{}));
using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape<Int<N>, Int<K>>{}, Step<_2, _1>{}));
using Copy = DefaultCopy;
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type_raw, typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
using A_type = typename std::conditional<std::is_same<A_type_raw, float>::value, tfloat32_t,
A_type_raw>::type;
using B_type = typename std::conditional<std::is_same<B_type_raw, float>::value, tfloat32_t,
A_type_raw>::type;
using C_type = C_type_raw;
using Instruction = DispatchInstruction<A_type, B_type, C_type>;
using OperandATraits = OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>;
using OperandBTraits = OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>;
using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout;
using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;
using TileMma =
TiledMMA<typename Instruction::MMA, Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>,
typename Instruction::MMA_Group>;
template <class... Args>
static CUTE_DEVICE auto remove_swizzle(Layout<Args...> const& layout) {
return layout;
}
// In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0
// the original layout fail to compile, currently using this as a workaround
template <class... Args>
static CUTE_DEVICE auto remove_swizzle(ComposedLayout<Args...> const& layout) {
if constexpr (sizeof(A_type) == 2)
return layout.layout_b();
else
return layout;
}
static CUTE_DEVICE void body(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type*>(pA)), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type*>(pB)), SmemLayoutB{});
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);
auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);
Tensor tCrA = thr_mma.partition_fragment_A(sA);
Tensor tCrB = thr_mma.partition_fragment_B(sB);
Tensor tCsA = thr_copy_A.partition_S(sA);
Tensor tCsB = thr_copy_B.partition_S(sB);
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
// when layout is KxN and n_warp is 1, there seem to be a bug, use this as a workaround
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
}
}
static CUTE_DEVICE void body_rs(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) {
const int tid = threadIdx.x;
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type*>(pB)), SmemLayoutB{});
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);
Tensor tCrB = thr_mma.partition_fragment_B(sB);
Tensor tCsB = thr_copy_B.partition_S(sB);
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
Tensor tCrA = make_tensor(make_rmem_ptr(reinterpret_cast<A_type*>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) {
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
}
gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
}
}
static CUTE_DEVICE void body_sr(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type*>(pA)), SmemLayoutA{});
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);
Tensor tCrA = thr_mma.partition_fragment_A(sA);
Tensor tCsA = thr_copy_A.partition_S(sA);
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
Tensor tCrB = make_tensor(make_rmem_ptr(reinterpret_cast<B_type*>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) {
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
}
}
};
} // namespace cute
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) {
using MMA =
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>;
MMA::body(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) {
using MMA =
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>;
MMA::body_rs(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_sr(A_type* pA, B_type* pB, C_type* accum) {
using MMA =
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}
} // namespace tl
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/arch/barrier.h>
#include <cute/algorithm/copy.hpp>
#include "common.h"
namespace cute {
template <GMMA::Major major, class ElementType, class BLK_MN, class BLK_K>
CUTE_HOST_DEVICE constexpr auto ss_smem_selector() {
auto BLK_MN0 = size<0>(BLK_MN{});
auto BLK_K0 = size<0>(BLK_K{});
static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8.");
static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8.");
if constexpr (major == GMMA::Major::MN) {
if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom<ElementType>{}) == 0) {
return GMMA::Layout_MN_SW128_Atom<ElementType>{};
} else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom<ElementType>{}) == 0) {
return GMMA::Layout_MN_SW64_Atom<ElementType>{};
} else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom<ElementType>{}) == 0) {
return GMMA::Layout_MN_SW32_Atom<ElementType>{};
} else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{}) == 0) {
return GMMA::Layout_MN_INTER_Atom<ElementType>{};
} else {
static_assert(
BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{}) == 0,
"BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom<ElementType>{})");
}
} else if constexpr (major == GMMA::Major::K) {
if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom<ElementType>{}) == 0) {
return GMMA::Layout_K_SW128_Atom<ElementType>{};
} else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom<ElementType>{}) == 0) {
return GMMA::Layout_K_SW64_Atom<ElementType>{};
} else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom<ElementType>{}) == 0) {
return GMMA::Layout_K_SW32_Atom<ElementType>{};
} else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{}) == 0) {
return GMMA::Layout_K_INTER_Atom<ElementType>{};
} else {
static_assert(
BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{}) == 0,
"BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom<ElementType>{})");
}
}
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B,
typename A_type_raw, typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
using A_type = conditional_t<std::is_same<A_type_raw, float>::value, tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value, tfloat32_t, B_type_raw>;
using C_type = C_type_raw;
static constexpr GMMA::Major GmmaMajorA = trans_A ? GMMA::Major::MN : GMMA::Major::K;
static constexpr GMMA::Major GmmaMajorB = trans_B ? GMMA::Major::K : GMMA::Major::MN;
using SmemLayoutAtomA = decltype(ss_smem_selector<GmmaMajorA, A_type, Int<M>, Int<K>>());
using SmemLayoutAtomB = decltype(ss_smem_selector<GmmaMajorB, B_type, Int<N>, Int<K>>());
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
// static_assert(num_warp_n == 1);
static_assert(num_warp_m % 4 == 0);
template <int wg_wait=0>
static CUTE_DEVICE void body(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) {
const int tid = threadIdx.x;
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type*>(pA)), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type*>(pB)), SmemLayoutB{});
auto tiled_mma =
make_tiled_mma(GMMA::ss_op_selector<A_type, B_type, C_type, Shape<Int<M>, Int<N / num_warp_n>, Int<K>>,
GmmaMajorA, GmmaMajorB>(),
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
auto thr_mma = tiled_mma.get_thread_slice(tid);
// Allocate registers for pipelining
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
warpgroup_fence_operand(acc);
warpgroup_arrive();
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(acc);
// warpgroup_fence_operand(acc);
// warpgroup_arrive();
// gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);
// warpgroup_commit_batch();
// if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
// warpgroup_fence_operand(acc);
}
template <int wg_wait=0>
static CUTE_DEVICE void body_rs(A_type_raw* pA, B_type_raw* pB, C_type_raw* pC) {
// TODO: Move bar.sync out of body_rs
// asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n * 32));
const int tid = threadIdx.x;
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type*>(pB)), SmemLayoutB{});
auto tiled_mma =
make_tiled_mma(GMMA::rs_op_selector<A_type, B_type, C_type, Shape<Int<M>, Int<N / num_warp_n>, Int<K>>,
GmmaMajorA, GmmaMajorB>(),
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{});
auto thr_mma = tiled_mma.get_thread_slice(tid);
// Allocate registers for pipelining
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE)
Tensor tCrA = make_tensor(make_rmem_ptr(reinterpret_cast<A_type*>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast<C_type*>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
warpgroup_fence_operand(tCrA);
warpgroup_fence_operand(acc);
warpgroup_arrive();
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(acc);
warpgroup_fence_operand(tCrA);
// warpgroup_fence_operand(acc);
// warpgroup_arrive();
// gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);
// warpgroup_commit_batch();
// if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
// warpgroup_fence_operand(acc);
}
};
} // namespace cute
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, int wg_wait=0,
typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) {
using MMA =
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>;
MMA::body<wg_wait>(pA, pB, accum);
}
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, int wg_wait=0,
typename A_type, typename B_type, typename C_type>
TL_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) {
using MMA =
cute::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, A_type, B_type, C_type>;
MMA::body_rs<wg_wait>(pA, pB, accum);
}
template <int num_mma>
TL_DEVICE void wait_wgmma() {
warpgroup_wait<num_mma>();
}
template <int NumMmaThreads>
TL_DEVICE void warp_scheduler_barrier_sync() {
cutlass::arch::NamedBarrier::sync(
NumMmaThreads,
cutlass::canonical_warp_group_idx() /*id*/);
}
template <int NumMmaThreads>
TL_DEVICE void warp_scheduler_barrier_arrive() {
static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
if constexpr (NumMmaThreads == 256) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/);
} else {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, (cutlass::canonical_warp_group_idx() <= 0 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
}
}
template <int NumMmaThreads>
TL_DEVICE void mma_init() {
static_assert(NumMmaThreads == 256 || NumMmaThreads == 384);
if (cutlass::canonical_warp_group_idx() > 0) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0);
}
if constexpr (NumMmaThreads == 384) {
if (cutlass::canonical_warp_group_idx() > 1) {
cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 1 /*id*/);
}
}
}
} // namespace tl
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "common.h"
namespace tl {
TL_DEVICE void ptx_ldmatrix_x1(void const* const smem_ptr, void* const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(value[0])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x2(void const* const smem_ptr, void* const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(value[0]), "=r"(value[1])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x4(void const* const smem_ptr, void* const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x1_trans(void const* const smem_ptr, void* const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
: "=r"(value[0])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x2_trans(void const* const smem_ptr, void* const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(value[0]), "=r"(value[1])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x4_trans(void const* const smem_ptr, void* const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t* value = reinterpret_cast<int32_t*>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_stmatrix_x1(void const* const smem_ptr, const int32_t& value0) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" ::"r"(smem_int_ptr),
"r"(value0));
}
TL_DEVICE void ptx_stmatrix_x2(void const* const smem_ptr, const int32_t& value0,
const int32_t& value1) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(smem_int_ptr),
"r"(value0), "r"(value1));
}
TL_DEVICE void ptx_stmatrix_x4(void const* const smem_ptr, const int32_t& value0,
const int32_t& value1, const int32_t& value2,
const int32_t& value3) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(smem_int_ptr),
"r"(value0), "r"(value1), "r"(value2), "r"(value3));
}
TL_DEVICE void ptx_stmatrix_x1_trans(void const* const smem_ptr, const int32_t& value0) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" ::"r"(smem_int_ptr),
"r"(value0));
}
TL_DEVICE void ptx_stmatrix_x2_trans(void const* const smem_ptr, const int32_t& value0,
const int32_t& value1) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(smem_int_ptr),
"r"(value0), "r"(value1));
}
TL_DEVICE void ptx_stmatrix_x4_trans(void const* const smem_ptr, const int32_t& value0,
const int32_t& value1, const int32_t& value2,
const int32_t& value3) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"r"(
smem_int_ptr),
"r"(value0), "r"(value1), "r"(value2), "r"(value3));
}
} // namespace tl
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "common.h"
namespace tl {
struct SumOp {
template <typename T>
TL_DEVICE T operator()(T const& x, T const& y) {
return x + y;
}
};
struct MaxOp {
template <typename T>
TL_DEVICE T operator()(T const& x, T const& y) {
return cutlass::fast_max(x, y);
}
};
struct MinOp {
template <typename T>
TL_DEVICE T operator()(T const& x, T const& y) {
return cutlass::fast_min(x, y);
}
};
template <class Reducer, int threads, int scale>
struct AllReduce {
static_assert(threads == 1024 or threads == 512 or threads == 256 or threads == 128 or
threads == 64 or threads == 32 or threads == 16 or threads == 8 or threads == 4 or
threads == 2);
static_assert(threads % scale == 0);
template <typename T>
static TL_DEVICE T run(T x, T* red_buf = nullptr) {
constexpr int offset = threads / 2;
if constexpr (offset >= 32) {
__syncthreads();
// asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(256));
red_buf[threadIdx.x] = x;
__syncthreads();
// asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(256));
x = Reducer()(x, red_buf[threadIdx.x ^ offset]);
} else {
x = Reducer()(x, T(__shfl_xor_sync(uint32_t(-1), x, offset)));
}
if constexpr (offset == scale) {
return x;
} else {
return AllReduce<Reducer, offset, scale>::run(x, red_buf);
}
}
};
} // namespace tl
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "common.h"
namespace tl {
template <int panel_width>
TL_DEVICE dim3 rasterization2DRow() {
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y;
const unsigned int panel_size = panel_width * gridDim.x;
const unsigned int panel_offset = block_idx % panel_size;
const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size);
const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.x;
const unsigned int col_idx =
(panel_idx & 1) ? gridDim.x - 1 - panel_offset / stride : panel_offset / stride;
const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z};
}
template <int panel_width>
TL_DEVICE dim3 rasterization2DColumn() {
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y;
const unsigned int panel_size = panel_width * gridDim.y;
const unsigned int panel_offset = block_idx % panel_size;
const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size);
const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.y;
const unsigned int row_idx =
(panel_idx & 1) ? gridDim.y - 1 - panel_offset / stride : panel_offset / stride;
const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z};
}
} // namespace tl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment