Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace ffi;
namespace tir = tvm::tir; namespace tir = tvm::tir;
class HostDeviceSplitter : public tir::StmtMutator { class HostDeviceSplitter : public tir::StmtMutator {
...@@ -200,10 +200,10 @@ tvm::transform::Pass SplitHostDevice() { ...@@ -200,10 +200,10 @@ tvm::transform::Pass SplitHostDevice() {
{}); {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice); refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include "../op/builtin.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
namespace tvm { namespace tvm {
...@@ -38,10 +39,11 @@ using namespace tir; ...@@ -38,10 +39,11 @@ using namespace tir;
void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) { void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) {
Var buf = op->buffer->data; Var buf = op->buffer->data;
buffer_data_to_buffer_.Set(GetRef<Var>(buf.get()), op->buffer); buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buf.get()), op->buffer);
StorageScope scope = GetScope(buf); StorageScope scope = GetScope(buf);
if (Enabled(buf.get(), scope)) { if (Enabled(buf.get(), scope)) {
ICHECK(allow_append_) << GetRef<BufferLoad>(op) << " " << scope.to_string(); ICHECK(allow_append_) << tvm::ffi::GetRef<BufferLoad>(op) << " "
<< scope.to_string();
AccessEntry e; AccessEntry e;
e.threads = env_threads(); e.threads = env_threads();
e.thread_range = this->ComputeThreadRange(e.threads); e.thread_range = this->ComputeThreadRange(e.threads);
...@@ -65,7 +67,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) { ...@@ -65,7 +67,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) {
curr_stmt_.stmt = op; curr_stmt_.stmt = op;
Var buf = op->buffer->data; Var buf = op->buffer->data;
buffer_data_to_buffer_.Set(GetRef<Var>(buf.get()), op->buffer); buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buf.get()), op->buffer);
StorageScope scope = GetScope(buf); StorageScope scope = GetScope(buf);
if (Enabled(buf.get(), scope)) { if (Enabled(buf.get(), scope)) {
AccessEntry e; AccessEntry e;
...@@ -252,7 +254,11 @@ void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) { ...@@ -252,7 +254,11 @@ void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) {
this->VisitExpr(op->condition); this->VisitExpr(op->condition);
PrimExpr real_condition = ExtractRealCondition(op->condition); PrimExpr real_condition = ExtractRealCondition(op->condition);
curr_stmt_.access.clear(); // Preserve accesses collected from the condition expression so they
// participate in dependency analysis. Otherwise, a write to shared memory
// immediately followed by an if-condition reading that memory would not
// trigger a sync before the if-statement.
std::vector<AccessEntry> cond_access = std::move(curr_stmt_.access);
allow_append_ = false; allow_append_ = false;
scope_.push_back(std::vector<StmtEntry>()); scope_.push_back(std::vector<StmtEntry>());
...@@ -265,6 +271,11 @@ void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) { ...@@ -265,6 +271,11 @@ void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) {
s.stmt = op; s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr); s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back(); scope_.pop_back();
// Merge the condition's access summary into the if-statement's access list
// so the planner can insert a sync before the if when necessary.
if (!cond_access.empty()) {
s.access.insert(s.access.begin(), cond_access.begin(), cond_access.end());
}
if (op->else_case) { if (op->else_case) {
scope_.push_back(std::vector<StmtEntry>()); scope_.push_back(std::vector<StmtEntry>());
{ {
...@@ -301,14 +312,32 @@ void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) { ...@@ -301,14 +312,32 @@ void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) {
} }
void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
// Mark async TMA load context so that tvm_access_ptr within the call
// can be tagged accordingly.
auto is_tma_load = [&]() {
if (auto opt = op->op.as<Op>()) {
const Op &call_op = opt.value();
return call_op.same_as(tl::tma_load()) ||
call_op.same_as(tl::tma_load_im2col());
}
return false;
}();
if (is_tma_load) {
tma_depth_++;
for (const auto &a : op->args) {
this->VisitExpr(a);
}
tma_depth_--;
return;
}
if (op->op.same_as(builtin::address_of())) { if (op->op.same_as(builtin::address_of())) {
ICHECK_EQ(op->args.size(), 1U); ICHECK_EQ(op->args.size(), 1U);
if (auto load = op->args[0].as<BufferLoadNode>()) { if (auto load = op->args[0].as<BufferLoadNode>()) {
Buffer buffer = load->buffer; Buffer buffer = load->buffer;
DataType dtype = buffer->dtype; DataType dtype = buffer->dtype;
const VarNode *buffer_var = buffer->data.as<VarNode>(); const VarNode *buffer_var = buffer->data.as<VarNode>();
buffer_data_to_buffer_.Set(GetRef<Var>(buffer_var), buffer); buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buffer_var), buffer);
StorageScope scope = GetScope(GetRef<Var>(buffer_var)); StorageScope scope = GetScope(tvm::ffi::GetRef<Var>(buffer_var));
Array<Range> buffer_ranges; Array<Range> buffer_ranges;
// from indices to buffer indices // from indices to buffer indices
ICHECK(buffer->shape.size() == load->indices.size()); ICHECK(buffer->shape.size() == load->indices.size());
...@@ -346,17 +375,18 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -346,17 +375,18 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
PrimExpr offset = op->args[2]; PrimExpr offset = op->args[2];
PrimExpr extent = op->args[3]; PrimExpr extent = op->args[3];
const IntImmNode *flag = op->args[4].as<IntImmNode>(); const IntImmNode *flag = op->args[4].as<IntImmNode>();
StorageScope scope = GetScope(GetRef<Var>(buffer_var)); StorageScope scope = GetScope(tvm::ffi::GetRef<Var>(buffer_var));
// The buffer scope. // The buffer scope.
if (Enabled(buffer_var, scope)) { if (Enabled(buffer_var, scope)) {
ICHECK(allow_append_); ICHECK(allow_append_);
Array<Range> buffer_ranges; Array<Range> buffer_ranges;
if (buffer_data_to_buffer_.find(GetRef<Var>(buffer_var)) == if (buffer_data_to_buffer_.find(tvm::ffi::GetRef<Var>(buffer_var)) ==
buffer_data_to_buffer_.end()) { buffer_data_to_buffer_.end()) {
// cannot find buffer map, use the default buffer // cannot find buffer map, use the default buffer
buffer_ranges = {Range::FromMinExtent(offset, extent)}; buffer_ranges = {Range::FromMinExtent(offset, extent)};
} else { } else {
Buffer buffer = buffer_data_to_buffer_.at(GetRef<Var>(buffer_var)); Buffer buffer =
buffer_data_to_buffer_.at(tvm::ffi::GetRef<Var>(buffer_var));
auto buffer_shape = buffer->shape; auto buffer_shape = buffer->shape;
// convert 1d offset to multi-dimensional index // convert 1d offset to multi-dimensional index
auto linear_to_indices = [this](PrimExpr offset, auto linear_to_indices = [this](PrimExpr offset,
...@@ -387,7 +417,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -387,7 +417,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
e.threads = env_threads(); e.threads = env_threads();
e.thread_range = this->ComputeThreadRange(e.threads); e.thread_range = this->ComputeThreadRange(e.threads);
e.dtype = dtype; e.dtype = dtype;
e.buffer = GetRef<Var>(buffer_var); e.buffer = tvm::ffi::GetRef<Var>(buffer_var);
e.buffer_ranges = buffer_ranges; e.buffer_ranges = buffer_ranges;
e.is_pointer_access = true; e.is_pointer_access = true;
e.touched = { e.touched = {
...@@ -395,10 +425,12 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -395,10 +425,12 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
e.scope = scope; e.scope = scope;
if (flag->value & 1) { if (flag->value & 1) {
e.type = kRead; e.type = kRead;
e.is_async_copy = (tma_depth_ > 0);
curr_stmt_.access.emplace_back(e); curr_stmt_.access.emplace_back(e);
} }
if (flag->value & 2) { if (flag->value & 2) {
e.type = kWrite; e.type = kWrite;
e.is_async_copy = (tma_depth_ > 0);
curr_stmt_.access.emplace_back(e); curr_stmt_.access.emplace_back(e);
} }
} }
......
...@@ -39,6 +39,7 @@ namespace tvm { ...@@ -39,6 +39,7 @@ namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace ffi;
using arith::IRVisitorWithAnalyzer; using arith::IRVisitorWithAnalyzer;
using runtime::StorageRank; using runtime::StorageRank;
using runtime::StorageScope; using runtime::StorageScope;
...@@ -83,6 +84,10 @@ public: ...@@ -83,6 +84,10 @@ public:
bool double_buffer_write = false; bool double_buffer_write = false;
/*! \brief Whether the access is pointer access */ /*! \brief Whether the access is pointer access */
bool is_pointer_access = false; bool is_pointer_access = false;
/*! \brief Whether this access originates from an async copy context
* (e.g., inside a TMA load) and therefore multiple writes
* among themselves should not force barriers between them. */
bool is_async_copy = false;
}; };
/*! \brief Access pattern about a single statement */ /*! \brief Access pattern about a single statement */
...@@ -159,6 +164,8 @@ private: ...@@ -159,6 +164,8 @@ private:
bool allow_append_{false}; bool allow_append_{false};
// Whether we are in device environment // Whether we are in device environment
bool in_device_env_{false}; bool in_device_env_{false};
// Nesting depth of tma_load/tma_load_im2col calls
int tma_depth_{0};
// Whether we are inside condition. // Whether we are inside condition.
int condition_counter_{0}; int condition_counter_{0};
// The current double buffer write scope. // The current double buffer write scope.
......
...@@ -544,7 +544,7 @@ public: ...@@ -544,7 +544,7 @@ public:
} }
return it->second->alloc_var; return it->second->alloc_var;
} else { } else {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
} }
PrimExpr VisitExpr_(const CallNode *op) final { PrimExpr VisitExpr_(const CallNode *op) final {
...@@ -679,7 +679,7 @@ private: ...@@ -679,7 +679,7 @@ private:
return !scope.tag.empty() && scope.tag != ".dyn" && return !scope.tag.empty() && scope.tag != ".dyn" &&
scope.tag != ".barrier" && scope.tag != ".workspace" && scope.tag != ".barrier" && scope.tag != ".workspace" &&
scope.tag != ".vtcm" && scope.tag != ".var" && scope.tag != ".vtcm" && scope.tag != ".var" &&
scope.tag != ".descriptor"; scope.tag.find(".descriptor") != 0;
} }
// Allocate entry of node. // Allocate entry of node.
...@@ -865,7 +865,7 @@ private: ...@@ -865,7 +865,7 @@ private:
ICHECK_NE(e->const_nbits, 0U); ICHECK_NE(e->const_nbits, 0U);
MemoryInfo info; MemoryInfo info;
if (e->scope.tag != ".barrier" && e->scope.tag != ".var" && if (e->scope.tag != ".barrier" && e->scope.tag != ".var" &&
e->scope.tag != ".descriptor") { e->scope.tag.find(".descriptor") != 0) {
info = GetMemoryInfo(e->scope.to_string()); info = GetMemoryInfo(e->scope.to_string());
} }
uint64_t total_bits = e->const_nbits; uint64_t total_bits = e->const_nbits;
...@@ -978,8 +978,8 @@ private: ...@@ -978,8 +978,8 @@ private:
ICHECK(alloc_info.count(var)); ICHECK(alloc_info.count(var));
const AllocEntry &entry = alloc_info.at(var); const AllocEntry &entry = alloc_info.at(var);
const AllocateNode *alloc = entry.alloc; const AllocateNode *alloc = entry.alloc;
auto storage_scope = auto storage_scope = StorageScope::Create(
StorageScope::Create(GetPtrStorageScope(GetRef<Var>(var))); GetPtrStorageScope(tvm::ffi::GetRef<Var>(var)));
StorageEntry *dst_entry = nullptr; StorageEntry *dst_entry = nullptr;
// inplace detection // inplace detection
if (detect_inplace) { if (detect_inplace) {
...@@ -1425,9 +1425,30 @@ public: ...@@ -1425,9 +1425,30 @@ public:
void void
OnArrayDeclaration(const Var &buffer, DataType element_dtype, PrimExpr extent, OnArrayDeclaration(const Var &buffer, DataType element_dtype, PrimExpr extent,
BufferVarInfo::DeclarationLocation declaration_location) { BufferVarInfo::DeclarationLocation declaration_location) {
ICHECK(info_map_.find(buffer.get()) == info_map_.end()) auto it = info_map_.find(buffer.get());
<< "Array declaration of " << buffer->name_hint if (it != info_map_.end()) {
<< " occurred multiple times."; // The same buffer var may appear in more than one Allocate due to
// upstream transforms (e.g., storage planning/merging). Treat repeated
// declarations as benign and merge metadata instead of erroring.
BufferVarInfo &existing = it->second;
// Prefer a concrete element dtype if the previous one was a handle.
if (existing.element_dtype.is_handle() && !element_dtype.is_handle()) {
existing.element_dtype =
element_dtype == DataType::Bool()
? DataType::Int(8).with_lanes(element_dtype.lanes())
: element_dtype;
}
// If extent was previously unknown (0) and a concrete extent is
// provided now, record it.
if (!existing.extent.defined() || is_zero(existing.extent)) {
existing.extent = extent;
}
// Merge declaration locations (bitwise OR of flags).
existing.declaration_location =
static_cast<BufferVarInfo::DeclarationLocation>(
existing.declaration_location | declaration_location);
return;
}
if (element_dtype == DataType::Bool()) { if (element_dtype == DataType::Bool()) {
element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes()); element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
...@@ -1732,7 +1753,7 @@ public: ...@@ -1732,7 +1753,7 @@ public:
Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var; Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var;
if (var.same_as(op->var) && value.same_as(op->value) && if (var.same_as(op->var) && value.same_as(op->value) &&
body.same_as(op->body)) { body.same_as(op->body)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} }
return LetStmt(var, value, body); return LetStmt(var, value, body);
} }
...@@ -1985,10 +2006,10 @@ Pass StorageRewrite() { ...@@ -1985,10 +2006,10 @@ Pass StorageRewrite() {
return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.StorageRewrite", StorageRewrite); refl::GlobalDef().def("tl.transform.StorageRewrite", StorageRewrite);
}); }
Pass PointerValueTypeRewrite() { Pass PointerValueTypeRewrite() {
auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) { auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
...@@ -1997,11 +2018,11 @@ Pass PointerValueTypeRewrite() { ...@@ -1997,11 +2018,11 @@ Pass PointerValueTypeRewrite() {
return CreatePrimFuncPass(pass_func, 0, "tl.PointerValueTypeRewrite", {}); return CreatePrimFuncPass(pass_func, 0, "tl.PointerValueTypeRewrite", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PointerValueTypeRewrite", refl::GlobalDef().def("tl.transform.PointerValueTypeRewrite",
PointerValueTypeRewrite); PointerValueTypeRewrite);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -86,6 +86,7 @@ protected: ...@@ -86,6 +86,7 @@ protected:
// check if sync before statement is needed. // check if sync before statement is needed.
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0); bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
// Apply the syncs added already. // Apply the syncs added already.
if (sync_before_stmt) { if (sync_before_stmt) {
reads.clear(); reads.clear();
writes.clear(); writes.clear();
...@@ -98,7 +99,8 @@ protected: ...@@ -98,7 +99,8 @@ protected:
break; break;
} }
} else if (acc.type == kWrite) { } else if (acc.type == kWrite) {
if (FindConflict(reads, acc, false)) { if (FindConflict(reads, acc, false) ||
FindConflict(writes, acc, false)) {
sync_before_stmt = true; sync_before_stmt = true;
break; break;
} }
...@@ -123,27 +125,51 @@ protected: ...@@ -123,27 +125,51 @@ protected:
writes.clear(); writes.clear();
} }
} }
if (sync_before_stmt) { if (sync_before_stmt) {
insert_syncs(s.stmt); insert_syncs(s.stmt);
} }
} }
if (loop != nullptr) { if (loop != nullptr) {
// Check if the loop body contains any reads in the same sync scope.
// If there are reads, we conservatively keep the sync within the loop
// body to preserve per-iteration ordering when needed. If there are no
// reads (e.g., only writes to shared.dyn), we can safely hoist the sync
// to before the loop to avoid redundant barriers.
bool has_read_in_scope = false;
for (const StmtEntry &s : seq) {
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead && acc.scope == sync_scope_) {
has_read_in_scope = true;
break;
}
}
if (has_read_in_scope)
break;
}
// If there is a loop-carried dependency, insert a single sync
// before the loop rather than hoisting a sync into the loop body.
// This reduces redundant per-iteration synchronizations for cases
// where each iteration touches disjoint regions (e.g., stmatrix
// writes to shared.dyn) and only a global ordering before/after the
// loop is required.
for (size_t i = 0; i < seq.size(); ++i) { for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry &s = seq[i]; const StmtEntry &s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0) if (syncs_inserted_.count(s.stmt) != 0)
break; break;
if (reads.empty() && writes.empty()) if (reads.empty() && writes.empty())
break; break;
bool sync_before_stmt = false; bool need_loop_sync = false;
for (const AccessEntry &acc : s.access) { for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) { if (acc.type == kRead) {
if (FindConflict(writes, acc, true)) { if (FindConflict(writes, acc, true)) {
sync_before_stmt = true; need_loop_sync = true;
break; break;
} }
} else if (acc.type == kWrite) { } else if (acc.type == kWrite) {
if (FindConflict(reads, acc, true)) { if (FindConflict(reads, acc, true) ||
sync_before_stmt = true; FindConflict(writes, acc, true)) {
need_loop_sync = true;
break; break;
} }
} else if (acc.type == kSync) { } else if (acc.type == kSync) {
...@@ -151,8 +177,17 @@ protected: ...@@ -151,8 +177,17 @@ protected:
writes.clear(); writes.clear();
} }
} }
if (sync_before_stmt) { if (need_loop_sync) {
insert_syncs(s.stmt); if (!has_read_in_scope) {
// Mark the loop itself to receive a sync before it, instead of
// inserting inside the loop body. This ensures a single sync is
// emitted outside the loop and avoids per-iteration overhead.
insert_syncs(loop);
} else {
// Fall back to inserting before the first conflicting statement
// inside the loop to maintain correctness when reads are present.
insert_syncs(s.stmt);
}
break; break;
} }
} }
...@@ -217,6 +252,14 @@ private: ...@@ -217,6 +252,14 @@ private:
bool FindConflict(const AccessEntry &prev, const AccessEntry &curr, bool FindConflict(const AccessEntry &prev, const AccessEntry &curr,
bool loop_carry) { bool loop_carry) {
// Special case: ignore conflicts between async-copy writes (e.g., TMA
// loads into shared memory). Multiple async writes do not require
// interspersed barriers among themselves. We still respect conflicts with
// reads to ensure visibility before consumption.
if (prev.type == kWrite && curr.type == kWrite && prev.is_async_copy &&
curr.is_async_copy) {
return false;
}
// Access to different buffers does not conflict. // Access to different buffers does not conflict.
if (!prev.buffer.same_as(curr.buffer)) { if (!prev.buffer.same_as(curr.buffer)) {
return false; return false;
...@@ -241,10 +284,15 @@ private: ...@@ -241,10 +284,15 @@ private:
return true; return true;
} }
if (prev.is_pointer_access || curr.is_pointer_access) { if (prev.is_pointer_access || curr.is_pointer_access) {
// If either access is a pointer access, conservatively assume a // For accesses created via tvm_access_ptr we may still be able to prove
// conflict. For example, address_of(A[0, 0]) may refer to an unknown // disjointness using their byte ranges. If both sides expose a touched
// memory region, so we cannot safely determine if it overlaps with // interval and we can show they don't overlap, skip the conflict.
// previous accesses. if (prev.is_pointer_access && curr.is_pointer_access &&
PointerAccessIsDisjoint(prev, curr)) {
return false;
}
// Otherwise fall back to the conservative answer: treat them as
// overlapping.
return true; return true;
} }
...@@ -327,7 +375,7 @@ private: ...@@ -327,7 +375,7 @@ private:
} }
} }
if (!(has_same_index)) { if (!has_same_index) {
break; break;
} }
} }
...@@ -350,6 +398,26 @@ private: ...@@ -350,6 +398,26 @@ private:
return range_is_overlap; return range_is_overlap;
} }
bool PointerAccessIsDisjoint(const AccessEntry &lhs, const AccessEntry &rhs) {
if (lhs.touched.size() != 1 || rhs.touched.size() != 1) {
return false;
}
PrimExpr lhs_min = analyzer_.Simplify(lhs.touched[0].min());
PrimExpr lhs_max = analyzer_.Simplify(lhs.touched[0].max());
PrimExpr rhs_min = analyzer_.Simplify(rhs.touched[0].min());
PrimExpr rhs_max = analyzer_.Simplify(rhs.touched[0].max());
if (analyzer_.CanProve(lhs_max < rhs_min,
arith::ProofStrength::kSymbolicBound)) {
return true;
}
if (analyzer_.CanProve(rhs_max < lhs_min,
arith::ProofStrength::kSymbolicBound)) {
return true;
}
return false;
}
void VisitStmt_(const AttrStmtNode *op) final { void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::tir::attr::thread_extent) { if (op->attr_key == tvm::tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
...@@ -782,10 +850,10 @@ tvm::transform::Pass ThreadSync(const String &storage_scope) { ...@@ -782,10 +850,10 @@ tvm::transform::Pass ThreadSync(const String &storage_scope) {
return CreatePrimFuncPass(pass_func, 0, "tl.ThreadSync", {}); return CreatePrimFuncPass(pass_func, 0, "tl.ThreadSync", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ThreadSync", ThreadSync); refl::GlobalDef().def("tl.transform.ThreadSync", ThreadSync);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -43,6 +44,7 @@ namespace tvm { ...@@ -43,6 +44,7 @@ namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace ffi;
/*! /*!
* \brief Perform data type legalization on the given BufferLoadNode pointer. * \brief Perform data type legalization on the given BufferLoadNode pointer.
...@@ -208,6 +210,14 @@ public: ...@@ -208,6 +210,14 @@ public:
using ExprFunctor::VisitExpr; using ExprFunctor::VisitExpr;
using StmtMutator::operator(); using StmtMutator::operator();
// Convenience entry to vectorize a loop body without exposing
// the mutator invocation pattern at call sites.
static Stmt Vectorize(const Var &var, const PrimExpr &var_lanes, Stmt body) {
TLVectorizer vec{var, var_lanes};
auto vec_stmt = vec(std::move(body));
return vec_stmt;
}
TLVectorizer(const Var &var, const PrimExpr &var_lanes) TLVectorizer(const Var &var, const PrimExpr &var_lanes)
: var_(var), var_lanes_(var_lanes) { : var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes); ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
...@@ -217,8 +227,9 @@ public: ...@@ -217,8 +227,9 @@ public:
ICHECK(!need_scalarize_); ICHECK(!need_scalarize_);
Stmt ret = StmtMutator::VisitStmt(stmt); Stmt ret = StmtMutator::VisitStmt(stmt);
if (need_scalarize_) { if (need_scalarize_) {
auto scalarized_stmt = Scalarize(stmt);
need_scalarize_ = false; need_scalarize_ = false;
return Scalarize(stmt); return scalarized_stmt;
} else { } else {
return ret; return ret;
} }
...@@ -242,7 +253,7 @@ public: ...@@ -242,7 +253,7 @@ public:
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector(); bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector(); bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
...@@ -296,7 +307,7 @@ public: ...@@ -296,7 +307,7 @@ public:
PrimExpr VisitExpr_(const NotNode *op) final { PrimExpr VisitExpr_(const NotNode *op) final {
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) { if (a.same_as(op->a)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return !(a); return !(a);
} }
...@@ -337,10 +348,10 @@ public: ...@@ -337,10 +348,10 @@ public:
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (value.dtype().is_scalable_or_fixed_length_vector()) { if (value.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Broadcast(op->value, op->lanes); return Broadcast(op->value, op->lanes);
} }
...@@ -352,7 +363,7 @@ public: ...@@ -352,7 +363,7 @@ public:
PrimExpr f = this->VisitExpr(op->false_value); PrimExpr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) && t.same_as(op->true_value) && if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
f.same_as(op->false_value)) { f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor();
...@@ -370,7 +381,7 @@ public: ...@@ -370,7 +381,7 @@ public:
PrimExpr VisitExpr_(const CastNode *op) final { PrimExpr VisitExpr_(const CastNode *op) final {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
if (value.dtype().is_scalable_vector()) { if (value.dtype().is_scalable_vector()) {
return Cast(op->dtype.with_scalable_vscale_factor( return Cast(op->dtype.with_scalable_vscale_factor(
...@@ -383,26 +394,26 @@ public: ...@@ -383,26 +394,26 @@ public:
} }
PrimExpr VisitExpr_(const FloatImmNode *op) final { PrimExpr VisitExpr_(const FloatImmNode *op) final {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const IntImmNode *op) final { PrimExpr VisitExpr_(const IntImmNode *op) final {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const StringImmNode *op) final { PrimExpr VisitExpr_(const StringImmNode *op) final {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
// Variable // Variable
PrimExpr VisitExpr_(const VarNode *op) final { PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op); Var var = tvm::ffi::GetRef<Var>(op);
if (var.same_as(var_)) { if (var.same_as(var_)) {
return ramp_; return ramp_;
} }
auto it = let_binding_.find(var); auto it = let_var_map_.find(var);
if (it != let_binding_.end()) { if (it != let_var_map_.end()) {
return it->second; return it->second;
} else { } else {
return std::move(var); return std::move(var);
...@@ -413,13 +424,13 @@ public: ...@@ -413,13 +424,13 @@ public:
PrimExpr cond = this->VisitExpr(op->args[0]); PrimExpr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_scalable_or_fixed_length_vector()) { if (cond.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr t = this->VisitExpr(op->args[1]); PrimExpr t = this->VisitExpr(op->args[1]);
PrimExpr f = this->VisitExpr(op->args[2]); PrimExpr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
f.same_as(op->args[2])) { f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor(); int f_lanes = f.dtype().get_lanes_or_vscale_factor();
...@@ -441,7 +452,7 @@ public: ...@@ -441,7 +452,7 @@ public:
ICHECK(op->op.same_as(builtin::reinterpret())); ICHECK(op->op.same_as(builtin::reinterpret()));
PrimExpr value = this->VisitExpr(op->args[0]); PrimExpr value = this->VisitExpr(op->args[0]);
if (value.same_as(op->args[0])) { if (value.same_as(op->args[0])) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int lanes = value.dtype().get_lanes_or_vscale_factor(); int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) { if (value.dtype().is_scalable_vector()) {
...@@ -478,7 +489,6 @@ public: ...@@ -478,7 +489,6 @@ public:
bool vectorizable = optional_op && bool vectorizable = optional_op &&
op_vectorizable_.get(optional_op.value(), false) && op_vectorizable_.get(optional_op.value(), false) &&
!op->dtype.is_scalable_vector(); !op->dtype.is_scalable_vector();
if (!vectorizable) { if (!vectorizable) {
// Cannot vectorize this op // Cannot vectorize this op
Array<PrimExpr> new_args; Array<PrimExpr> new_args;
...@@ -486,12 +496,12 @@ public: ...@@ -486,12 +496,12 @@ public:
auto new_arg = this->VisitExpr(arg); auto new_arg = this->VisitExpr(arg);
if (new_arg.dtype().is_scalable_or_fixed_length_vector()) { if (new_arg.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
new_args.push_back(new_arg); new_args.push_back(new_arg);
} }
if (op->args.same_as(new_args)) { if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Call(op->dtype, op->op, new_args); return Call(op->dtype, op->op, new_args);
} }
...@@ -500,7 +510,7 @@ public: ...@@ -500,7 +510,7 @@ public:
Array<PrimExpr> new_args = MutateArray(op->args, &lane); Array<PrimExpr> new_args = MutateArray(op->args, &lane);
// normal code path. // normal code path.
if (op->args.same_as(new_args)) { if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Call(op->dtype.with_lanes(lane), op->op, new_args); return Call(op->dtype.with_lanes(lane), op->op, new_args);
} }
...@@ -508,7 +518,7 @@ public: ...@@ -508,7 +518,7 @@ public:
} }
// BufferLoad // BufferLoad
PrimExpr VisitExpr_(const BufferLoadNode *op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = GetRef<BufferLoad>(op); auto load = tvm::ffi::GetRef<BufferLoad>(op);
auto fmutate = [this](const PrimExpr &index) { auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index); return this->VisitExpr(index);
...@@ -518,7 +528,6 @@ public: ...@@ -518,7 +528,6 @@ public:
if (!indices.same_as(op->indices)) { if (!indices.same_as(op->indices)) {
BufferLoadNode *writer = load.CopyOnWrite(); BufferLoadNode *writer = load.CopyOnWrite();
writer->indices = indices; writer->indices = indices;
// writer->LegalizeDType();
LegalizeBufferLoadDType(writer); LegalizeBufferLoadDType(writer);
} }
...@@ -533,21 +542,23 @@ public: ...@@ -533,21 +542,23 @@ public:
// This is used to allow cases when we reuse a single let // This is used to allow cases when we reuse a single let
// expression to construct a nested expr. // expression to construct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1) // (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto it = let_binding_.find(op->var); auto it = let_var_map_.find(op->var);
if (it != let_binding_.end()) { if (it != let_var_map_.end()) {
ICHECK(deep_equal_(it->second, value)) ICHECK(deep_equal_(it->second, value))
<< "Let cannot bind the same var to two different values"; << "Let cannot bind the same var to two different values";
} }
if (value.dtype().get_lanes_or_vscale_factor() != if (value.dtype().get_lanes_or_vscale_factor() !=
op->value.dtype().get_lanes_or_vscale_factor()) { op->value.dtype().get_lanes_or_vscale_factor()) {
Var new_var(op->var->name_hint, value.dtype()); Var new_var(op->var->name_hint, value.dtype());
let_binding_[op->var] = new_var; let_var_map_[op->var] = new_var;
// Record mapping from the new var to its bound value
let_value_binding_[new_var] = value;
return Let(new_var, value, this->VisitExpr(op->body)); return Let(new_var, value, this->VisitExpr(op->body));
} else { } else {
let_binding_[op->var] = op->var; let_var_map_[op->var] = op->var;
PrimExpr body = this->VisitExpr(op->body); PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) { if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Let(op->var, value, body); return Let(op->var, value, body);
} }
...@@ -555,7 +566,7 @@ public: ...@@ -555,7 +566,7 @@ public:
} }
// BufferStore // BufferStore
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = GetRef<BufferStore>(op); auto store = tvm::ffi::GetRef<BufferStore>(op);
auto fmutate = [this](const PrimExpr &index) { auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index); return this->VisitExpr(index);
...@@ -618,11 +629,11 @@ public: ...@@ -618,11 +629,11 @@ public:
ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector()); ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector());
PrimExpr extent = this->VisitExpr(op->extent); PrimExpr extent = this->VisitExpr(op->extent);
if (extent.dtype().is_scalable_or_fixed_length_vector()) { if (extent.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
Stmt body = this->VisitStmt(op->body); Stmt body = this->VisitStmt(op->body);
if (extent.same_as(op->extent) && body.same_as(op->body)) { if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return For(op->loop_var, op->min, extent, op->kind, body, return For(op->loop_var, op->min, extent, op->kind, body,
op->thread_binding, op->annotations); op->thread_binding, op->annotations);
...@@ -633,7 +644,7 @@ public: ...@@ -633,7 +644,7 @@ public:
ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector()); ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
PrimExpr condition = this->VisitExpr(op->condition); PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) { if (condition.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
Stmt then_case = this->VisitStmt(op->then_case); Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = std::nullopt; Optional<Stmt> else_case = std::nullopt;
...@@ -642,7 +653,7 @@ public: ...@@ -642,7 +653,7 @@ public:
} }
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) { else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return IfThenElse(condition, then_case, else_case); return IfThenElse(condition, then_case, else_case);
} }
...@@ -654,20 +665,23 @@ public: ...@@ -654,20 +665,23 @@ public:
// LetStmt // LetStmt
Stmt VisitStmt_(const LetStmtNode *op) final { Stmt VisitStmt_(const LetStmtNode *op) final {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
ICHECK(!let_binding_.count(op->var)) ICHECK(!let_var_map_.count(op->var))
<< "SSA violation, a single var is binded twice"; << "SSA violation, a single var is binded twice";
let_binding_[op->var] = value;
if (value.dtype().get_lanes_or_vscale_factor() != if (value.dtype().get_lanes_or_vscale_factor() !=
op->value.dtype().get_lanes_or_vscale_factor()) { op->value.dtype().get_lanes_or_vscale_factor()) {
Var new_var(op->var->name_hint, value.dtype()); Var new_var(op->var->name_hint, value.dtype());
let_binding_[op->var] = new_var; let_var_map_[op->var] = new_var;
// Record mapping from the new var to its bound value
let_value_binding_[op->var] = op->value;
let_value_binding_[new_var] = value;
return LetStmt(new_var, value, this->VisitStmt(op->body)); return LetStmt(new_var, value, this->VisitStmt(op->body));
} else { } else {
let_binding_[op->var] = op->var; let_var_map_[op->var] = op->var;
let_value_binding_[op->var] = value;
Stmt body = this->VisitStmt(op->body); Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) { if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return LetStmt(op->var, value, body); return LetStmt(op->var, value, body);
} }
...@@ -681,7 +695,7 @@ public: ...@@ -681,7 +695,7 @@ public:
if (condition.dtype().is_scalable_or_fixed_length_vector()) { if (condition.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint; << op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
return StmtMutator::VisitStmt_(op); return StmtMutator::VisitStmt_(op);
...@@ -689,8 +703,27 @@ public: ...@@ -689,8 +703,27 @@ public:
// scalarize the statement // scalarize the statement
Stmt Scalarize(Stmt stmt) { Stmt Scalarize(Stmt stmt) {
Var idx(var_->name_hint + ".s", var_->dtype); Var idx(var_->name_hint + "_s", var_->dtype);
// Find all Vars in stmt that are keys in let_value_binding_
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_let_bound_vars;
PostOrderVisit(stmt, [this, &used_let_bound_vars](const ObjectRef &node) {
if (const auto *v = node.as<VarNode>()) {
Var var = GetRef<Var>(v);
if (let_value_binding_.count(var)) {
used_let_bound_vars.insert(var);
}
}
});
stmt = Substitute(stmt, {{var_, idx}}); stmt = Substitute(stmt, {{var_, idx}});
if (!used_let_bound_vars.empty()) {
for (const auto &v : used_let_bound_vars) {
// Bind the existing var v to its value around the stmt scope
auto new_value = Substitute(let_value_binding_.at(v), {{var_, idx}});
stmt = LetStmt(v, new_value, stmt);
}
}
return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt); return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt);
} }
...@@ -707,8 +740,11 @@ private: ...@@ -707,8 +740,11 @@ private:
PrimExpr ramp_; PrimExpr ramp_;
// flag to mark requirement of scalarization. // flag to mark requirement of scalarization.
bool need_scalarize_{false}; bool need_scalarize_{false};
// Let binding // Let var mapping
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_; std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_var_map_;
// Let value binding: map new_var -> value
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>
let_value_binding_;
// vectorizable property // vectorizable property
OpAttrMap<TVectorizable> op_vectorizable_ = OpAttrMap<TVectorizable> op_vectorizable_ =
Op::GetAttrMap<TVectorizable>("TVectorizable"); Op::GetAttrMap<TVectorizable>("TVectorizable");
...@@ -746,7 +782,7 @@ private: ...@@ -746,7 +782,7 @@ private:
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor();
...@@ -762,7 +798,7 @@ private: ...@@ -762,7 +798,7 @@ private:
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor();
...@@ -806,7 +842,7 @@ public: ...@@ -806,7 +842,7 @@ public:
<< " for target " << Target::Current(); << " for target " << Target::Current();
} }
ICHECK(is_zero(op->min)); ICHECK(is_zero(op->min));
return TLVectorizer(op->loop_var, op->extent)(op->body); return TLVectorizer::Vectorize(op->loop_var, op->extent, op->body);
} else { } else {
return StmtMutator::VisitStmt_(op); return StmtMutator::VisitStmt_(op);
} }
...@@ -842,10 +878,10 @@ tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) { ...@@ -842,10 +878,10 @@ tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) {
return CreatePrimFuncPass(pass_func, 0, "tl.VectorizeLoop", {}); return CreatePrimFuncPass(pass_func, 0, "tl.VectorizeLoop", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.VectorizeLoop", VectorizeLoop); refl::GlobalDef().def("tl.transform.VectorizeLoop", VectorizeLoop);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -159,7 +159,7 @@ public: ...@@ -159,7 +159,7 @@ public:
// Check reads from global // Check reads from global
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ GetRef<Stmt>(op)); /*body*/ tvm::ffi::GetRef<Stmt>(op));
auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto reads = access[0]; auto reads = access[0];
Role role = Role::kProducer; Role role = Role::kProducer;
...@@ -511,7 +511,7 @@ private: ...@@ -511,7 +511,7 @@ private:
annotations.Set(String("stmt_group"), Integer(1)); annotations.Set(String("stmt_group"), Integer(1));
auto original_node = (op->body).as<SeqStmtNode>(); auto original_node = (op->body).as<SeqStmtNode>();
if (!original_node) { if (!original_node) {
return GetRef<For>(op); return tvm::ffi::GetRef<For>(op);
} }
Array<Stmt> new_body; Array<Stmt> new_body;
int cur_id = 0; int cur_id = 0;
...@@ -646,7 +646,7 @@ private: ...@@ -646,7 +646,7 @@ private:
if (role == Role::kBoth) { if (role == Role::kBoth) {
return StmtMutator::VisitStmt_(op); return StmtMutator::VisitStmt_(op);
} else if ((role == Role::kProducer) == is_emitting_producer_) { } else if ((role == Role::kProducer) == is_emitting_producer_) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return Evaluate(0); return Evaluate(0);
} }
...@@ -1284,7 +1284,7 @@ tvm::transform::Pass WarpSpecialized() { ...@@ -1284,7 +1284,7 @@ tvm::transform::Pass WarpSpecialized() {
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized, return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized,
disable_shuffle_elect); disable_shuffle_elect);
} else { } else {
ObjectRef node = String("default"); auto node = ffi::String("default");
f.CopyOnWrite()->body = f.CopyOnWrite()->body =
AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body); AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body);
return f; return f;
...@@ -1293,10 +1293,10 @@ tvm::transform::Pass WarpSpecialized() { ...@@ -1293,10 +1293,10 @@ tvm::transform::Pass WarpSpecialized() {
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {}); return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized); refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -266,10 +266,10 @@ tvm::transform::Pass RewriteWgmmaSync() { ...@@ -266,10 +266,10 @@ tvm::transform::Pass RewriteWgmmaSync() {
return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {}); return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.RewriteWgmmaSync", RewriteWgmmaSync); refl::GlobalDef().def("tl.transform.RewriteWgmmaSync", RewriteWgmmaSync);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -22,15 +22,6 @@ def tl_matmul( ...@@ -22,15 +22,6 @@ def tl_matmul(
b_transposed=True, b_transposed=True,
k_pack=1, k_pack=1,
): ):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
...@@ -190,6 +181,9 @@ def assert_tl_matmul_correctness(M, ...@@ -190,6 +181,9 @@ def assert_tl_matmul_correctness(M,
if in_dtype == "int8": if in_dtype == "int8":
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8) A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8) B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
elif in_dtype == "float8_e4m3fnuz":
A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
else: else:
A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype)) A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))
......
...@@ -217,6 +217,9 @@ def assert_tl_matmul_correctness(M, ...@@ -217,6 +217,9 @@ def assert_tl_matmul_correctness(M,
if in_dtype == "int8": if in_dtype == "int8":
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8) A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8) B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
elif in_dtype == "float8_e4m3fnuz":
A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
else: else:
A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype)) A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))
...@@ -264,11 +267,11 @@ def assert_tl_matmul_correctness(M, ...@@ -264,11 +267,11 @@ def assert_tl_matmul_correctness(M,
@tilelang.testing.requires_rocm @tilelang.testing.requires_rocm
def test_assert_tl_matmul(): def test_assert_tl_matmul():
assert_tl_matmul_correctness( assert_tl_matmul_correctness(
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True) 256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness( assert_tl_matmul_correctness(
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True) 256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness( assert_tl_matmul_correctness(
256, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) 256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness( assert_tl_matmul_correctness(
256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True) 256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
...@@ -283,6 +286,21 @@ def test_assert_tl_matmul(): ...@@ -283,6 +286,21 @@ def test_assert_tl_matmul():
k_pack=2, k_pack=2,
b_preshuffle=True) b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(
256,
256,
512,
"float8_e4m3fnuz",
"float32",
k_pack=2,
b_transposed=False,
b_preshuffle=True)
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -223,29 +223,26 @@ def run_gemm_rs( ...@@ -223,29 +223,26 @@ def run_gemm_rs(
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
@tilelang.testing.requires_rocm # @tilelang.testing.requires_rocm
def test_gemm_rs_f16f32f32_nt(): # def test_gemm_rs_f16f32f32_nt():
run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32) # run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32) # run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32) # run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32) # run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)
# @tilelang.testing.requires_rocm
@tilelang.testing.requires_rocm # def test_gemm_rs_bf16f32f32_nt():
def test_gemm_rs_bf16f32f32_nt(): # run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32) # run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32) # run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32) # run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
# @tilelang.testing.requires_rocm
# def test_gemm_rs_bf16bf16f32_nt():
@tilelang.testing.requires_rocm # run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
def test_gemm_rs_bf16bf16f32_nt(): # run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32) # run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32) # run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -514,4 +514,5 @@ def test_assert_tl_matmul_block_all_dynamic_with_pass_config(): ...@@ -514,4 +514,5 @@ def test_assert_tl_matmul_block_all_dynamic_with_pass_config():
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() # tilelang.testing.main()
assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
import torch
import tilelang
import tilelang.testing
from tilelang import language as T
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _fill_with_static_region_kernel():
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
T.fill(x[0:128], 0)
return buggy_kernel
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _fill_with_dynamic_region_kernel():
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
a, b = T.alloc_var('int'), T.alloc_var('int')
T.fill(x[a:b], 0)
return buggy_kernel
def test_fill_with_static_region_kernel():
kernel = _fill_with_static_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
kernel(x)
def test_fill_with_dynamic_region_kernel():
kernel = _fill_with_dynamic_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
kernel(x)
if __name__ == '__main__':
tilelang.testing.main()
import torch
import tilelang
import tilelang.language as T
def test_int64_address():
@tilelang.jit
def set_cache_kernel(
S,
D,
pos_ty='int64',
dtype="float32",
):
@T.prim_func
def main(
pos: T
.Tensor(
[
S,
], pos_ty
), # type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32`
value: T.Tensor([S, D], dtype), # type: ignore
cache: T.Tensor([S, D], dtype), # type: ignore
):
with T.Kernel(S, threads=128) as bx:
slot = pos[bx]
for i in T.Parallel(D):
cache[slot, i] = value[bx, i]
return main
D = 2
S = 10
cache = torch.rand((S, D), device="cuda", dtype=torch.float32)
value = torch.rand((S, D), device='cuda', dtype=torch.float32)
pos_int64 = torch.arange(S, device='cuda', dtype=torch.int64)
pos_int32 = torch.arange(S, device='cuda', dtype=torch.int32)
kernel_int64 = set_cache_kernel(S, D, 'int64')
kernel_int32 = set_cache_kernel(S, D, 'int32')
kernel_int64(pos_int64, value, cache)
torch.testing.assert_close(cache, value)
kernel_int32(pos_int32, value, cache)
torch.testing.assert_close(cache, value)
if __name__ == "__main__":
tilelang.testing.main()
import tilelang.testing
import tilelang.language as T
def test_issue_1198():
@T.prim_func
def foo(x: T.Buffer([
32,
], "int32")):
pass
if __name__ == '__main__':
tilelang.testing.main()
import tilelang
import tilelang.language as T
import tilelang.testing
def _make_kernel(M, N):
dtype = "bfloat16"
@T.prim_func
def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), "int32")):
with T.Kernel(4, threads=1):
A = T.alloc_shared([N], dtype)
B = T.alloc_shared([N], dtype)
# Regression for a bug where InjectSoftwarePipeline left the loop
# variable as a free var, causing MakePackedAPI to fail
for i in T.Pipelined(4, num_stages=1):
_id = ids[i]
T.copy(KV[_id, :], A)
T.clear(B)
return fwd_main
def test_make_packed_api_no_free_loop_var():
func = _make_kernel(4, 4)
# Keep warp-specialization/TMA disabled to match the original repro
cfg = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
}
tilelang.compile(func, pass_configs=cfg)
if __name__ == "__main__":
tilelang.testing.main()
import tilelang.testing
from tilelang import language as T
def test_issue_1237_dynamic_copy_extent_builds():
# Repro from debug/1113_issues/copy_dyn.py, adapted as a unit test.
# The goal is to ensure T.copy correctly handles dynamic extents
# (e.g., src slice length vs. static dst buffer size) during prim_func building.
length = T.symbolic("len", dtype="int32")
@T.prim_func
def sample_kernel(global_tensor: T.Tensor[(length,), "int32"]): # noqa: F821
with T.Kernel(1, threads=32):
buffer_shared = T.alloc_shared((1024,), dtype="int32")
T.copy(global_tensor[0:length], buffer_shared)
# Building the prim_func is sufficient to exercise the bug path; no need to JIT/execute.
_ = sample_kernel
if __name__ == "__main__":
tilelang.testing.main()
...@@ -85,7 +85,7 @@ def run_gemm( ...@@ -85,7 +85,7 @@ def run_gemm(
stramp = "&*(XS)" stramp = "&*(XS)"
@tvm.register_func("tilelang_callback_cuda_postproc", override=True) @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _): def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code code = f"// {stramp}\n" + code
return code return code
...@@ -407,4 +407,5 @@ def test_ctypes_dynamic_shape(): ...@@ -407,4 +407,5 @@ def test_ctypes_dynamic_shape():
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() # tilelang.testing.main()
test_gemm_f16f16f16_nn()
...@@ -85,7 +85,7 @@ def run_gemm( ...@@ -85,7 +85,7 @@ def run_gemm(
stramp = "&*(XS)" stramp = "&*(XS)"
@tvm.register_func("tilelang_callback_cuda_postproc", override=True) @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _): def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code code = f"// {stramp}\n" + code
return code return code
......
import tilelang.testing
import tilelang
import torch
@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
verbose=True,
)
def matmul_kernel_jit(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A=False,
trans_B=True,
in_dtype='float16',
out_dtype='float32',
accum_dtype='float32',
num_stages=2,
threads=128,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def test_par_compile():
configs = [
(1024, 1024, 1024, 128, 128, 32),
(2048, 2048, 2048, 256, 256, 64),
(4096, 4096, 4096, 64, 64, 128),
]
kernels = matmul_kernel_jit.par_compile(configs)
for (M, N, K, _, _, _), kernel in zip(configs, kernels):
A = torch.randn(M, K, dtype=torch.float16).cuda()
B = torch.randn(N, K, dtype=torch.float16).cuda()
ref = (A @ B.T).float()
C = kernel(A, B)
tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment