"docs/archive_en_US/Tutorial/SetupNniDeveloperEnvironment.md" did not exist on "962d9aee04fbf7af04047d917266b1b40cc6a31a"
Commit 57ab687c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Initialization] Migration of Codebase from Dev Branch into Main (#10)



* Add format.sh script for code formatting and linting

* docs update

* center align the title

* lint fix

* add ignore

* Add .gitignore for 3rdparty directory

* Add requirements-dev.txt, requirements-test.txt, and requirements.txt

* 3rdparty

* Add gemm.h, CMakeLists.txt, _ffi_api.py, __init__.py, runtime.h, reduce.h, loop_partition.h, utils.h, and loop_vectorize.h

* Refactor CMakeLists.txt and include statements

- Update CMakeLists.txt to use a newer version of CMake and add project name
- Remove unnecessary include directories

Fix include paths in layout.cc, codegen.cc, codegen.h, rt_mod.cc, frontend_legalize.cc, inject_pipeline.cc, layout_inference.cc, loop_vectorize.cc, and lower_tile_op.cc

- Update include paths to use relative paths instead of absolute paths

* Update submodule for 3rdparty/tvm

* update

* load dll first

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* git keep update

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* refactor code structure

* Update Readme

* CMakeLists Customized

* update readme

* update README

* update readme

* update usage

* with TVM_IMPORT_PYTHON_PATH to handle own tvm build python import

* annotate lower transform global func with `transform` prefix

* Migrate Simplify Pass from tilelang tvm branch

* enhance system environment handling with __init__ and CMake

* Initial commit

* CODE_OF_CONDUCT.md committed

* LICENSE committed

* README.md committed

* SECURITY.md committed

* SUPPORT.md committed

* CODE_OF_CONDUCT Commit

* LICENSE Commit

* SECURITY Commit

* SUPPORT Commit

* Modify Support

* Update README.md

* security ci update

* remove examples

* Update and implement clang-format

* add composable kernel components

* Migrate from latest update

* submodule update

* Test update

* Update License

* Spell check

* lint fix

* add clang-tidy to apply static analysis for c source

* update tilelang examples

* Update Install Docs

* Refactor filetree

* Enhance Install

* conflict resloved

* annotate_version

* Initial Update

* test fix

* install

* Implement setup.py

* lint fix

* Separate Init

* Separate test

* docker file commit

* add logo

* Update Readme and Examples

* update readme

* update logo

* Implement AMD Installation

* Add License

* Update AMD MI300x Benchmark

* update README

* update mi300 benchmark scripts

* update ignore

* enhance build scirpt

* update image

* enhance setup.py to remove duplicated libraries

* remove debug files

* update readme

* update image

* update gemm examples

* update flashattention README

* readme update

* add cmake into requirements

* libinfo fix

* auto update submodule

* lint fix

* Fix AMD Build and Test

* Update check for transpose attribute for CDNA Arch

* typo fix for amd

* Implement Matmul Benchmark

* Refactor Code

* [TypoFix] Fix GEMM Example

* [Docs] Init Linear Attention README

* [TYPO] Typo fix

* [Lint] Lint Fix

* enhance example with intrinsics

* [Enhancement] Improve Buffer Collection during IR Parser

* [Dev] Introduce Current classmethod to get current frame

* submodule update

* fake test pass update

* support thread_extent_api

* code optimize

* Add GEMM function implementation for matrix multiplication

* Update logging format to reflect TileLang in logger messages

* Refactor CMakeLists.txt for improved readability and set default build type to Release

* Support Gemm SS Primitives Implementation

* [README] Upload Tile Language Logo (#5)

* update logo

* Update README.md to enhance formatting and center the title

---------
Co-authored-by: default avatarmicrosoft-github-operations[bot] <55726097+microsoft-github-operations[bot]@users.noreply.github.com>
Co-authored-by: default avatarMicrosoft Open Source <microsoftopensource@users.noreply.github.com>
Co-authored-by: default avatarYu Cheng <yu.cheng@pku.edu.cn>
parent 64f17c2f
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lower_tile_op.cc
* \brief Lower the tile op for further codegen.
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include "arith/ir_mutator_with_analyzer.h"
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "../op/op.h"
#include "loop_partition.h"
namespace tvm {
namespace tl {
using namespace tir;
static Buffer makeBufferWithLayout(const Buffer& buffer, const Layout& layout) {
const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode);
Type new_type;
// convert fragments to normal local buffer
if (ptr_type->storage_scope == "local.fragment") {
new_type = PointerType(ptr_type->element_type, "local");
} else {
new_type = buffer->data->type_annotation;
}
Var new_var;
if (ptr_type->storage_scope == "global") {
new_var = buffer->data;
} else {
new_var = Var(buffer->data->name_hint, new_type);
}
return Buffer(new_var, buffer->dtype, layout->OutputShape(), {}, buffer->elem_offset,
buffer->name, buffer->data_alignment, buffer->offset_factor, buffer->buffer_type);
}
class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
public:
static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer;
LowerTileOpPass substituter(&analyzer);
// Trace the buffer map for tvm_access_ptr
substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end());
for (const auto& [_, buffer] : f->buffer_map) {
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
}
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "LowerTileOpPass: Require the target attribute";
substituter.target_ = target.value();
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = substituter.VisitStmt(f->body);
return f;
}
private:
using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;
Stmt VisitStmt_(const BlockNode* op) final {
// Record the mapping from buffer data var to buffer for later lookup
for (auto buffer : op->alloc_buffers) {
buffer_map_.insert({buffer->data, buffer});
}
for (auto match_buffer : op->match_buffers) {
buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer});
}
for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
Map<Var, Layout> vmap;
if (op->annotations.count(attr::kLayoutMap)) {
auto layout_map = op->annotations.at(attr::kLayoutMap).as<Map<Buffer, Layout>>().value();
for (auto [buffer, layout] : layout_map) {
buffer_remap_.Set(buffer, makeBufferWithLayout(buffer, layout));
layout_map_.Set(buffer, layout);
}
}
auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
auto block_ptr = block.CopyOnWrite();
for (size_t i = 0; i < block->alloc_buffers.size(); i++) {
auto buffer = block->alloc_buffers[i];
if (buffer_remap_.count(buffer)) {
block_ptr->alloc_buffers.Set(i, buffer_remap_[buffer]);
}
}
for (const auto& buffer : workspaces_) block_ptr->alloc_buffers.push_back(buffer);
workspaces_.clear();
block_ptr->annotations.erase(attr::kLayoutMap);
return block;
}
int CheckAndGetBufferRowSize(Buffer buffer) {
CHECK(buffer->shape.size() >= 2)
<< "The dimension of Buffer \"" << buffer->name << "\" with shape " << buffer->shape
<< " should be at least 2";
auto dim = buffer->shape.size();
auto buffer_row_size = buffer->shape[dim - 1].as<IntImmNode>()->value;
return buffer_row_size;
}
PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, Optional<PrimExpr> offset = NullOpt,
DataType dtype = DataType::Int(32)) {
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and accumulate it to
// smem_offset
CHECK(access_ptr->IsInstance<CallNode>())
<< "Invalid access ptr for permuted layout: " << access_ptr;
auto access_ptr_call = Downcast<Call>(access_ptr);
if (access_ptr_call->op.same_as(builtin::tvm_access_ptr())) {
LOG(FATAL) << "Transformation for tvm_access_ptr is not implemented yet";
} else if (access_ptr_call->op.same_as(builtin::address_of())) {
BufferLoad load = Downcast<BufferLoad>(access_ptr_call->args[0]);
Array<PrimExpr> indices = load->indices;
Array<PrimExpr> shape = load->buffer->shape;
CHECK_EQ(indices.size(), shape.size())
<< "Indices size and shape size must match for general N-dimensional buffer "
<< "but got indices size: " << indices.size() << " and shape size: " << shape.size();
PrimExpr elem_offset = 0;
PrimExpr stride = 1;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
elem_offset += indices[i] * stride;
stride *= shape[i];
}
PrimExpr smem_offset = elem_offset + (offset.defined() ? offset.value() : 0);
auto new_buffer = buffer_remap_[load->buffer];
auto buffer_map_iter = buffer_map_.find(Downcast<Var>(load->buffer->data));
CHECK(buffer_map_iter != buffer_map_.end())
<< "The buffer corresponding to data Var " << access_ptr_call->args[0] << " is not found";
int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second);
(void)buffer_row_size;
// Convert offset to target-dimension, reindex it and convert it back
Array<PrimExpr> multi_dim_indices;
PrimExpr remaining_offset = smem_offset;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
multi_dim_indices.insert(multi_dim_indices.begin(), floormod(remaining_offset, shape[i]));
remaining_offset = floordiv(remaining_offset, shape[i]);
}
auto forward_indices = layout_map_[load->buffer]->Forward(multi_dim_indices);
PrimExpr new_offset = 0;
PrimExpr stride_offset = 1;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
new_offset += forward_indices[i] * stride_offset;
stride_offset *= shape[i];
}
new_offset = analyzer_->Simplify(new_offset);
Array<PrimExpr> new_indices;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
new_indices.insert(new_indices.begin(), floormod(new_offset, shape[i]));
new_offset = floordiv(new_offset, shape[i]);
}
auto new_access_ptr = access_ptr_call.CopyOnWrite();
new_access_ptr->args.Set(0, BufferLoad(new_buffer, new_indices));
} else {
LOG(FATAL) << "Invalid access op for permuted layout: " << access_ptr;
}
return access_ptr_call;
}
PrimExpr VisitExpr_(const tir::CallNode* op) final {
if (!op->op.same_as(builtin::ptx_ldmatrix()) && !op->op.same_as(builtin::mma_store())) {
return Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
} else {
is_ptx_ = true;
}
// Rewrite from/to shared or shared.dyn to/from local
auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
if (call->op.same_as(builtin::ptx_ldmatrix())) {
// form: T.ptx_ldmatrix(..., smem_ptr, smem_offset)
// smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
// or T.address_of(buffer, offset)
auto access_ptr = call->args[5];
PrimExpr smem_offset = call->args[6];
Call address_of_call = Downcast<Call>(access_ptr);
if (!address_of_call->op.same_as(builtin::address_of())) {
LOG(FATAL) << "Invalid access ptr for permuted layout: " << access_ptr;
}
BufferLoad load = Downcast<BufferLoad>(address_of_call->args[0]);
if (buffer_remap_.count(load->buffer)) {
auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype);
auto new_call = call.CopyOnWrite();
new_call->args.Set(5, new_access_ptr);
new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
}
} else if (call->op.same_as(builtin::mma_store())) {
// because we will directly store result to Buffer instead of calling mma_store now
auto access_ptr = call->args[2];
auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, NullOpt, call->dtype);
auto new_call = call.CopyOnWrite();
new_call->args.Set(2, new_access_ptr);
} else {
LOG(FATAL) << "Invalid call node: " << call;
}
is_ptx_ = false;
return call;
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
if (is_ptx_) {
return load;
}
if (buffer_remap_.count(load->buffer)) {
auto new_indices = layout_map_[load->buffer]->Forward(load->indices);
auto new_buffer = buffer_remap_[load->buffer];
return BufferLoad(new_buffer, new_indices);
}
return load;
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (buffer_remap_.count(store->buffer)) {
auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
auto new_buffer = buffer_remap_[store->buffer];
return BufferStore(new_buffer, store->value, new_indices);
}
return store;
}
PrimExpr VisitExpr_(const VarNode* op) final {
auto var = Downcast<Var>(IRMutatorWithAnalyzer::VisitExpr_(op));
if (buffer_data_to_buffer_.count(var)) {
auto buffer = buffer_data_to_buffer_[var];
if (buffer_remap_.count(buffer)) return buffer_remap_[buffer]->data;
}
return var;
}
Stmt VisitStmt_(const EvaluateNode* op) final {
const CallNode* call = op->value.as<CallNode>();
// Do not analysis the call node to the global function.
if (call && call->op.as<GlobalVarNode>())
return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto tile_op = ParseOperator(GetRef<Stmt>(op), buffer_data_to_buffer_);
if (tile_op == nullptr) return IRMutatorWithAnalyzer::VisitStmt_(op);
AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) {
auto workspace = decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn");
workspaces_.push_back(workspace);
return workspace.access_ptr(2); // write
};
auto lowered = tile_op->Lower(
LowerArgs{target_, thread_block_size_, thread_var_, callback, layout_map_, buffer_remap_},
analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered);
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
ICHECK_NE(iv->thread_tag.length(), 0U);
if (iv->thread_tag == "threadIdx.x") {
thread_var_ = iv->var;
ICHECK(iv->dom->extent.as<IntImmNode>());
thread_block_size_ = iv->dom->extent.as<IntImmNode>()->value;
}
}
return arith::IRMutatorWithAnalyzer::VisitStmt_(op);
}
Target target_;
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Layout> layout_map_;
Map<Buffer, Buffer> buffer_remap_;
Var thread_var_;
size_t thread_block_size_ = 0;
Array<Buffer> workspaces_;
// For ptx Node, we need to remap the buffer and indices
// By access CallNode instead of BufferLoad Node.
bool is_ptx_{false};
// Mapping from data Var of a Buffer to Buffer, for lookup
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
};
namespace transform {
using namespace tir::transform;
tvm::transform::Pass LowerTileOp() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return LowerTileOpPass::Substitute(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {});
}
TVM_REGISTER_GLOBAL("tl.transform.LowerTileOp").set_body_typed(LowerTileOp);
} // namespace transform
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file warp_specialized_pipeline.cc
* \brief Warp specialized Pipeline for cuda GPU (sm90+)
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
namespace tvm {
namespace tl {
using namespace tir;
enum class Role { kConsumer, kProducer, kBoth };
class WarpSpecializedRoleMarker_ : public StmtVisitor {
public:
WarpSpecializedRoleMarker_(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {}
Role GetRole(const StmtNode* stmt) const {
auto it = map_.find(stmt);
ICHECK(it != map_.end());
return it->second;
}
Role GetRole(const Stmt& stmt) const { return GetRole(stmt.get()); }
void VisitStmt_(const EvaluateNode* op) final {
Role role = Role::kConsumer;
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
role = Role::kProducer;
has_bulk_copy_ = true;
}
}
SetRole(op, role);
}
void VisitStmt_(const BufferStoreNode* op) final {
bool is_shared_store = op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
if (!is_shared_store) {
SetRole(op, Role::kConsumer);
return;
}
// Check reads from global
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ GetRef<Stmt>(op));
auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto reads = access[0];
Role role = Role::kProducer;
for (auto read : reads) {
if (read->buffer.scope() != "global") {
role = Role::kConsumer;
break;
}
}
if (role == Role::kProducer) has_simt_copy_ = true;
SetRole(op, role);
}
void VisitStmt_(const SeqStmtNode* op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetRole(op->seq[0]);
for (auto stmt : op->seq) {
if (role != GetRole(stmt)) {
role = Role::kBoth;
break;
}
}
SetRole(op, role);
}
void VisitStmt_(const IfThenElseNode* op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetRole(op->then_case);
if (op->else_case.defined()) {
auto role_else = GetRole(op->else_case.value());
if (role != role_else) role = Role::kBoth;
}
SetRole(op, role);
}
void VisitStmt_(const BlockRealizeNode* op) final {
StmtVisitor::VisitStmt_(op);
SetRole(op, GetRole(op->block));
}
template <class NodeType>
void HandleBodyStmt(const NodeType* op) {
StmtVisitor::VisitStmt_(op);
SetRole(op, GetRole(op->body));
}
void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); }
bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; }
bool HasSimtCopy() { return has_simt_copy_; }
private:
void SetRole(const StmtNode* stmt, Role role) { map_[stmt] = role; }
Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_map<const StmtNode*, Role> map_;
bool has_simt_copy_ = false;
bool has_bulk_copy_ = false;
};
class MultiVersionBufferRewriter : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc& f) {
auto rewriter = MultiVersionBufferRewriter();
rewriter.buffer_lca_ = DetectBufferAccessLCA(f);
for (auto [buffer, _] : rewriter.buffer_lca_) {
Var buffer_var = buffer->data;
rewriter.buffer_data_to_buffer_.Set(buffer_var, buffer);
}
f.CopyOnWrite()->body = rewriter(f->body);
return f;
}
private:
MultiVersionBufferRewriter() = default;
Array<Buffer> GetVersionedBuffers(Array<Stmt> seq_stmt, Array<Buffer> scoped_buffers) {
std::vector<Role> roles;
Array<Array<BufferRegion>> reads, writes;
auto marker = WarpSpecializedRoleMarker_(buffer_data_to_buffer_);
for (auto stmt : seq_stmt) {
marker(stmt);
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt);
auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
reads.push_back(std::move(access[0]));
writes.push_back(std::move(access[1]));
roles.push_back(marker.GetRole(stmt));
}
std::unordered_set<const BufferNode*> consumer_used, producer_used;
for (size_t i = 0; i < seq_stmt.size(); i++) {
if (roles[i] == Role::kProducer) {
for (BufferRegion br : writes[i]) producer_used.insert(br->buffer.get());
} else {
for (BufferRegion br : reads[i]) consumer_used.insert(br->buffer.get());
}
}
Array<Buffer> versioned_buffers;
for (Buffer buffer : scoped_buffers) {
if (consumer_used.count(buffer.get()) && producer_used.count(buffer.get())) {
versioned_buffers.push_back(buffer);
}
}
return versioned_buffers;
}
static Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (new_buffer->strides.size()) {
ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
}
return Buffer(new_buffer);
}
Stmt VisitStmt_(const BlockRealizeNode* op) final {
BlockRealize block_realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
Block block = block_realize->block;
Array<Buffer> alloc_buffers;
for (auto buffer : block->alloc_buffers) {
if (buffer_remap_.count(buffer)) {
Buffer new_buffer = buffer_remap_[buffer];
alloc_buffers.push_back(new_buffer);
} else {
alloc_buffers.push_back(buffer);
}
}
block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
block_realize.CopyOnWrite()->block = block;
return block_realize;
}
Stmt VisitStmt_(const ForNode* op) final {
auto num_stages_anno = op->annotations.Get("num_stages");
if (!num_stages_anno.defined()) return StmtExprMutator::VisitStmt_(op);
ICHECK(num_stages_anno.as<IntImmNode>());
int num_stages = static_cast<int>(num_stages_anno.as<IntImmNode>()->value);
const SeqStmtNode* pipeline_body_seq = op->body.as<SeqStmtNode>();
CHECK(pipeline_body_seq)
<< "ValueError: The body of the software pipeline should be SeqStmt, got "
<< op->body->GetTypeKey();
Array<Buffer> scoped_buffers = {};
for (auto [buffer, stmt] : buffer_lca_) {
if (stmt.defined() && stmt.value().get() == op) scoped_buffers.push_back(buffer);
}
Array<Buffer> versioned_buffers = GetVersionedBuffers(pipeline_body_seq->seq, scoped_buffers);
for (auto buffer : versioned_buffers) {
Var buffer_var = buffer->data;
Buffer new_buffer = RewriteAllocBuffer(buffer, num_stages);
buffer_remap_.Set(buffer, new_buffer);
}
version_index_ = FloorMod(op->loop_var - op->min, num_stages);
auto for_node = StmtExprMutator::VisitStmt_(op);
return for_node;
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto it = buffer_remap_.find(load->buffer);
if (it == buffer_remap_.end()) {
return std::move(load);
}
const Buffer& new_buffer = (*it).second;
auto* n = load.CopyOnWrite();
n->buffer = new_buffer;
n->indices.insert(n->indices.begin(), version_index_);
return std::move(load);
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto it = buffer_remap_.find(store->buffer);
if (it == buffer_remap_.end()) {
return std::move(store);
}
const Buffer& new_buffer = (*it).second;
auto* n = store.CopyOnWrite();
n->buffer = new_buffer;
n->indices.insert(n->indices.begin(), version_index_);
return std::move(store);
}
PrimExpr VisitExpr_(const CallNode* op) final {
Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (call->op.same_as(builtin::tvm_access_ptr())) {
return RewriteBufferAccess(call, {1});
}
return call;
}
PrimExpr RewriteBufferAccess(const Call& call, const std::vector<int> arg_indices) {
auto product = [](const Array<PrimExpr>& input) {
return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), input);
};
Array<PrimExpr> new_args = call->args;
for (int i : arg_indices) {
auto buffer_var = Downcast<Var>(call->args[i]);
if (!buffer_data_to_buffer_.count(buffer_var)) continue;
const Buffer& buffer = buffer_data_to_buffer_[buffer_var];
auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) {
const Buffer& new_buffer = (*it).second;
const PrimExpr& old_index = call->args[i + 1];
PrimExpr offset;
if (new_buffer->strides.empty()) {
offset = product(buffer->shape);
} else {
offset = new_buffer->strides[0];
}
PrimExpr new_index = old_index + version_index_ * offset;
new_args.Set(i + 1, new_index);
}
}
return Call(call->dtype, call->op, new_args, call->span);
}
PrimExpr version_index_;
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Optional<Stmt>> buffer_lca_;
Map<Buffer, Buffer> buffer_remap_;
};
using namespace tir::transform;
tvm::transform::Pass MultiVersionBuffer() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return MultiVersionBufferRewriter::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {});
}
TVM_REGISTER_GLOBAL("tl.transform.MultiVersionBuffer")
.set_body_typed(MultiVersionBuffer);
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file pipeline_planning.cc
* \brief Plan the software pipeline
*/
#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../target/utils.h"
namespace tvm {
namespace tl {
using namespace tir;
namespace {
/*!
* \brief Check whether two regions have intersections.
* \param region1 The first region.
* \param region2 The second region.
* \return Whether region1 and region2 have intersections.
*/
bool MayConflict(Region region1, Region region2) {
ICHECK(region1.size() == region2.size());
for (size_t i = 0; i < region1.size(); i++) {
Range dim1 = region1[i];
Range dim2 = region2[i];
auto int_set1 = arith::IntSet::FromRange(dim1);
auto int_set2 = arith::IntSet::FromRange(dim2);
if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
return false;
}
}
return true;
}
} // namespace
class PipelinePlanner : public StmtExprMutator {
public:
static Stmt Substitute(const PrimFunc& f) {
PipelinePlanner substituter;
for (const auto& [_, buffer] : f->buffer_map) {
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
}
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "Pipeline_Planning: Require the target attribute";
substituter.target_ = target.value();
return substituter.VisitStmt(f->body);
}
private:
PipelinePlanner() = default;
struct PipelineStageInfo {
Array<BufferRegion> reads, writes;
int original_order;
int order = -1, stage = -1;
bool copy_stage = false;
int last_use_stage = -1;
};
PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) {
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt);
Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
PipelineStageInfo pinfo;
pinfo.reads = std::move(access[0]);
pinfo.writes = std::move(access[1]);
pinfo.original_order = idx;
// copy stage should only have one reads and one writes
if (pinfo.reads.size() == 1 && pinfo.writes.size() == 1) {
for (auto region : pinfo.reads)
if (region->buffer.scope() == "global") pinfo.copy_stage = true;
for (auto region : pinfo.writes)
if (region->buffer.scope() == "global") pinfo.copy_stage = true;
}
return std::move(pinfo);
}
Stmt VisitStmt_(const ForNode* loop) final {
auto num_stages_anno = loop->annotations.Get("num_stages");
if (!num_stages_anno.defined()) return StmtExprMutator::VisitStmt_(loop);
int num_stages = num_stages_anno.as<IntImmNode>()->value;
Stmt pipeline_body{nullptr};
if (const auto* realize = loop->body.as<BlockRealizeNode>()) {
const auto& block = realize->block;
for (const auto& buffer : block->alloc_buffers) {
ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
pipeline_body = block->body;
} else {
pipeline_body = loop->body;
}
const SeqStmtNode* pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
CHECK(pipeline_body_seq)
<< "ValueError: The body of the software pipeline should be SeqStmt, got "
<< loop->body->GetTypeKey();
CHECK(num_stages >= 1);
CHECK(loop->kind == ForKind::kSerial);
std::vector<PipelineStageInfo> pipeline_stage_infos;
for (size_t i = 0; i < pipeline_body_seq->size(); i++) {
auto pinfo = MakePipelineStageInfo(pipeline_body_seq->seq[i], i);
pipeline_stage_infos.push_back(std::move(pinfo));
}
// analysis use-def chain
for (auto& pinfo : pipeline_stage_infos) {
for (int i = pinfo.original_order + 1; i < static_cast<int>(pipeline_body_seq->size()); i++) {
if (!pinfo.copy_stage) continue;
for (const BufferRegion& read : pipeline_stage_infos[i].reads) {
if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), [&](const BufferRegion& r) {
return r->buffer == read->buffer && MayConflict(r->region, read->region);
}) != pinfo.writes.end()) {
pinfo.last_use_stage = std::max(pinfo.last_use_stage, i);
}
}
for (const BufferRegion& write : pipeline_stage_infos[i].writes) {
if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), [&](const BufferRegion& r) {
return r->buffer == write->buffer && MayConflict(r->region, write->region);
}) != pinfo.writes.end()) {
CHECK(false) << "Can't handle multiple write on overlap buffer region in the pipeline "
"planning pass: "
<< pipeline_body_seq->seq[pinfo.original_order];
}
}
}
}
// Making stages and orders
int order_idx = 0;
for (auto& pinfo : pipeline_stage_infos) {
if (pinfo.copy_stage && pinfo.last_use_stage != -1) continue;
pinfo.order = order_idx++;
pinfo.stage = num_stages;
for (auto& pinfo_1 : pipeline_stage_infos) {
if (pinfo_1.copy_stage && pinfo_1.last_use_stage == pinfo.original_order) {
pinfo_1.order = order_idx++;
pinfo_1.stage = 0;
}
}
}
ICHECK(size_t(order_idx) == pipeline_stage_infos.size()) <<
"The number of stages should be equal to the number of pipeline stages. " <<
"Got " << order_idx << " stages and " << pipeline_stage_infos.size() << " pipeline stages.";
// if all the copy is at the end of the order, we can move these copy to the beginning of the
// order and shrink the stage offset by 1.
int copy_stage_at_end = [&]() {
int copy_stage_cnt = 0;
int copy_order_min = pipeline_stage_infos.size();
int non_copy_order_max = 0;
for (auto& pinfo : pipeline_stage_infos) {
if (pinfo.copy_stage) {
copy_stage_cnt++;
copy_order_min = std::min(copy_order_min, pinfo.order);
} else {
non_copy_order_max = std::max(non_copy_order_max, pinfo.order);
}
}
if (copy_order_min > non_copy_order_max) return copy_stage_cnt;
return -1;
}();
if (copy_stage_at_end > 0 && num_stages >= 2) {
for (auto& pinfo : pipeline_stage_infos) { // move copy to the beginning
pinfo.order = (pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size();
if (!pinfo.copy_stage) pinfo.stage--;
}
}
// Finally, make the pipeline annotation
Map<String, ObjectRef> annotations;
for (const auto& [key, value] : loop->annotations) {
if (key != "num_stages") {
annotations.Set(key, value);
}
}
std::vector<Integer> orders, stages;
orders.reserve(pipeline_stage_infos.size());
stages.reserve(pipeline_stage_infos.size());
for (auto& pinfo : pipeline_stage_infos) {
orders.push_back(pinfo.order);
stages.push_back(pinfo.stage);
}
annotations.Set(tir::attr::software_pipeline_stage, Array<Integer>(stages));
annotations.Set(tir::attr::software_pipeline_order, Array<Integer>(orders));
if (TargetHasAsyncCopy(target_))
annotations.Set(tir::attr::software_pipeline_async_stages, Array<Integer>{0});
return For(loop->loop_var, loop->min, loop->extent, loop->kind, loop->body,
loop->thread_binding, annotations);
}
Stmt VisitStmt_(const BlockNode* op) final {
for (const auto& buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
for (const auto& buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data);
}
return std::move(block);
}
Map<Var, Buffer> buffer_data_to_buffer_;
Target target_;
};
tvm::transform::Pass PipelinePlanning() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = PipelinePlanner::Substitute(f);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
}
TVM_REGISTER_GLOBAL("tl.transform.PipelinePlanning").set_body_typed(PipelinePlanning);
} // namespace tl
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file simplify.cc
* \brief Remove useless parameters of TL PrimFunc.
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/utils.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/transform.h>
#include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include "tir/analysis/control_flow_graph.h"
namespace tvm {
namespace tl {
using namespace tir;
using namespace arith;
struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
bool transitively_prove_inequalities;
bool propagate_knowns_to_prove_conditional;
bool propagate_knowns_to_simplify_expressions;
bool convert_boolean_to_and_of_ors;
bool apply_constraints_to_boolean_branches;
TVM_DECLARE_ATTRS(SimplifyConfigNode, "tl.transform.SimplifyConfig") {
TVM_ATTR_FIELD(transitively_prove_inequalities)
.describe(
"If true, simplify conditionals with transitive combinations of scoped constraints")
.set_default(false);
TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional)
.describe(
"If true, known buffer values are propagated and used to statically prove conditionals")
.set_default(false);
TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions)
.describe(
"If true, known buffer values are propagated and used to replace BufferLoad wherever "
"possible")
.set_default(false);
TVM_ATTR_FIELD(convert_boolean_to_and_of_ors)
.describe("If true, simplify conditionals into an AND of ORs")
.set_default(false);
TVM_ATTR_FIELD(apply_constraints_to_boolean_branches)
.describe(
"If true, simplify each branch of AND/OR "
"under a constraints provided by the other branch")
.set_default(false);
}
RewriteSimplifier::Extension GetEnabledExtensions() const {
RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
if (transitively_prove_inequalities) {
flags =
RewriteSimplifier::Extension(flags | RewriteSimplifier::kTransitivelyProveInequalities);
}
if (convert_boolean_to_and_of_ors) {
flags = RewriteSimplifier::Extension(flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
}
if (apply_constraints_to_boolean_branches) {
flags = RewriteSimplifier::Extension(flags |
RewriteSimplifier::kApplyConstraintsToBooleanBranches);
}
return flags;
}
};
std::unordered_set<const BufferNode*> CollectUsedBuffers(const PrimFunc& func) {
struct Visitor : StmtExprVisitor {
using StmtExprVisitor::VisitExpr_;
using StmtExprVisitor::VisitStmt_;
Visitor(PrimFunc func) : func(func) {}
void VisitExpr_(const CallNode* op) override {
for (const auto& arg: op->args) {
for (const auto& it: func->buffer_map) {
if (Downcast<PrimExpr>(it.second.get()->data).same_as(arg)) {
used_in_buffer_def_.insert(it.second.get());
}
}
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const BufferLoadNode* op) override {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) override {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const BlockNode* op) override {
for (const auto& buffer: op->alloc_buffers) {
for (const auto& it: func->buffer_map) {
if (it.second.get()->data.same_as(buffer.get()->data)) {
used_in_buffer_def_.insert(it.second.get());
}
}
}
for (const auto& buffer: op->reads) {
for (const auto& it: func->buffer_map) {
if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
used_in_buffer_def_.insert(it.second.get());
}
}
}
for (const auto& buffer: op->writes) {
for (const auto& it: func->buffer_map) {
if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
used_in_buffer_def_.insert(it.second.get());
}
}
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitBuffer(const Buffer& buf) {
// Collect buffers that should remain defined
VarUseDefAnalyzer usage(Array<Var>{});
usage(buf->data);
for (const auto& dim : buf->shape) {
usage(dim);
}
for (const auto& dim : buf->strides) {
usage(dim);
}
usage(buf->elem_offset);
for (const auto& buffer : usage.buffer_use_count_) {
if (buffer.second >= 1) {
used_in_buffer_def_.insert(buffer.first);
}
}
for (const auto& buffer : usage.undefined_buffers_) {
used_in_buffer_def_.insert(buffer.get());
}
}
PrimFunc func;
std::unordered_set<const BufferNode*> used_in_buffer_def_;
};
Visitor visitor(func);
visitor(func->body);
return visitor.used_in_buffer_def_;
}
/* \brief Utility function to collect vars that should be retained. Used in Letstmt Only
*/
std::unordered_set<const VarNode*> CollectVarsUsedInBufferDefinition(const Stmt& stmt) {
struct Visitor : StmtExprVisitor {
using StmtExprVisitor::VisitExpr_;
using StmtExprVisitor::VisitStmt_;
void VisitExpr_(const BufferLoadNode* op) override {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) override {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}
void VisitBuffer(const Buffer& buf) {
// Collect variables that should remain defined
VarUseDefAnalyzer usage(Array<Var>{});
usage(buf->data);
for (const auto& dim : buf->shape) {
usage(dim);
}
for (const auto& dim : buf->strides) {
usage(dim);
}
usage(buf->elem_offset);
// Track for use in LetStmtNode mutator
for (const auto& var : usage.undefined_) {
used_in_buffer_def_.insert(var.get());
}
}
std::unordered_set<const VarNode*> used_in_buffer_def_;
};
Visitor visitor;
visitor(stmt);
return visitor.used_in_buffer_def_;
}
class SimplifyConfig : public Attrs {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, SimplifyConfigNode);
};
TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer {
public:
static PrimFunc Apply(PrimFunc func, Analyzer* analyzer,
Optional<SimplifyConfig> config_opt = NullOpt) {
auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions());
std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
if (config->propagate_knowns_to_prove_conditional ||
config->propagate_knowns_to_simplify_expressions) {
touch_pattern = ControlFlowGraph(func->body);
}
std::unordered_set<const VarNode*> used_in_buffer_def =
CollectVarsUsedInBufferDefinition(func->body);
StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
std::move(used_in_buffer_def));
simplifier.MarkBufferMapShapes(func);
func.CopyOnWrite()->body = simplifier(func->body);
// Begin to remove useless var and buffer
// First get used buffers
simplifier.used_buffers_ = CollectUsedBuffers(func);
bool param_updated = false;
Array<Var> new_params;
Map<Var, Buffer> new_buffer_map;
// Check whether each buffer is used
for (const auto& var: func->params) {
if (func->buffer_map.find(var) != func->buffer_map.end()) {
if (simplifier.used_buffers_.find(func->buffer_map[var].get()) != simplifier.used_buffers_.end()) {
new_params.push_back(var);
new_buffer_map.Set(var, func->buffer_map[var]);
} else {
param_updated = true;
}
}
}
// return func;
if (param_updated) {
return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type, new_buffer_map, func->attrs, func->span);
} else {
return func;
}
}
private:
explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config,
std::optional<ControlFlowGraph> touch_pattern,
std::unordered_set<const VarNode*> used_in_buffer_def)
: IRMutatorWithAnalyzer(analyzer),
config_(config),
touch_pattern_(touch_pattern),
used_in_buffer_def_(used_in_buffer_def) {}
using Parent = IRMutatorWithAnalyzer;
using Parent::VisitExpr_;
using Parent::VisitStmt;
using Parent::VisitStmt_;
PrimExpr VisitExpr(const PrimExpr& expr) final {
if (config_->propagate_knowns_to_simplify_expressions) {
return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(), analyzer_);
} else {
return analyzer_->Simplify(expr);
}
}
Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }
Stmt VisitStmt(const Stmt& stmt) override {
Optional<Stmt> cache = this->current_stmt_;
this->current_stmt_ = stmt;
Stmt output = Parent::VisitStmt(stmt);
this->current_stmt_ = std::move(cache);
return output;
}
Stmt VisitStmt_(const ForNode* op) final {
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
return Parent::VisitStmt_(op);
}
bool CanInlineLetStmt(const LetStmtNode* op) {
if (is_const_number(op->value)) return true;
if (op->value.as<VarNode>()) return true;
// Won't face the deep expression explosion problem as in Let expression.
// attempt to inline as much as possible if the value integer type(can be index).
if (!op->value.dtype().is_int()) return false;
return SideEffect(op->value) <= CallEffectKind::kPure;
}
Stmt VisitStmt_(const LetStmtNode* op) override {
PrimExpr value = this->VisitExpr(op->value);
bool can_inline = CanInlineLetStmt(op);
if (can_inline) {
// It is usually fine to discard the let binding because the
// call to simplify will always inline the var.
//
// The exception is when the variable is used in a Buffer's
// definition, as these are not updated by the simplification.
// After DeclBuffer is required prior to use of a buffer,
// simplifying can update the buffer definition as well. The
// buffer can only be updated at its point of definition,
// because the points of use may occur within contexts that
// allow for additional simplifications (e.g. a buffer of shape
// [i,j] whose first use occurs within "if i==1" should not have
// its shape simplified to [1,j]).
analyzer_->Bind(op->var, value);
} else if (SideEffect(op->value) <= CallEffectKind::kPure) {
// Even if we aren't replacing all occurrences, they may be
// necessary for proving conditional statements.
non_inlined_bindings_.Set(op->var, value);
}
Stmt body = this->VisitStmt(op->body);
// TODO(Lunderberg): Update the Buffer object as part of
// DeclBuffer updates, which will first require
// https://github.com/apache/tvm/pull/14778.
bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());
if (can_inline && !used_in_buffer_def) {
return body;
} else if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = this->CopyOnWrite(op);
n->value = std::move(value);
n->body = std::move(body);
return Stmt(n);
}
}
Stmt VisitStmt_(const IfThenElseNode* op) override {
if (Optional<Bool> cond = ProveCondition(op->condition)) {
if (cond.value()->value) {
return this->VisitStmt(op->then_case);
} else if (op->else_case) {
return this->VisitStmt(op->else_case.value());
} else {
return Evaluate(0);
}
} else {
return Parent::VisitStmt_(op);
}
}
PrimExpr VisitExpr_(const CallNode* op) override {
if (op->op.same_as(builtin::if_then_else())) {
if (Optional<Bool> cond = ProveCondition(op->args[0])) {
if (cond.value()->value) {
return this->VisitExpr(op->args[1]);
} else {
return this->VisitExpr(op->args[2]);
}
}
}
return Parent::VisitExpr_(op);
}
PrimExpr VisitExpr_(const VarNode* op) override {
used_vars_.insert(op);
return Parent::VisitExpr_(op);
}
PrimExpr VisitExpr_(const BufferLoadNode* op) override {
auto buffer = op->buffer.get();
if (used_buffers_.find(buffer) == used_buffers_.end()) {
used_buffers_.insert(buffer);
}
return Parent::VisitExpr_(op);
}
// eliminate useless stores
Stmt VisitStmt_(const BufferStoreNode* op) override {
BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) {
if (load->buffer->data.same_as(store->buffer->data) &&
ArrayDeepEqual(load->indices, store->indices) &&
tir::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) &&
ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
return Evaluate(0);
}
}
auto buffer = op->buffer.get();
if (used_buffers_.find(buffer) == used_buffers_.end()) {
used_buffers_.insert(buffer);
}
return std::move(store);
}
private:
bool ArrayDeepEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (size_t i = 0; i < lhs.size(); i++) {
if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) {
return false;
}
}
return true;
}
/* \brief Internal utility for checking conditionals
*
* Uses more aggressive optimization, such as performing additional
* inlining and tracking known buffer values.
*/
Optional<Bool> ProveCondition(PrimExpr condition) const {
condition = Substitute(condition, non_inlined_bindings_);
if (config_->propagate_knowns_to_prove_conditional) {
ICHECK(touch_pattern_.has_value());
condition = touch_pattern_->SimplifyInContext(condition, current_stmt_.value(), analyzer_);
} else {
condition = analyzer_->Simplify(condition);
}
if (const int64_t* as_int = as_const_int(condition)) {
return Bool(*as_int);
} else {
return NullOpt;
}
}
SimplifyConfig config_;
std::optional<ControlFlowGraph> touch_pattern_;
Map<Var, PrimExpr> non_inlined_bindings_;
Optional<Stmt> current_stmt_{NullOpt};
std::unordered_set<const VarNode*> used_in_buffer_def_;
std::unordered_set<const VarNode*> used_vars_;
std::unordered_set<const BufferNode*> used_buffers_;
};
using namespace tir::transform;
tvm::transform::Pass Simplify() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
arith::Analyzer analyzer;
auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
return StmtSimplifier::Apply(f, &analyzer, cfg);
};
return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
}
TVM_REGISTER_GLOBAL("tl.transform.Simplify").set_body_typed(Simplify);
} // namespace tl
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file thread_storage_sync.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "../op/builtin.h"
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
#include "tir/transforms/storage_access.h"
namespace tvm {
namespace tl {
using namespace tir;
class ThreadPartialSyncPlanner : public StorageAccessVisitor {
public:
explicit ThreadPartialSyncPlanner(StorageScope sync_scope) : sync_scope_(sync_scope) {}
// The syncs inserted before each statement
std::unordered_set<const Object*> syncs_inserted_;
std::unordered_map<const Object*, int> partial_syncs_inserted_;
protected:
bool Enabled(const VarNode* buf, const StorageScope& scope) const final {
return in_device_env() && scope == sync_scope_;
}
// Plan the sync
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
// Redirect all "shared.dyn" buffer access to the same buffer var
// so that the accesses can be planned together.
Var shared_dyn_buf;
for (StmtEntry& entry : seq) {
for (AccessEntry& access : entry.access) {
if (access.scope.rank == StorageRank::kShared && access.scope.tag == ".dyn" &&
access.buffer.defined()) {
if (!shared_dyn_buf.defined()) {
shared_dyn_buf = access.buffer;
} else {
access.buffer = shared_dyn_buf;
}
}
}
}
// Unsynced reads and writes
std::vector<AccessEntry> reads;
std::vector<AccessEntry> writes;
// if it is a loop, rotate two times to consider effect of loop.
// simulation based approach to find dependencies
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
// check if sync before statement is needed.
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
// Apply the syncs added already.
if (sync_before_stmt) {
reads.clear();
writes.clear();
}
for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, false)) {
sync_before_stmt = true;
break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, false)) {
sync_before_stmt = true;
break;
}
} else if (acc.type == kSync) {
reads.clear();
writes.clear();
}
}
// If sync is inserted. remove the irrelevant things.
if (sync_before_stmt) {
reads.clear();
writes.clear();
}
// Add the read/write of current statement
for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) {
reads.push_back(acc);
} else if (acc.type == kWrite) {
writes.push_back(acc);
} else if (acc.type == kSync) {
reads.clear();
writes.clear();
}
}
if (sync_before_stmt) {
insert_syncs(s.stmt);
}
}
if (loop != nullptr) {
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0) break;
if (reads.empty() && writes.empty()) break;
bool sync_before_stmt = false;
for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, true)) {
sync_before_stmt = true;
break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, true)) {
sync_before_stmt = true;
break;
}
} else if (acc.type == kSync) {
reads.clear();
writes.clear();
}
}
if (sync_before_stmt) {
insert_syncs(s.stmt);
break;
}
}
}
// return the exposed entries, remove unnecessary ones.
int sync_count = 0;
// head are before first sync, tail are after last sync
std::vector<AccessEntry> head, tail;
AccessEntry esync;
esync.threads = this->env_threads();
esync.type = kSync;
esync.scope = sync_scope_;
for (const StmtEntry& s : seq) {
if (syncs_inserted_.count(s.stmt)) {
if (sync_count != 0) {
tail.clear();
} else {
head.push_back(esync);
}
++sync_count;
}
for (const AccessEntry& acc : s.access) {
if (acc.type == kSync) {
if (sync_count != 0) {
tail.clear();
} else {
head.push_back(esync);
}
++sync_count;
} else {
if (sync_count != 0) {
tail.push_back(acc);
} else {
head.push_back(acc);
}
}
}
}
head.insert(head.end(), tail.begin(), tail.end());
if (loop != nullptr) {
// clear double buffer flag after a loop is finished.
for (AccessEntry& e : head) {
e.double_buffer_write = false;
}
}
return head;
}
private:
// find conflicting entry in vec.
bool FindConflict(const std::vector<AccessEntry>& prev, const AccessEntry& curr,
bool loop_carry) {
for (const AccessEntry& x : prev) {
if (FindConflict(x, curr, loop_carry)) {
return true;
}
}
return false;
}
bool FindConflict(const AccessEntry& prev, const AccessEntry& curr, bool loop_carry) {
// Access to different buffers does not conflict.
if (!prev.buffer.same_as(curr.buffer)) {
return false;
}
// Assumes no race between threads
// Same index value means no conflicts
// TODO(tqchen) more standard set based testing.
bool has_same_index = true;
// Even if access has the same index, those indices need to
// depend on the innermost thread id to avoid race condition
bool depends_on_thread_index = true;
const VarNode* thread_index_var = nullptr;
if (!curr.threads.empty()) {
thread_index_var = curr.threads.back()->var.get();
}
for (size_t i = 0; i < prev.touched.size(); i++) {
const auto& prev_intset = prev.touched[i];
const auto& curr_intset = curr.touched[i];
if (prev_intset.IsSinglePoint() && curr_intset.IsSinglePoint()) {
PrimExpr prev_index = prev_intset.PointValue();
PrimExpr curr_index = curr_intset.PointValue();
has_same_index = ExprDeepEqual()(prev_index, curr_index);
if (thread_index_var != nullptr) {
auto f_uses_thread_index = [=](const tvm::tir::VarNode* parameter) {
return parameter == thread_index_var;
};
depends_on_thread_index = depends_on_thread_index &&
UsesVar(curr_index, f_uses_thread_index) &&
UsesVar(prev_index, f_uses_thread_index);
}
} else {
has_same_index = false;
}
if (!(has_same_index && depends_on_thread_index)) {
break;
}
}
if (has_same_index && depends_on_thread_index) {
return false;
}
// If this is a read into a double buffer that was previously
// swapped out, then it doesn't conflict.
if (prev.double_buffer_write && curr.type == kRead && !loop_carry) {
return false;
}
// If nothing else allows sharing the same buffer, then they are
// in conflict.
return true;
}
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "kWarpSpecializationScope") {
IfThenElse body = Downcast<IfThenElse>(op->body);
auto partitions = Downcast<Array<IntImm>>(op->node);
ICHECK(partitions.size() == 2);
scope_.push_back(std::vector<StmtEntry>());
num_partial_threads_ = partitions[0];
this->VisitStmt(body->then_case);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
num_partial_threads_ = partitions[1];
scope_.push_back(std::vector<StmtEntry>());
VisitStmt(body->else_case.value());
auto v = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
num_partial_threads_ = NullOpt;
} else {
StorageAccessVisitor::VisitStmt_(op);
}
}
void insert_syncs(const Object* obj) {
// ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition";
if (syncs_inserted_.count(obj)) return;
if (num_partial_threads_.defined()) {
syncs_inserted_.insert(obj);
partial_syncs_inserted_[obj] = static_cast<int>(num_partial_threads_.value()->value);
} else {
syncs_inserted_.insert(obj);
}
}
private:
Optional<IntImm> num_partial_threads_;
// synchronization scope
StorageScope sync_scope_;
};
// There are cases where necessary syncthreads is not inserted by ThreadPartialSyncInserter.
// For example, syncthreads is needed after async_wait_queue in the second loop below,
// but since ThreadPartialSyncInserter is not aware of the asynchronous semantics, it cannot tell
// that the syncthreads is needed there.
//
// // Pipeline prologue
// for i in range(125):
// async_commit_queue(0):
// async_scope:
// shared[(i + 3) % 4] = ...
// ...
//
// // Pipeline Epilogue
// for i in range(3):
// async_wait_queue(0, 2 - i):
// local[...] = shared[(i + 125) % 4]
class ThreadPartialSyncInserter : public StmtExprMutator {
public:
ThreadPartialSyncInserter(StorageScope sync_scope, const std::unordered_set<const Object*>& syncs,
std::unordered_map<const Object*, int> partial_syncs)
: sync_scope_(sync_scope), syncs_(syncs), partial_syncs_(partial_syncs) {}
Stmt VisitStmt(const Stmt& stmt) final {
if (syncs_.size() == 0) return stmt;
if (syncs_.count(stmt.get())) {
Stmt barrier;
if (partial_syncs_.count(stmt.get())) {
auto iter = partial_syncs_.find(stmt.get());
ICHECK(sync_scope_.rank == StorageRank::kShared);
barrier = Evaluate(Call(DataType::Int(32), tl::SyncThreadsPartialOp(), {iter->second}));
} else {
return StmtExprMutator::VisitStmt(stmt);
}
// Mutate after query, to avoid stmt change.
auto ret = StmtExprMutator::VisitStmt(stmt);
ret = SeqStmt({barrier, ret});
return ret;
} else {
return StmtExprMutator::VisitStmt(stmt);
}
}
private:
// data structure.
StorageScope sync_scope_;
const std::unordered_set<const Object*>& syncs_;
const std::unordered_map<const Object*, int>& partial_syncs_;
};
Stmt ThreadPartialSync(Stmt stmt, std::string storage_scope) {
StorageScope sync_scope = StorageScope::Create(storage_scope);
ThreadPartialSyncPlanner planner(sync_scope);
planner(stmt);
return ThreadPartialSyncInserter(sync_scope, planner.syncs_inserted_,
planner.partial_syncs_inserted_)(std::move(stmt));
}
using namespace tir::transform;
namespace transform {
Pass ThreadPartialSync(String storage_scope) {
auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = tl::ThreadPartialSync(std::move(n->body), storage_scope);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.ThreadPartialSync", {});
}
TVM_REGISTER_GLOBAL("tl.transform.ThreadPartialSync").set_body_typed(ThreadPartialSync);
} // namespace transform
} // namespace tir
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file warp_specialized_pipeline.cc
* \brief Warp specialized Pipeline for cuda GPU (sm90+)
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
namespace tvm {
namespace tl {
using namespace tir;
enum class Role { kConsumer, kProducer, kBoth };
class WarpSpecializedRoleMarker : public StmtVisitor {
public:
WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {}
Role GetRole(const StmtNode* stmt) const {
auto it = map_.find(stmt);
ICHECK(it != map_.end());
return it->second;
}
Role GetRole(const Stmt& stmt) const { return GetRole(stmt.get()); }
void VisitStmt_(const EvaluateNode* op) final {
Role role = Role::kConsumer;
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
role = Role::kProducer;
has_bulk_copy_ = true;
}
}
SetRole(op, role);
}
void VisitStmt_(const BufferStoreNode* op) final {
bool is_shared_store = op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
if (!is_shared_store) {
SetRole(op, Role::kConsumer);
return;
}
// Check reads from global
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ GetRef<Stmt>(op));
auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto reads = access[0];
Role role = Role::kProducer;
for (auto read : reads) {
if (read->buffer.scope() != "global") {
role = Role::kConsumer;
break;
}
}
if (role == Role::kProducer) has_simt_copy_ = true;
SetRole(op, role);
}
void VisitStmt_(const SeqStmtNode* op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetRole(op->seq[0]);
for (auto stmt : op->seq) {
if (role != GetRole(stmt)) {
role = Role::kBoth;
break;
}
}
SetRole(op, role);
}
void VisitStmt_(const IfThenElseNode* op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetRole(op->then_case);
if (op->else_case.defined()) {
auto role_else = GetRole(op->else_case.value());
if (role != role_else) role = Role::kBoth;
}
SetRole(op, role);
}
void VisitStmt_(const BlockRealizeNode* op) final {
StmtVisitor::VisitStmt_(op);
SetRole(op, GetRole(op->block));
}
template <class NodeType>
void HandleBodyStmt(const NodeType* op) {
StmtVisitor::VisitStmt_(op);
SetRole(op, GetRole(op->body));
}
void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); }
bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; }
bool HasSimtCopy() { return has_simt_copy_; }
private:
void SetRole(const StmtNode* stmt, Role role) { map_[stmt] = role; }
Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_map<const StmtNode*, Role> map_;
bool has_simt_copy_ = false;
bool has_bulk_copy_ = false;
};
static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
return Call(DataType::Handle(), GetMBarrierOp(), {barrier_id});
}
static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) {
auto call = Call(DataType::Handle(), MBarrierExpectTX(), {makeGetBarrier(barrier_id), bytes});
return Evaluate(call);
}
static Stmt makeArriveBarrier(PrimExpr barrier_id) {
auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(), {makeGetBarrier(barrier_id)});
return Evaluate(call);
}
static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
auto call =
Call(DataType::Handle(), builtin::ptx_cp_async_barrier(), {makeGetBarrier(barrier_id)});
return Evaluate(call);
}
static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
auto call = Call(DataType::Handle(), MBarrierWaitParity(), {makeGetBarrier(barrier_id), parity});
return Evaluate(call);
}
// static bool isGemm(Stmt stmt) {
// bool is_gemm = false;
// if (stmt.as<EvaluateNode>()) {
// auto call = Downcast<Evaluate>(stmt)->value.as<CallNode>();
// if (call && call->op.same_as(Op::Get("tir.call_extern"))) {
// if (call->args[0].as<StringImmNode>()) {
// std::string name = Downcast<StringImm>(call->args[0])->value;
// if (name.find("gemm") != std::string::npos) {
// is_gemm = true;
// }
// }
// }
// }
// return is_gemm;
// }
class ProducerTraitsCollector : public StmtExprVisitor {
public:
ProducerTraitsCollector() { Clear(); }
void Clear() {
bulk_copy_bytes = 0;
loop_extents = 1;
has_simt_copy = false;
}
void Collect(Stmt stmt) { VisitStmt(stmt); }
bool HasSimtCopy() { return has_simt_copy; }
PrimExpr BulkCopyBytes() { return bulk_copy_bytes; }
private:
void VisitExpr_(const CallNode* call) final {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
int type_bytes = access_ptr->args[0]->dtype.bytes();
bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
}
StmtExprVisitor::VisitExpr_(call);
}
void VisitStmt_(const ForNode* op) final {
PrimExpr old_loop_evtents = loop_extents;
loop_extents *= op->extent;
StmtExprVisitor::VisitStmt_(op);
loop_extents = old_loop_evtents;
}
void VisitExpr_(const BufferLoadNode* op) final {
has_simt_copy = true;
StmtExprVisitor::VisitExpr_(op);
}
bool has_simt_copy;
PrimExpr bulk_copy_bytes;
PrimExpr loop_extents;
};
// Rewrite the producer Stmt to use the correct barrier index
class MbarrierRewriter : public StmtExprMutator {
public:
static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) {
MbarrierRewriter rewriter;
rewriter.producer_barrier_idx_ = barrier_id;
return rewriter(stmt);
}
private:
PrimExpr VisitExpr_(const CallNode* op) final {
auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
call.CopyOnWrite()->args.Set(1, makeGetBarrier(producer_barrier_idx_));
}
return call;
}
PrimExpr producer_barrier_idx_;
};
class ThreadIdxRewriter : public StmtExprMutator {
public:
static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced) {
auto rewriter = ThreadIdxRewriter(thread_var, replaced);
return rewriter(stmt);
}
private:
ThreadIdxRewriter(Var thread_var, PrimExpr replaced)
: thread_var_(thread_var), replaced_(replaced) {}
PrimExpr VisitExpr_(const VarNode* var) final {
if (var == thread_var_.get()) {
return replaced_;
} else {
return StmtExprMutator::VisitExpr_(var);
}
}
Var thread_var_;
PrimExpr replaced_;
};
Block MakeGroupBlock(const Stmt& stmt, const Map<String, ObjectRef>& annotations) {
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt,
/*init=*/{}, /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/annotations);
return block;
}
struct OpInfo {
int group_size, order, stage;
std::vector<int> group;
};
struct PipelineInfo {
std::vector<OpInfo> op_infos;
PipelineInfo() = default;
PipelineInfo(
Array<Array<Integer>> group_info,
Array<Integer> order_info,
Array<Integer> stage_info
) {
int n = static_cast<int>(group_info.size());
ICHECK(n == static_cast<int>(order_info.size()));
ICHECK(n == static_cast<int>(stage_info.size()));
// int cur_id = 0;
for (int i = 0; i < n; i++) {
OpInfo op_info;
op_info.group_size = group_info[i].size();
for (int j = 0; j < op_info.group_size; j++) {
op_info.group.push_back(group_info[i][j].as<IntImmNode>()->value);
}
op_info.order = order_info[i].as<IntImmNode>()->value;
op_info.stage = stage_info[i].as<IntImmNode>()->value;
op_infos.push_back(op_info);
}
}
PipelineInfo(const PipelineInfo& other) {
for (auto op_info : other.op_infos) {
op_infos.push_back(op_info);
}
}
std::pair<int, int> FindStmt(int stmt_idx) {
for (size_t i = 0; i < op_infos.size(); i++) {
for (size_t j = 0; j < op_infos[i].group.size(); j++) {
if (op_infos[i].group[j] == stmt_idx) {
return std::make_pair(i, j);
}
}
}
return std::make_pair(-1, -1);
}
void UpdateOrder(int order) {
for (int i = 0; i < static_cast<int>(op_infos.size()); i++) {
if (op_infos[i].order >= order && op_infos[i].order > 0) {
op_infos[i].order++;
}
}
}
int SplitOp(int stmt_idx) {
auto pair = FindStmt(stmt_idx);
int op_idx = pair.first;
int inner_idx = pair.second;
ICHECK(op_idx != -1);
ICHECK(inner_idx != -1);
OpInfo half0;
OpInfo half1;
// The order to do sync
int sync_order = op_infos[op_idx].order + 1;
UpdateOrder(sync_order);
half0.group_size = inner_idx + 1;
half0.order = op_infos[op_idx].order;
half0.stage = op_infos[op_idx].stage;
for (int i = 0; i <= inner_idx; i++) {
half0.group.push_back(op_infos[op_idx].group[i]);
}
half1.group_size = op_infos[op_idx].group_size - inner_idx - 1;
half1.order = op_infos[op_idx].order + 2;
half1.stage = op_infos[op_idx].stage;
for (int i = inner_idx + 1; i < op_infos[op_idx].group_size; i++) {
half1.group.push_back(op_infos[op_idx].group[i]);
}
op_infos.erase(op_infos.begin() + op_idx);
if (half0.group_size > 0) {
op_infos.insert(op_infos.begin() + op_idx, half0);
}
if (half1.group_size > 0) {
UpdateOrder(half1.order);
op_infos.insert(op_infos.begin() + op_idx + 1, half1);
}
return sync_order;
}
void PrintPipelineInfo() {
std::cout << "Print op_infos:" << std::endl;
for (size_t i = 0; i < op_infos.size(); i++) {
std::cout << i << " " << op_infos[i].group_size << " " << op_infos[i].order << " " << op_infos[i].stage << std::endl;
}
std::cout << "End of print" << std::endl;
}
};
class GroupOpRewriter : public StmtExprMutator {
public:
GroupOpRewriter(PipelineInfo pipeline_info) : pipeline_info_(pipeline_info) {}
private:
Stmt VisitStmt_(const ForNode* op) final {
Map<String, ObjectRef> annotations;
annotations.Set(String("stmt_group"), Integer(1));
auto original_node = (op->body).as<SeqStmtNode>();
if (!original_node) {
return GetRef<For>(op);
}
Array<Stmt> new_body;
int cur_id = 0;
for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size()); i++) {
if (pipeline_info_.op_infos[i].group_size == 0) continue;
Array<Stmt> block_stmt;
for (int j = 0; j < static_cast<int>(pipeline_info_.op_infos[i].group_size); j++) {
// ICHECK(group_info_[i][j].as<IntImmNode>());
// int index = static_cast<int>(group_info_[i][j].as<IntImmNode>()->value);
ICHECK(original_node->seq[cur_id].as<BlockNode>());
auto block = original_node->seq[cur_id].as<BlockNode>();
// TODO: handle nested seqstmt
block_stmt.push_back(block->body);
cur_id++;
}
new_body.push_back(
MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
}
Array<Integer> order_anno;
Array<Integer> stage_anno;
for (auto op_info : pipeline_info_.op_infos) {
order_anno.push_back(Integer(op_info.order));
stage_anno.push_back(Integer(op_info.stage));
}
Map<String, ObjectRef> for_annotations = op->annotations;
for_annotations.erase("tl_pipeline_group");
for_annotations.Set("software_pipeline_order", order_anno);
for_annotations.Set("software_pipeline_stage", stage_anno);
For new_for = For(op->loop_var, op->min, op->extent, op->kind, new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)), op->thread_binding, for_annotations);
return new_for;
}
PipelineInfo pipeline_info_;
};
class WSCodeEmitter : public StmtMutator {
public:
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
Map<Var, Buffer> buffer_data_to_buffer, const WarpSpecializedRoleMarker& marker)
: is_emitting_producer_(is_emitting_producer),
buffer_data_to_buffer_(buffer_data_to_buffer),
marker_(marker),
thread_var_(thread_iv->var) {}
private:
template <typename NodeType>
Stmt FilterByRole(const NodeType* op) {
Role role = marker_.GetRole(op);
if (role == Role::kBoth)
return StmtMutator::VisitStmt_(op);
else if ((role == Role::kProducer) == is_emitting_producer_)
return GetRef<Stmt>(op);
else
return Evaluate(0);
}
// TODO: only need to add block for ops in the loop
Stmt VisitStmt_(const SeqStmtNode* op) final {
bool has_producer = false;
for (auto stmt : op->seq) {
if (marker_.GetRole(stmt) == Role::kProducer) {
has_producer = true;
break;
}
}
bool need_producer_sync = has_producer && marker_.GetRole(op) == Role::kBoth;
if (!need_producer_sync) return FilterByRole(op);
auto seq_transformed = op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); });
auto map = ExtractSyncPattern(op->seq);
// std::cout << "Print ExtractSyncPattern" << std::endl;
// for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
// std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " " << map.release_after[i] << std::endl;
// }
// std::cout << "Print sync pattern" << std::endl;
// for (auto pattern : map.patterns) {
// std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl;
// }
// std::cout << "End of ExtractSyncPattern" << std::endl;
// pipeline_info_.PrintPipelineInfo();
Array<Stmt> new_body;
Map<String, ObjectRef> annotations;
annotations.Set(String("stmt_group"), Integer(1));
if (is_emitting_producer_) { // producer case
ProducerTraitsCollector collector;
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
Array<Stmt> block_stmt = {};
if (marker_.GetRole(op->seq[i]) == Role::kConsumer) continue;
if (marker_.GetRole(op->seq[i]) == Role::kBoth) {
block_stmt.push_back(seq_transformed[i]);
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
continue;
}
if (map.acquire[i] != -1) {
PrimExpr acquire_barrier_id = stage_ + num_barriers_ + num_stages_ * map.acquire[i];
PrimExpr parity =
map.is_loop_dependency(map.acquire[i]) ? bitwise_xor(parity_, 1) : parity_;
block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
}
ICHECK(map.release[i] >= 0);
PrimExpr release_barrier_id = stage_ + num_barriers_ + num_stages_ * map.release[i];
auto stmt = MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id);
collector.Collect(stmt);
if (!is_zero(collector.BulkCopyBytes())) {
auto expect_tx = IfThenElse(EQ(thread_var_, 0),
makeExpectTX(release_barrier_id, collector.BulkCopyBytes()));
block_stmt.push_back(expect_tx);
}
block_stmt.push_back(stmt);
if (collector.HasSimtCopy() > 0) {
block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id));
}
if (map.release_after[i]) {
block_stmt.push_back(makeArriveBarrier(release_barrier_id));
for (int j = 0; j < num_stages_; j++) {
released_barrier_.insert(j + num_barriers_ + num_stages_ * map.release[i]);
}
}
collector.Clear();
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
}
} else { // consumer case
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
Array<Stmt> block_stmt = {};
if (marker_.GetRole(op->seq[i]) == Role::kProducer) continue;
if (map.acquire[i] != -1) {
PrimExpr acquire_barrier_id = stage_ + num_barriers_ + num_stages_ * map.acquire[i];
PrimExpr parity =
map.is_loop_dependency(map.acquire[i]) ? bitwise_xor(parity_, 1) : parity_;
block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
}
block_stmt.push_back(seq_transformed[i]);
// new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
if (map.release_after[i]) {
PrimExpr release_barrier_id = stage_ + num_barriers_ + num_stages_ * map.release[i];
block_stmt.push_back(makeArriveBarrier(release_barrier_id));
for (int j = 0; j < num_stages_; j++) {
released_barrier_.insert(j + num_barriers_ + num_stages_ * map.release[i]);
}
// Update the pipeline info
// Todo: handle sync
}
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
}
// Filter out the producer stmts
int cur_id = 0;
PipelineInfo new_pipeline_info;
for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size()); i++) {
auto op_info = pipeline_info_.op_infos[i];
bool is_producer = false;
for (int j = 0; j < op_info.group_size; j++) {
if (marker_.GetRole(op->seq[cur_id]) == Role::kProducer) {
is_producer = true;
}
cur_id++;
}
if (is_producer) {
ICHECK(op_info.group_size == 1);
} else {
new_pipeline_info.op_infos.push_back(op_info);
}
}
pipeline_info_ = new_pipeline_info;
}
num_barriers_ += map.patterns.size() * num_stages_;
ICHECK(new_body.size() > 0);
return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body));
}
Stmt VisitStmt_(const ForNode* op) final {
int num_stages = 1;
auto num_stages_anno = op->annotations.Get("num_stages");
if (num_stages_anno.defined()) {
ICHECK(num_stages_anno.as<IntImmNode>());
num_stages = static_cast<int>(num_stages_anno.as<IntImmNode>()->value);
ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
}
Array<Array<Integer>> group_info_array;
Array<Integer> order_info_array;
Array<Integer> stage_info_array;
auto group_anno = op->annotations.Get("tl_pipeline_group");
if (group_anno.defined()) {
group_info_array = Downcast<Array<Array<Integer>>>(group_anno);
}
auto order_anno = op->annotations.Get("tl_pipeline_order");
if (order_anno.defined()) {
order_info_array = Downcast<Array<Integer>>(order_anno);
}
auto stage_anno = op->annotations.Get("tl_pipeline_stage");
if (stage_anno.defined()) {
stage_info_array = Downcast<Array<Integer>>(stage_anno);
}
PipelineInfo pipeline_info(group_info_array, order_info_array, stage_info_array);
if (pipeline_info.op_infos.size() > 0) {
ICHECK(pipeline_info_.op_infos.size() == 0) << "Nested pipeline not supported.";
}
PrimExpr parity_before = std::move(parity_);
PrimExpr stage_before = std::move(stage_);
int num_stages_before = num_stages_;
PipelineInfo pipeline_info_before = pipeline_info_;
num_stages_ = num_stages;
pipeline_info_ = pipeline_info;
stage_ = FloorMod(op->loop_var - op->min, num_stages);
parity_ =
FloorMod(parity_before * op->extent + FloorDiv(op->loop_var - op->min, num_stages), 2);
auto result = FilterByRole(op);
Stmt grouped_for_node;
if (result.as<ForNode>() && group_anno.defined() && group_info_array.size() > 0 && !is_emitting_producer_) {
GroupOpRewriter group_op_rewriter(pipeline_info_);
auto for_node = Downcast<For>(result);
grouped_for_node = group_op_rewriter(for_node);
}
parity_ = std::move(parity_before);
stage_ = std::move(stage_before);
num_stages_ = num_stages_before;
pipeline_info_ = pipeline_info_before;
// remove pipeline annotation
auto for_node = result.as<For>();
if (result.as<ForNode>()) {
auto for_node = Downcast<For>(result);
for_node.CopyOnWrite()->annotations.erase("num_stages");
if (is_emitting_producer_ || group_info_array.size() == 0) {
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order");
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage");
}
if (is_emitting_producer_ || !group_anno.defined() ||group_info_array.size() == 0) {
return for_node;
}
return grouped_for_node;
}
return result;
}
Stmt VisitStmt_(const IfThenElseNode* op) final { return FilterByRole(op); }
Stmt VisitStmt_(const EvaluateNode* op) final { return FilterByRole(op); }
Stmt VisitStmt_(const AttrStmtNode* op) final { return FilterByRole(op); }
Stmt VisitStmt_(const BufferStoreNode* op) final { return FilterByRole(op); }
Stmt VisitStmt_(const LetStmtNode* op) final { return FilterByRole(op); }
Stmt VisitStmt_(const AssertStmtNode* op) final { return FilterByRole(op); }
Stmt VisitStmt_(const BlockNode* op) final {
ICHECK(0);
return Stmt();
}
Stmt VisitStmt_(const BlockRealizeNode* op) final {
ICHECK(0);
return Stmt();
}
struct SyncPattern {
int release_idx, acquire_idx;
};
struct SyncPatternMap {
std::vector<int> acquire;
std::vector<int> release;
std::vector<bool> release_after;
std::vector<SyncPattern> patterns;
bool is_loop_dependency(int i) {
// return if the acquire is based on release in the previous iteration
return patterns[i].release_idx > patterns[i].acquire_idx;
}
};
std::vector<SyncPattern> CreateBaseSyncPairs(Array<Stmt> seq_stmt,
const std::vector<bool>& is_producer) {
const int n = seq_stmt.size();
std::vector<std::set<const BufferNode*>> reads, writes;
reads.reserve(n);
writes.reserve(n);
for (int i = 0; i < n; i++) {
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ seq_stmt[i]);
auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
std::set<const BufferNode*> read_set, write_set;
for (auto region : access[0]) read_set.insert(region->buffer.get());
for (auto region : access[1]) write_set.insert(region->buffer.get());
reads.push_back(std::move(read_set));
writes.push_back(std::move(write_set));
}
auto intersect_fn = [](const std::set<const BufferNode*>& lhs,
const std::set<const BufferNode*>& rhs) {
for (auto ptr : lhs)
if (rhs.count(ptr)) return true;
return false;
};
std::vector<SyncPattern> sync_patterns;
// producer_release consumer_acquire,
// inject before the first consumer stmt for each producer
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
if (is_producer[i] != is_producer[j] &&
(intersect_fn(writes[i], reads[j]) || intersect_fn(reads[i], writes[j]))) {
sync_patterns.push_back({i, j});
break;
}
}
}
// consumer_release producer_acquire
// valid when is_loop is true
// inject before the earliest producer stmt for each consumer
bool in_loop = !is_zero(parity_);
if (in_loop) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < i; j++) {
if (is_producer[i] != is_producer[j] &&
(intersect_fn(writes[i], reads[j]) || intersect_fn(reads[i], writes[j]))) {
sync_patterns.push_back({i, j});
break;
}
}
}
}
return sync_patterns;
}
static std::vector<SyncPattern> RemoveUnusedSyncPatterns(
const std::vector<SyncPattern>& sync_patterns, const std::vector<bool>& is_producer) {
/*
Simplify multiple release-acquire pairs into one
------------------
Produce(A)
Produce(B)
Consume(A, B)
------------------
[(0, 2), (1, 2), (2, 0)] -> [(1, 2), (2, 0)]
Or
------------------
Produce(A, B)
Consume(A)
Consume(B)
------------------
[(0, 1), (1, 0), (2, 0)] -> [(0, 1), (2, 0)]
*/
int M = sync_patterns.size();
std::vector<bool> removed(M, false);
for (int i = 0; i < M; i++) {
for (int j = 0; j < M; j++) {
if (is_producer[sync_patterns[i].acquire_idx] ==
is_producer[sync_patterns[j].acquire_idx] &&
sync_patterns[i].acquire_idx >= sync_patterns[j].acquire_idx &&
sync_patterns[i].release_idx < sync_patterns[j].release_idx)
removed[i] = true;
}
}
std::vector<SyncPattern> sync_pattern_cleaned;
sync_pattern_cleaned.reserve(M);
for (int i = 0; i < M; i++)
if (!removed[i]) sync_pattern_cleaned.push_back(sync_patterns[i]);
return sync_pattern_cleaned;
}
SyncPatternMap ExtractSyncPattern(Array<Stmt> seq_stmt) {
size_t num_stmts = seq_stmt.size();
std::vector<bool> is_producer;
is_producer.reserve(num_stmts);
for (auto stmt : seq_stmt) {
is_producer.push_back(marker_.GetRole(stmt) == Role::kProducer);
}
auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer);
auto sync_patterns = RemoveUnusedSyncPatterns(sync_patterns_base, is_producer);
// for (auto pattern : sync_patterns) {
// std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl;
// }
SyncPatternMap map;
map.patterns = sync_patterns;
map.acquire.resize(num_stmts, -1);
map.release.resize(num_stmts, -1);
map.release_after.resize(num_stmts, false);
for (size_t i = 0; i < sync_patterns.size(); i++) {
map.acquire[sync_patterns[i].acquire_idx] = i;
map.release[sync_patterns[i].release_idx] = i;
map.release_after[sync_patterns[i].release_idx] = true;
}
int cur_consumer_barrier = -1, cur_producer_barrier = -1;
for (int i = num_stmts - 1; i >= 0; i--) {
if (is_producer[i]) {
if (map.release[i] == -1) {
map.release[i] = cur_producer_barrier;
} else {
cur_producer_barrier = map.release[i];
}
} else {
if (map.release[i] == -1) {
map.release[i] = cur_consumer_barrier;
} else {
cur_consumer_barrier = map.release[i];
}
}
}
return map;
}
const bool is_emitting_producer_;
Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_set<int> released_barrier_;
const WarpSpecializedRoleMarker& marker_;
int num_barriers_ = 0;
PrimExpr parity_ = 0;
PrimExpr stage_ = 0;
int num_stages_ = 1;
Var thread_var_;
PipelineInfo pipeline_info_;
friend class WarpSpecializedRewriter;
};
class WarpSpecializedRewriter : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc f) {
auto T = WarpSpecializedRewriter();
T.buffer_lca_ = DetectBufferAccessLCA(f);
for (auto [buffer, _] : T.buffer_lca_) T.buffer_data_to_buffer_.Set(buffer->data, buffer);
f.CopyOnWrite()->body = T(f->body);
return f;
}
private:
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::thread_extent &&
Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
thread_iv_ = Downcast<IterVar>(op->node);
need_update_thread_extent_ = false;
AttrStmt attr_stmt = Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(op));
if (need_update_thread_extent_) {
thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()};
attr_stmt.CopyOnWrite()->node = thread_iv_;
attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value();
}
thread_iv_ = {};
return attr_stmt;
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
// If users define a thread binding, we will replace the thread binding with threadIdx.x
// We require the thread binding is threadIdx.x, and the extent is the same as the thread extent
Stmt VisitStmt_(const ForNode* op) final {
ICHECK(thread_iv_.defined());
For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
if (for_node->kind == ForKind::kThreadBinding) {
ICHECK(for_node->thread_binding.defined());
String thread_tag = for_node->thread_binding.value()->thread_tag;
ICHECK(thread_tag == "threadIdx.x") << "Only support threadIdx.x";
Var thread_iv = Downcast<Var>(for_node->loop_var);
Stmt new_body = ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_);
return new_body;
}
return for_node;
}
Stmt VisitStmt_(const BlockRealizeNode* op) final {
BlockRealize block_realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
if (!thread_iv_.defined()) {
return block_realize;
}
Block block = block_realize->block;
WarpSpecializedRoleMarker marker(buffer_data_to_buffer_);
marker(block);
if (!marker.HasProducer()) {
// Cannot detect any producer here, directly return.
return block_realize;
}
WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker);
Stmt producer_code = producer(block->body);
Stmt consumer_code = consumer(block->body);
PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
PrimExpr producer_thread_extent = thread_iv_->dom->extent;
// Need one warp-group for bulk-copy only case
if (!marker.HasSimtCopy()) producer_thread_extent = 128;
// TODO: estimate the correct reg usage.
auto inc_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {240, 1}));
auto dec_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {24, 0}));
producer_code = SeqStmt({dec_reg_stmt, producer_code});
consumer_code = SeqStmt({inc_reg_stmt, consumer_code});
producer_code = ThreadIdxRewriter::Rewrite(producer_code, thread_iv_->var,
thread_iv_->var - consumer_thread_extent);
updated_thread_extent_ = consumer_thread_extent + producer_thread_extent;
need_update_thread_extent_ = true;
ICHECK(producer.num_barriers_ == consumer.num_barriers_)
<< producer.num_barriers_ << " " << consumer.num_barriers_;
int num_barriers = consumer.num_barriers_;
Array<PrimExpr> barrier_num_threads;
barrier_num_threads.reserve(num_barriers);
for (int i = 0; i < num_barriers; i++) {
PrimExpr arrive_thread_count =
producer.released_barrier_.count(i) ? producer_thread_extent : consumer_thread_extent;
barrier_num_threads.push_back(arrive_thread_count);
}
Stmt init_barrier =
Evaluate(Call(DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads));
Stmt body =
IfThenElse(GE(thread_iv_->var, consumer_thread_extent), producer_code, consumer_code);
// Add an attr here to handle the partial thread count in THreadSync pass.
Array<IntImm> ws_partition = {Downcast<IntImm>(producer_thread_extent),
Downcast<IntImm>(consumer_thread_extent)};
body = AttrStmt(ws_partition, "kWarpSpecializationScope", 0, body);
block.CopyOnWrite()->body = SeqStmt({init_barrier, body});
block_realize.CopyOnWrite()->block = block;
return block_realize;
}
WarpSpecializedRewriter() = default;
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Optional<Stmt>> buffer_lca_;
Map<Buffer, Buffer> buffer_remap_;
IterVar thread_iv_;
Optional<PrimExpr> updated_thread_extent_;
bool need_update_thread_extent_ = false;
};
using namespace tir::transform;
tvm::transform::Pass WarpSpecialized() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return WarpSpecializedRewriter::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}
TVM_REGISTER_GLOBAL("tl.transform.WarpSpecialized").set_body_typed(WarpSpecialized);
} // namespace tl
} // namespace tvm
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
torch.manual_seed(0)
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
block_row_warps = 1
block_col_warps = 1
warp_row_tiles = 16
warp_col_tiles = 16
chunk = 32
shared_scope = "shared.dyn"
cache_write_shared = False
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 64
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mfma_emitter = MatrixCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=0):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mfma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
mfma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local)
# Perform STMatrix
if cache_write_shared:
mfma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
else:
mfma_emitter.stmatrix(
C_local,
C,
thread_bindings=thread_bindings,
pid_m=by,
pid_n=bx,
)
return main
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32"):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
if in_dtype == "int8":
A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8)
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
latency = mod.do_bench(mod.func, warmup=25)
# Ensure that the latency is not None
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, out_dtype))
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_rocm
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32")
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang as tl
import tilelang.language as T
import tilelang.testing
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
k_pack=1,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
vec_size = 4 * k_pack
@T.prim_func
def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer(
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared, coalesced_width=vec_size)
else:
T.copy(A[by * block_M, k * block_K], A_shared, coalesced_width=vec_size)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared, coalesced_width=vec_size)
else:
T.copy(B[k * block_K, bx * block_N], B_shared, coalesced_width=vec_size)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B, k_pack=k_pack)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
k_pack=1,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
k_pack=k_pack,
)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_f16f32f32_nt():
run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32, k_pack=2)
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics.utils import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (TensorCoreIntrinEmitter)
torch.manual_seed(0)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
def tl_matmul_macro(
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 1
block_col_warps = 1
warp_row_tiles = 16
warp_col_tiles = 16
chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
M = tvm.te.var("m")
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), in_dtype)
B_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return main
def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul_macro(N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
# Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def tl_matmul_block(
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
M = tvm.te.var("m")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer(
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def assert_tl_matmul_block_correctness(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = tl_matmul_block(
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
mod, params = TL.lower(program)
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def tl_matmul_block_all_dynamic(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
M = tvm.te.var("m")
N = tvm.te.var("n")
K = tvm.te.var("k")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer(
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def assert_tl_matmul_block_all_dynamic_correctness(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = tl_matmul_block_all_dynamic(
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
mod, params = TL.lower(program)
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
# Get Reference Result
ref_c = ref_program(A, B)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def test_assert_tl_matmul_macro():
assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_macro_correctness(66, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_macro_correctness(32, 128, 128, "float16", "float16", "float16")
def test_assert_tl_matmul_block():
assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16",
64, 64, 32)
assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16",
64, 64, 32)
assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16",
64, 64, 32)
def test_assert_tl_matmul_block_all_dynamic():
assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16",
"float16", "float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16",
"float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16",
"float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 115, 103, False, False, "float16", "float16",
"float16", 64, 64, 32)
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# TODO: implement this test for tilelang/language/kernel.py
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang as TL
import tilelang.language as T
torch.manual_seed(0)
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
num_bits=4,
):
from bitblas.quantization import _tir_packed_to_unsigned_convert
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
storage_type = str("".join(c for c in storage_dtype if not c.isdigit()))
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
MAX_TRANSACTION_SIZE_IN_BITS = 128
local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
local_size_compressed = local_size // num_elems_per_byte
import tvm.tl.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_local([local_size_compressed], storage_dtype)
B_dequantize_local = T.alloc_local([local_size], in_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
tx = T.thread_binding(0, threads, thread="threadIdx.x")
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
for i in T.serial(block_N * block_K // num_elems_per_byte //
(threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed):
index = i * threads * local_size_compressed + tx * local_size_compressed + v
vi = index // (block_K // num_elems_per_byte)
vj = index % (block_K // num_elems_per_byte)
B_local[v] = B_shared[vi, vj]
for v in T.serial(0, local_size):
B_dequantize_local[v] = _tir_packed_to_unsigned_convert(
storage_type, storage_nbit)(
num_bits,
B_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
)
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
vi = index // block_K
vj = index % block_K
B_dequantize_shared[vi, vj] = B_dequantize_local[v]
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
mod, params = TL.lower(program)
mod = TL.Profiler(mod, params, [2], TL.TensorSupplyType.Integer)
out = mod.run_once()
assert out is not None
def ref_program(A, qB):
import torch
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program)
@tvm.testing.requires_package("bitblas")
def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
transform_b,
):
from bitblas.tl.utils import make_mma_swizzle_layout as make_swizzle_layout
from bitblas.tl.mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform,)
from bitblas.gpu.intrin.lop3 import decode_i4_to_f16
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_rows = 4
warp_cols = 4
warp_row_tiles = micro_size_x * warp_rows
warp_col_tiles = micro_size_y * warp_cols
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
reduce_k = 1
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = 32 if in_dtype == "float16" else 64
chunk = block_K // reduce_k
is_smooth_a = False
can_swizzle = block_K * DataType(in_dtype).bits == 512
apply_pad_a = not (is_smooth_a or can_swizzle)
pad_factor = 8
A_shape = (M, K)
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y,
micro_size_k // num_elems_per_byte)
A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K)
B_shared_shape = (
block_N // micro_size_y,
block_K // micro_size_k,
micro_size_y,
micro_size_k // num_elems_per_byte,
)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitterWithLadderTransform(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
reduce_k=reduce_k,
transform_kind_b=transform_b,
num_elems_per_byte=num_elems_per_byte)
vec_load_qb = 16
if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb:
vec_load_qb = block_N * (block_K // reduce_k) // num_elems_per_byte // threads
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads,
prelude=decode_i4_to_f16) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), in_dtype)
B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype)
B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
rk = T.thread_binding(0, reduce_k, "threadIdx.y")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
})
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, (block_K // reduce_k)):
vk = rk * (block_K // reduce_k) + k
A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk]
# TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte //
(threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb):
t = thread_bindings
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (
block_K // micro_size_k)
vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y //
(block_K // micro_size_k)) % (
block_N // micro_size_y)
B_shared[vj, vk, vjj,
vkk] = B[bx * (block_N // micro_size_y) + vj,
ko * (block_K // micro_size_k) + vk, vjj, vkk]
for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
rk=rk,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
rk=rk,
)
for j in T.serial(warp_cols):
local_size_b = mma_emitter.local_size_b
T.call_extern('handle', 'decode_i4u_to_f16',
T.address_of(B_local[j * local_size_b // num_elems_per_byte]),
T.address_of(B_dequantize_local[j * local_size_b]), 8)
mma_emitter.mma(A_local, B_dequantize_local, C_local)
if reduce_k > 1:
for n in T.serial(warp_rows * warp_cols * local_size):
T.attr(
T.comm_reducer(lambda x, y: x + y, [T.float16(0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
)
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_local[n],
True,
reduced_accum_res[0],
rk,
dtype="handle",
))
if rk == 0:
C_local[n] = reduced_accum_res[0]
if rk == 0:
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
for i, j in T.Parallel(block_M, (block_N // reduce_k)):
vj = rk * (block_N // reduce_k) + j
C[by * block_M + i,
bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y,
i % micro_size_x, vj % micro_size_y]
return main
def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
transform_b,
):
import bitblas
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
qB = torch.randint(
0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
ladder_permutate_config = bitblas.ops.LadderPermutateConfig(
M=N,
N=K,
transform_kind=transform_b,
transpose_matrix=True,
dequantize_bits=num_bits,
storage_dtype=storage_dtype,
)
ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config)
lop3_permutate_config = bitblas.ops.LOP3PermutateConfig(
M=N,
N=K,
datatype=in_dtype,
dequantize_bits=num_bits,
storage_dtype=storage_dtype,
)
lop3_permutate = bitblas.ops.LOP3Permutate(
config=lop3_permutate_config,
target=tvm.target.Target("llvm"),
)
QLB = ladder_permutate(qB.cpu()).cuda()
QLB = lop3_permutate(QLB.cpu()).cuda()
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, QLB, C)
latency = mod.do_bench(mod.func, warmup=25)
# Ensure that the latency is not None
assert latency is not None
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
# Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
print("Ref C: ", ref_c)
print("C: ", C)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_package("bitblas")
def test_run_dequantize_gemm():
run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128)
run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128)
@tilelang.testing.requires_package("bitblas")
def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
256, 1024, 512, "float16", "float16", "float16", 3)
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def test_gemm_f16f16f32_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float32",
128,
128,
32,
)
def test_gemm_bf16bf16f32_nn():
run_gemm(
512,
1024,
768,
False,
False,
"bfloat16",
"bfloat16",
"float32",
128,
128,
32,
)
def test_gemm_f32f32f32_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float32",
"float32",
"float32",
64,
128,
32,
)
def test_gemm_i8i8i32_nn():
run_gemm(
512, 1024, 768, False, False, "int8", "int8", "int32", 128, 128, 64
)
def test_gemm_f16f16f16_tn():
run_gemm(
512,
1024,
768,
True,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def test_gemm_f16f16f16_nt():
run_gemm(
512,
1024,
768,
False,
True,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def test_gemm_i8i8i32_nt():
run_gemm(512, 1024, 768, False, True, "int8", "int8", "int32", 128, 128, 64)
def test_gemm_i8i8i32_tn():
run_gemm(512, 1024, 768, True, False, "int8", "int8", "int32", 128, 128, 64)
def test_gemm_f64f64f64_nt():
run_gemm(
512, 512, 512, False, True, "float64", "float64", "float64", 64, 32, 16
)
def test_gemm_f32f32f32_nt():
run_gemm(
512,
1024,
768,
False,
True,
"float32",
"float32",
"float32",
64,
128,
32,
)
def test_gemm_f32f32f32_tn():
run_gemm(
512,
1024,
768,
True,
False,
"float32",
"float32",
"float32",
64,
128,
32,
)
def test_pad_aligned_f16f16f16_nn():
run_gemm(
512 - 8,
1024 - 32,
768 - 24,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def test_pad_f16f16f16_nn():
run_gemm(
512 - 9,
1024 - 7,
768 - 5,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def test_pad_f16f16f32_nn():
run_gemm(
512 + 19,
1024 + 17,
768 + 15,
False,
False,
"float16",
"float16",
"float32",
128,
64,
32,
)
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
torch.manual_seed(0)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 1
block_col_warps = 1
warp_row_tiles = 16
warp_col_tiles = 16
# chunk = 32 if in_dtype == "float16" else 64
chunk = 32
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return main
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
if in_dtype == "int8":
A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8)
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
latency = mod.do_bench(mod.func, warmup=25)
# Ensure that the latency is not None
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype))
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32")
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.transform import simplify_prim_func
torch.manual_seed(0)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@simplify_prim_func
def tl_matmul_simt(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
# This is a debug config
block_size_x = 8
block_size_y = 8
thread_row_tiles = 16
thread_col_tiles = 16
chunk = 16
shared_scope = "shared"
block_M = block_size_x * thread_row_tiles
block_N = block_size_y * thread_col_tiles
block_K = chunk
# Pipeline Stage
A_shape = (M, K)
B_shape = (N, K)
C_shape = (M, N)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
threads = thread_row_tiles * thread_col_tiles
local_size_a = block_M // thread_row_tiles
local_size_b = block_N // thread_col_tiles
local_size_c = (block_M // thread_row_tiles) * (block_N // thread_col_tiles)
micro_size_k = 128 // DataType(in_dtype).bits
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
A_local = T.alloc_local((local_size_a, micro_size_k), in_dtype)
B_local = T.alloc_local((local_size_b, micro_size_k), in_dtype)
C_local = T.alloc_local((local_size_c,), accum_dtype)
thread_binding = T.thread_binding(threads, "threadIdx.x")
warp_m = thread_binding % thread_row_tiles
warp_n = thread_binding // thread_row_tiles
T.clear(C_local)
for ko in T.serial(K // block_K):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial((block_K // micro_size_k)):
for i in T.serial(local_size_a):
for mk in T.vectorized(micro_size_k):
A_local[i, mk] = A_shared[warp_m * local_size_a + i,
ki * micro_size_k + mk]
for i in T.serial(local_size_b):
for mk in T.vectorized(micro_size_k):
B_local[i, mk] = B_shared[warp_n * local_size_b + i,
ki * micro_size_k + mk]
for i, j in T.grid(local_size_a, local_size_b):
for mk in T.serial(micro_size_k // dp4a_size):
if use_dp4a:
T.dp4a(A_local[i, mk * dp4a_size], B_local[j, mk * dp4a_size],
C_local[i * local_size_b + j])
else:
for dp4a_idx in T.serial(dp4a_size):
C_local[i * local_size_b +
j] += A_local[i, mk * dp4a_size +
dp4a_idx] * B_local[j, mk * dp4a_size +
dp4a_idx]
for i, j in T.grid(local_size_a, local_size_b):
C[by * block_M + warp_m * local_size_a + i,
bx * block_N + warp_n * local_size_b + j] = C_local[i * local_size_b + j]
return main
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul_simt(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
print(src_code)
# src_code is the generated cuda source
assert src_code is not None
if in_dtype == "int8":
A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8)
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
latency = mod.do_bench(mod.func, warmup=25)
# Ensure that the latency is not None
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype))
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32")
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,)
from tilelang.intrinsics.mma_macro_generator import (
INT4TensorCoreIntrinEmitter,
INT4TensorCoreIntrinEmitterWithLadderTransform,
)
from tilelang.transform import simplify_prim_func
torch.manual_seed(0)
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
K = K // 2
micro_size_x = micro_size_y = micro_size_k = 16
if accum_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K) # int8 storage represents int4*2
B_shape = (N, K) # int8 storage represents int4*2
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = INT4TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return main
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype))
B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4)
compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(compressed_A, compressed_B, C)
print(C)
latency = mod.do_bench(mod.func, warmup=25, profiler="tvm")
print(latency)
# Ensure that the latency is not None
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype))
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@simplify_prim_func
def tl_matmul_weight_only_transform(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
K = K // 2
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
transform_b = 3
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k)
A_shared_shape = (
block_M,
block_K,
)
B_shared_shape = (
block_N // micro_size_y,
block_K // micro_size_k,
micro_size_y,
micro_size_k,
)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = INT4TensorCoreIntrinEmitterWithLadderTransform(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
transform_kind_b=transform_b,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k,
micro_size_y, micro_size_k):
B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j,
ko * (block_K // micro_size_k) + k, jj, kk]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return main
def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
transform_b = 3
A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype))
B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4)
compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4)
ladder_permutate_config = tilelang.ops.LadderPermutateConfig(
M=N,
N=(K // 2),
datatype="int8",
storage_dtype="int8",
transform_kind=transform_b,
transpose_matrix=True,
)
ladder_permutate = tilelang.ops.LadderPermutate(ladder_permutate_config)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
LB = ladder_permutate(compressed_B.cpu()).cuda()
mod(compressed_A, LB, C)
latency = mod.do_bench(mod.func, warmup=25)
print(f"Latency: {latency}")
# Ensure that the latency is not None
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype))
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_package("bitblas")
def test_assert_tl_matmul_weight_only_transform():
assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, "int8", "int32", "int32")
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
from tilelang import primitives as P
def matmul_ssr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
P.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_matmul_ssr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_ssr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
print(program)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_f16f16f16_nt_ssr():
run_matmul_ssr(
1024,
1024,
1024,
False,
True,
"float16",
"float16",
"float16",
128,
128,
32,
2,
)
def matmul_rsr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_local_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
A_local = T.alloc_fragment(A_local_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
T.copy(A_shared, A_local)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(A_shared, A_local)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
P.gemm(A_local, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_matmul_rsr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_rsr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_f16f16f16_nt_rsr():
run_matmul_rsr(
1024,
1024,
1024,
False,
True,
"float16",
"float16",
"float16",
16,
16,
16,
0,
num_threads=32,
)
def matmul_rrr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_local_shape = A_shared_shape
B_local_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
A_local = T.alloc_fragment(A_local_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
B_local = T.alloc_fragment(B_local_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
T.copy(A_shared, A_local)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(A_shared, A_local)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
T.copy(B_shared, B_local)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(B_shared, B_local)
P.gemm(A_local, B_local, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_matmul_rrr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_rrr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_f16f16f16_nt_rrr():
run_matmul_rrr(
1024,
1024,
1024,
False,
True,
"float16",
"float16",
"float16",
128,
128,
32,
2,
)
if __name__ == "__main__":
# tilelang.testing.main()
# test_gemm_f16f16f16_nt_ssr()
test_gemm_f16f16f16_nt_rsr()
# test_gemm_f16f16f16_nt_rrr()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang as tl
import tilelang.language as T
import tilelang.testing
def modify(
with_B: bool = False,
with_bias: bool = False,
):
@T.prim_func
def main(
A: T.Buffer((64, 64)),
B: T.Buffer((64, 64)),
C: T.Buffer((64, 64)),
D: T.Buffer((64, 64)),
bias: T.Buffer((64, 64)),
):
if with_B:
if with_bias:
T.gemm(A, bias, D)
T.gemm(A, B, D)
else:
with T.block():
A_shared = T.alloc_shared((64, 64), dtype="float32")
C_shared = T.alloc_shared((64, 64), dtype="float32")
D_shared = T.alloc_shared((64, 64), dtype="float32")
T.copy(A, A_shared)
T.copy(C, C_shared)
T.gemm(A_shared, C_shared, D_shared)
T.copy(D_shared, D)
return main
def test_modify(with_B=False, with_bias=False):
tester = modify(with_B=with_B, with_bias=with_bias)
mod = tvm.IRModule({tester.attrs["global_symbol"]: tester})
mod2 = tl.transform.Simplify()(mod)
assert mod != mod2
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
a: T.handle,
b: T.handle,
c: T.handle,
):
A = T.match_buffer(a, (M, K), dtype=dtype)
B = T.match_buffer(b, (K, N), dtype=dtype)
C = T.match_buffer(c, (M, N), dtype=accum_dtype)
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def test_matmul():
func = matmul(1024, 1024, 1024, 128, 128, 32)
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
mod = tl.transform.Simplify()(mod)
rt_mod, params = tl.lower(mod.functions_items()[0][1], runtime_only=False)
# TODO Profiler only support TensorType, not dynamic variable
profiler = tl.Profiler(rt_mod, params, result_idx=[2])
import torch
a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half()
b = torch.randn(1024, 1024, dtype=torch.float16).cuda().half()
c = profiler(a, b)
ref_c = a @ b
ref_c = ref_c.float()
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Get CUDA Source
# print(rt_mod.imported_modules[0].get_source())
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import os
import ctypes
import warnings
import functools
import logging
from tqdm import tqdm
class TqdmLoggingHandler(logging.Handler):
"""Custom logging handler that directs log output to tqdm progress bar to avoid interference."""
def __init__(self, level=logging.NOTSET):
"""Initialize the handler with an optional log level."""
super().__init__(level)
def emit(self, record):
"""Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted."""
try:
msg = self.format(record)
tqdm.write(msg)
except Exception:
self.handleError(record)
def set_log_level(level):
"""Set the logging level for the module's logger.
Args:
level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO).
OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'
"""
if isinstance(level, str):
level = getattr(logging, level.upper(), logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(level)
def _init_logger():
"""Initialize the logger specific for this module with custom settings and a Tqdm-based handler."""
logger = logging.getLogger(__name__)
handler = TqdmLoggingHandler()
formatter = logging.Formatter(
fmt="%(asctime)s [TileLang:%(levelname)s]: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
set_log_level("WARNING")
_init_logger()
def deprecated(reason):
"""
This is a decorator which can be used to mark functions as deprecated.
It will result in a warning being emitted when the function is used.
"""
def decorator(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
warnings.warn(
f"Call to deprecated function {func.__name__} ({reason}).",
category=DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)
return new_func
return decorator
logger = logging.getLogger(__name__)
# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TVM_LIBRARY_NOT_FOUND_MESSAGE = ("TVM is not installed or found in the expected path")
SKIP_LOADING_TILELANG_SO = os.environ.get("SKIP_LOADING_TILELANG_SO", "0")
# Handle TVM_IMPORT_PYTHON_PATH to import tvm from the specified path
TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None)
if TVM_IMPORT_PYTHON_PATH is not None:
os.environ["PYTHONPATH"] = TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, TVM_IMPORT_PYTHON_PATH + "/python")
else:
install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
install_tvm_library_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lib")
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, install_tvm_path + "/python")
develop_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
develop_tvm_library_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "build", "tvm")
if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
sys.path.insert(0, develop_tvm_path + "/python")
if os.environ.get("TVM_LIBRARY_PATH") is None:
if os.path.exists(develop_tvm_library_path):
os.environ["TVM_LIBRARY_PATH"] = develop_tvm_library_path
elif os.path.exists(install_tvm_library_path):
os.environ["TVM_LIBRARY_PATH"] = install_tvm_library_path
else:
logger.warning(TVM_LIBRARY_NOT_FOUND_MESSAGE)
if os.environ.get("TL_CUTLASS_PATH", None) is None:
install_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass")
develop_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass")
if os.path.exists(install_cutlass_path):
os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include"
elif (os.path.exists(develop_cutlass_path) and develop_cutlass_path not in sys.path):
os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include"
else:
logger.warning(CUTLASS_NOT_FOUND_MESSAGE)
if os.environ.get("TL_TEMPLATE_PATH", None) is None:
install_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src")
develop_tl_template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src")
if os.path.exists(install_tl_template_path):
os.environ["TL_TEMPLATE_PATH"] = install_tl_template_path
elif (os.path.exists(develop_tl_template_path) and develop_tl_template_path not in sys.path):
os.environ["TL_TEMPLATE_PATH"] = develop_tl_template_path
else:
logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE)
import tvm
import tvm._ffi.base
from . import libinfo
def _load_tile_lang_lib():
"""Load Tile Lang lib"""
if sys.platform.startswith("win32") and sys.version_info >= (3, 8):
for path in libinfo.get_dll_directories():
os.add_dll_directory(path)
# pylint: disable=protected-access
lib_name = "tilelang" if tvm._ffi.base._RUNTIME_ONLY else "tilelang_module"
# pylint: enable=protected-access
lib_path = libinfo.find_lib_path(lib_name, optional=False)
return ctypes.CDLL(lib_path[0]), lib_path[0]
# only load once here
if SKIP_LOADING_TILELANG_SO == "0":
_LIB, _LIB_PATH = _load_tile_lang_lib()
from .utils import (
Profiler, # noqa: F401
TensorSupplyType, # noqa: F401
)
from .layout import (
Layout, # noqa: F401
Fragment, # noqa: F401
)
from . import (
transform, # noqa: F401
autotuner, # noqa: F401
language, # noqa: F401
engine, # noqa: F401
)
from .engine import lower # noqa: F401
from .version import __version__ # noqa: F401
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment