Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
......@@ -37,7 +37,7 @@
namespace tvm {
namespace tl {
using namespace ffi;
namespace tir = tvm::tir;
class HostDeviceSplitter : public tir::StmtMutator {
......@@ -200,10 +200,10 @@ tvm::transform::Pass SplitHostDevice() {
{});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice);
});
}
} // namespace transform
} // namespace tl
......
......@@ -29,6 +29,7 @@
#include <string>
#include <utility>
#include "../op/builtin.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
......@@ -38,10 +39,11 @@ using namespace tir;
void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) {
Var buf = op->buffer->data;
buffer_data_to_buffer_.Set(GetRef<Var>(buf.get()), op->buffer);
buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buf.get()), op->buffer);
StorageScope scope = GetScope(buf);
if (Enabled(buf.get(), scope)) {
ICHECK(allow_append_) << GetRef<BufferLoad>(op) << " " << scope.to_string();
ICHECK(allow_append_) << tvm::ffi::GetRef<BufferLoad>(op) << " "
<< scope.to_string();
AccessEntry e;
e.threads = env_threads();
e.thread_range = this->ComputeThreadRange(e.threads);
......@@ -65,7 +67,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) {
curr_stmt_.stmt = op;
Var buf = op->buffer->data;
buffer_data_to_buffer_.Set(GetRef<Var>(buf.get()), op->buffer);
buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buf.get()), op->buffer);
StorageScope scope = GetScope(buf);
if (Enabled(buf.get(), scope)) {
AccessEntry e;
......@@ -252,7 +254,11 @@ void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) {
this->VisitExpr(op->condition);
PrimExpr real_condition = ExtractRealCondition(op->condition);
curr_stmt_.access.clear();
// Preserve accesses collected from the condition expression so they
// participate in dependency analysis. Otherwise, a write to shared memory
// immediately followed by an if-condition reading that memory would not
// trigger a sync before the if-statement.
std::vector<AccessEntry> cond_access = std::move(curr_stmt_.access);
allow_append_ = false;
scope_.push_back(std::vector<StmtEntry>());
......@@ -265,6 +271,11 @@ void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) {
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
// Merge the condition's access summary into the if-statement's access list
// so the planner can insert a sync before the if when necessary.
if (!cond_access.empty()) {
s.access.insert(s.access.begin(), cond_access.begin(), cond_access.end());
}
if (op->else_case) {
scope_.push_back(std::vector<StmtEntry>());
{
......@@ -301,14 +312,32 @@ void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) {
}
void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
// Mark async TMA load context so that tvm_access_ptr within the call
// can be tagged accordingly.
auto is_tma_load = [&]() {
if (auto opt = op->op.as<Op>()) {
const Op &call_op = opt.value();
return call_op.same_as(tl::tma_load()) ||
call_op.same_as(tl::tma_load_im2col());
}
return false;
}();
if (is_tma_load) {
tma_depth_++;
for (const auto &a : op->args) {
this->VisitExpr(a);
}
tma_depth_--;
return;
}
if (op->op.same_as(builtin::address_of())) {
ICHECK_EQ(op->args.size(), 1U);
if (auto load = op->args[0].as<BufferLoadNode>()) {
Buffer buffer = load->buffer;
DataType dtype = buffer->dtype;
const VarNode *buffer_var = buffer->data.as<VarNode>();
buffer_data_to_buffer_.Set(GetRef<Var>(buffer_var), buffer);
StorageScope scope = GetScope(GetRef<Var>(buffer_var));
buffer_data_to_buffer_.Set(tvm::ffi::GetRef<Var>(buffer_var), buffer);
StorageScope scope = GetScope(tvm::ffi::GetRef<Var>(buffer_var));
Array<Range> buffer_ranges;
// from indices to buffer indices
ICHECK(buffer->shape.size() == load->indices.size());
......@@ -346,17 +375,18 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
PrimExpr offset = op->args[2];
PrimExpr extent = op->args[3];
const IntImmNode *flag = op->args[4].as<IntImmNode>();
StorageScope scope = GetScope(GetRef<Var>(buffer_var));
StorageScope scope = GetScope(tvm::ffi::GetRef<Var>(buffer_var));
// The buffer scope.
if (Enabled(buffer_var, scope)) {
ICHECK(allow_append_);
Array<Range> buffer_ranges;
if (buffer_data_to_buffer_.find(GetRef<Var>(buffer_var)) ==
if (buffer_data_to_buffer_.find(tvm::ffi::GetRef<Var>(buffer_var)) ==
buffer_data_to_buffer_.end()) {
// cannot find buffer map, use the default buffer
buffer_ranges = {Range::FromMinExtent(offset, extent)};
} else {
Buffer buffer = buffer_data_to_buffer_.at(GetRef<Var>(buffer_var));
Buffer buffer =
buffer_data_to_buffer_.at(tvm::ffi::GetRef<Var>(buffer_var));
auto buffer_shape = buffer->shape;
// convert 1d offset to multi-dimensional index
auto linear_to_indices = [this](PrimExpr offset,
......@@ -387,7 +417,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
e.threads = env_threads();
e.thread_range = this->ComputeThreadRange(e.threads);
e.dtype = dtype;
e.buffer = GetRef<Var>(buffer_var);
e.buffer = tvm::ffi::GetRef<Var>(buffer_var);
e.buffer_ranges = buffer_ranges;
e.is_pointer_access = true;
e.touched = {
......@@ -395,10 +425,12 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
e.scope = scope;
if (flag->value & 1) {
e.type = kRead;
e.is_async_copy = (tma_depth_ > 0);
curr_stmt_.access.emplace_back(e);
}
if (flag->value & 2) {
e.type = kWrite;
e.is_async_copy = (tma_depth_ > 0);
curr_stmt_.access.emplace_back(e);
}
}
......
......@@ -39,6 +39,7 @@ namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
using arith::IRVisitorWithAnalyzer;
using runtime::StorageRank;
using runtime::StorageScope;
......@@ -83,6 +84,10 @@ public:
bool double_buffer_write = false;
/*! \brief Whether the access is pointer access */
bool is_pointer_access = false;
/*! \brief Whether this access originates from an async copy context
* (e.g., inside a TMA load) and therefore multiple writes
* among themselves should not force barriers between them. */
bool is_async_copy = false;
};
/*! \brief Access pattern about a single statement */
......@@ -159,6 +164,8 @@ private:
bool allow_append_{false};
// Whether we are in device environment
bool in_device_env_{false};
// Nesting depth of tma_load/tma_load_im2col calls
int tma_depth_{0};
// Whether we are inside condition.
int condition_counter_{0};
// The current double buffer write scope.
......
......@@ -544,7 +544,7 @@ public:
}
return it->second->alloc_var;
} else {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
}
PrimExpr VisitExpr_(const CallNode *op) final {
......@@ -679,7 +679,7 @@ private:
return !scope.tag.empty() && scope.tag != ".dyn" &&
scope.tag != ".barrier" && scope.tag != ".workspace" &&
scope.tag != ".vtcm" && scope.tag != ".var" &&
scope.tag != ".descriptor";
scope.tag.find(".descriptor") != 0;
}
// Allocate entry of node.
......@@ -865,7 +865,7 @@ private:
ICHECK_NE(e->const_nbits, 0U);
MemoryInfo info;
if (e->scope.tag != ".barrier" && e->scope.tag != ".var" &&
e->scope.tag != ".descriptor") {
e->scope.tag.find(".descriptor") != 0) {
info = GetMemoryInfo(e->scope.to_string());
}
uint64_t total_bits = e->const_nbits;
......@@ -978,8 +978,8 @@ private:
ICHECK(alloc_info.count(var));
const AllocEntry &entry = alloc_info.at(var);
const AllocateNode *alloc = entry.alloc;
auto storage_scope =
StorageScope::Create(GetPtrStorageScope(GetRef<Var>(var)));
auto storage_scope = StorageScope::Create(
GetPtrStorageScope(tvm::ffi::GetRef<Var>(var)));
StorageEntry *dst_entry = nullptr;
// inplace detection
if (detect_inplace) {
......@@ -1425,9 +1425,30 @@ public:
void
OnArrayDeclaration(const Var &buffer, DataType element_dtype, PrimExpr extent,
BufferVarInfo::DeclarationLocation declaration_location) {
ICHECK(info_map_.find(buffer.get()) == info_map_.end())
<< "Array declaration of " << buffer->name_hint
<< " occurred multiple times.";
auto it = info_map_.find(buffer.get());
if (it != info_map_.end()) {
// The same buffer var may appear in more than one Allocate due to
// upstream transforms (e.g., storage planning/merging). Treat repeated
// declarations as benign and merge metadata instead of erroring.
BufferVarInfo &existing = it->second;
// Prefer a concrete element dtype if the previous one was a handle.
if (existing.element_dtype.is_handle() && !element_dtype.is_handle()) {
existing.element_dtype =
element_dtype == DataType::Bool()
? DataType::Int(8).with_lanes(element_dtype.lanes())
: element_dtype;
}
// If extent was previously unknown (0) and a concrete extent is
// provided now, record it.
if (!existing.extent.defined() || is_zero(existing.extent)) {
existing.extent = extent;
}
// Merge declaration locations (bitwise OR of flags).
existing.declaration_location =
static_cast<BufferVarInfo::DeclarationLocation>(
existing.declaration_location | declaration_location);
return;
}
if (element_dtype == DataType::Bool()) {
element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
......@@ -1732,7 +1753,7 @@ public:
Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var;
if (var.same_as(op->var) && value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
}
return LetStmt(var, value, body);
}
......@@ -1985,10 +2006,10 @@ Pass StorageRewrite() {
return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.StorageRewrite", StorageRewrite);
});
}
Pass PointerValueTypeRewrite() {
auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
......@@ -1997,11 +2018,11 @@ Pass PointerValueTypeRewrite() {
return CreatePrimFuncPass(pass_func, 0, "tl.PointerValueTypeRewrite", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PointerValueTypeRewrite",
PointerValueTypeRewrite);
});
}
} // namespace transform
} // namespace tl
......
......@@ -86,6 +86,7 @@ protected:
// check if sync before statement is needed.
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
// Apply the syncs added already.
if (sync_before_stmt) {
reads.clear();
writes.clear();
......@@ -98,7 +99,8 @@ protected:
break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, false)) {
if (FindConflict(reads, acc, false) ||
FindConflict(writes, acc, false)) {
sync_before_stmt = true;
break;
}
......@@ -123,27 +125,51 @@ protected:
writes.clear();
}
}
if (sync_before_stmt) {
insert_syncs(s.stmt);
}
}
if (loop != nullptr) {
// Check if the loop body contains any reads in the same sync scope.
// If there are reads, we conservatively keep the sync within the loop
// body to preserve per-iteration ordering when needed. If there are no
// reads (e.g., only writes to shared.dyn), we can safely hoist the sync
// to before the loop to avoid redundant barriers.
bool has_read_in_scope = false;
for (const StmtEntry &s : seq) {
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead && acc.scope == sync_scope_) {
has_read_in_scope = true;
break;
}
}
if (has_read_in_scope)
break;
}
// If there is a loop-carried dependency, insert a single sync
// before the loop rather than hoisting a sync into the loop body.
// This reduces redundant per-iteration synchronizations for cases
// where each iteration touches disjoint regions (e.g., stmatrix
// writes to shared.dyn) and only a global ordering before/after the
// loop is required.
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry &s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0)
break;
if (reads.empty() && writes.empty())
break;
bool sync_before_stmt = false;
bool need_loop_sync = false;
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, true)) {
sync_before_stmt = true;
need_loop_sync = true;
break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, true)) {
sync_before_stmt = true;
if (FindConflict(reads, acc, true) ||
FindConflict(writes, acc, true)) {
need_loop_sync = true;
break;
}
} else if (acc.type == kSync) {
......@@ -151,8 +177,17 @@ protected:
writes.clear();
}
}
if (sync_before_stmt) {
if (need_loop_sync) {
if (!has_read_in_scope) {
// Mark the loop itself to receive a sync before it, instead of
// inserting inside the loop body. This ensures a single sync is
// emitted outside the loop and avoids per-iteration overhead.
insert_syncs(loop);
} else {
// Fall back to inserting before the first conflicting statement
// inside the loop to maintain correctness when reads are present.
insert_syncs(s.stmt);
}
break;
}
}
......@@ -217,6 +252,14 @@ private:
bool FindConflict(const AccessEntry &prev, const AccessEntry &curr,
bool loop_carry) {
// Special case: ignore conflicts between async-copy writes (e.g., TMA
// loads into shared memory). Multiple async writes do not require
// interspersed barriers among themselves. We still respect conflicts with
// reads to ensure visibility before consumption.
if (prev.type == kWrite && curr.type == kWrite && prev.is_async_copy &&
curr.is_async_copy) {
return false;
}
// Access to different buffers does not conflict.
if (!prev.buffer.same_as(curr.buffer)) {
return false;
......@@ -241,10 +284,15 @@ private:
return true;
}
if (prev.is_pointer_access || curr.is_pointer_access) {
// If either access is a pointer access, conservatively assume a
// conflict. For example, address_of(A[0, 0]) may refer to an unknown
// memory region, so we cannot safely determine if it overlaps with
// previous accesses.
// For accesses created via tvm_access_ptr we may still be able to prove
// disjointness using their byte ranges. If both sides expose a touched
// interval and we can show they don't overlap, skip the conflict.
if (prev.is_pointer_access && curr.is_pointer_access &&
PointerAccessIsDisjoint(prev, curr)) {
return false;
}
// Otherwise fall back to the conservative answer: treat them as
// overlapping.
return true;
}
......@@ -327,7 +375,7 @@ private:
}
}
if (!(has_same_index)) {
if (!has_same_index) {
break;
}
}
......@@ -350,6 +398,26 @@ private:
return range_is_overlap;
}
bool PointerAccessIsDisjoint(const AccessEntry &lhs, const AccessEntry &rhs) {
if (lhs.touched.size() != 1 || rhs.touched.size() != 1) {
return false;
}
PrimExpr lhs_min = analyzer_.Simplify(lhs.touched[0].min());
PrimExpr lhs_max = analyzer_.Simplify(lhs.touched[0].max());
PrimExpr rhs_min = analyzer_.Simplify(rhs.touched[0].min());
PrimExpr rhs_max = analyzer_.Simplify(rhs.touched[0].max());
if (analyzer_.CanProve(lhs_max < rhs_min,
arith::ProofStrength::kSymbolicBound)) {
return true;
}
if (analyzer_.CanProve(rhs_max < lhs_min,
arith::ProofStrength::kSymbolicBound)) {
return true;
}
return false;
}
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
......@@ -782,10 +850,10 @@ tvm::transform::Pass ThreadSync(const String &storage_scope) {
return CreatePrimFuncPass(pass_func, 0, "tl.ThreadSync", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ThreadSync", ThreadSync);
});
}
} // namespace transform
} // namespace tl
......
......@@ -33,6 +33,7 @@
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -43,6 +44,7 @@ namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
/*!
* \brief Perform data type legalization on the given BufferLoadNode pointer.
......@@ -208,6 +210,14 @@ public:
using ExprFunctor::VisitExpr;
using StmtMutator::operator();
// Convenience entry to vectorize a loop body without exposing
// the mutator invocation pattern at call sites.
static Stmt Vectorize(const Var &var, const PrimExpr &var_lanes, Stmt body) {
TLVectorizer vec{var, var_lanes};
auto vec_stmt = vec(std::move(body));
return vec_stmt;
}
TLVectorizer(const Var &var, const PrimExpr &var_lanes)
: var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
......@@ -217,8 +227,9 @@ public:
ICHECK(!need_scalarize_);
Stmt ret = StmtMutator::VisitStmt(stmt);
if (need_scalarize_) {
auto scalarized_stmt = Scalarize(stmt);
need_scalarize_ = false;
return Scalarize(stmt);
return scalarized_stmt;
} else {
return ret;
}
......@@ -242,7 +253,7 @@ public:
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
......@@ -296,7 +307,7 @@ public:
PrimExpr VisitExpr_(const NotNode *op) final {
PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return !(a);
}
......@@ -337,10 +348,10 @@ public:
PrimExpr value = this->VisitExpr(op->value);
if (value.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Broadcast(op->value, op->lanes);
}
......@@ -352,7 +363,7 @@ public:
PrimExpr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
......@@ -370,7 +381,7 @@ public:
PrimExpr VisitExpr_(const CastNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
if (value.dtype().is_scalable_vector()) {
return Cast(op->dtype.with_scalable_vscale_factor(
......@@ -383,26 +394,26 @@ public:
}
PrimExpr VisitExpr_(const FloatImmNode *op) final {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const StringImmNode *op) final {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
// Variable
PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op);
Var var = tvm::ffi::GetRef<Var>(op);
if (var.same_as(var_)) {
return ramp_;
}
auto it = let_binding_.find(var);
if (it != let_binding_.end()) {
auto it = let_var_map_.find(var);
if (it != let_var_map_.end()) {
return it->second;
} else {
return std::move(var);
......@@ -413,13 +424,13 @@ public:
PrimExpr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr t = this->VisitExpr(op->args[1]);
PrimExpr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor();
......@@ -441,7 +452,7 @@ public:
ICHECK(op->op.same_as(builtin::reinterpret()));
PrimExpr value = this->VisitExpr(op->args[0]);
if (value.same_as(op->args[0])) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) {
......@@ -478,7 +489,6 @@ public:
bool vectorizable = optional_op &&
op_vectorizable_.get(optional_op.value(), false) &&
!op->dtype.is_scalable_vector();
if (!vectorizable) {
// Cannot vectorize this op
Array<PrimExpr> new_args;
......@@ -486,12 +496,12 @@ public:
auto new_arg = this->VisitExpr(arg);
if (new_arg.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
new_args.push_back(new_arg);
}
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Call(op->dtype, op->op, new_args);
}
......@@ -500,7 +510,7 @@ public:
Array<PrimExpr> new_args = MutateArray(op->args, &lane);
// normal code path.
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Call(op->dtype.with_lanes(lane), op->op, new_args);
}
......@@ -508,7 +518,7 @@ public:
}
// BufferLoad
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = GetRef<BufferLoad>(op);
auto load = tvm::ffi::GetRef<BufferLoad>(op);
auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
......@@ -518,7 +528,6 @@ public:
if (!indices.same_as(op->indices)) {
BufferLoadNode *writer = load.CopyOnWrite();
writer->indices = indices;
// writer->LegalizeDType();
LegalizeBufferLoadDType(writer);
}
......@@ -533,21 +542,23 @@ public:
// This is used to allow cases when we reuse a single let
// expression to construct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
auto it = let_var_map_.find(op->var);
if (it != let_var_map_.end()) {
ICHECK(deep_equal_(it->second, value))
<< "Let cannot bind the same var to two different values";
}
if (value.dtype().get_lanes_or_vscale_factor() !=
op->value.dtype().get_lanes_or_vscale_factor()) {
Var new_var(op->var->name_hint, value.dtype());
let_binding_[op->var] = new_var;
let_var_map_[op->var] = new_var;
// Record mapping from the new var to its bound value
let_value_binding_[new_var] = value;
return Let(new_var, value, this->VisitExpr(op->body));
} else {
let_binding_[op->var] = op->var;
let_var_map_[op->var] = op->var;
PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Let(op->var, value, body);
}
......@@ -555,7 +566,7 @@ public:
}
// BufferStore
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = GetRef<BufferStore>(op);
auto store = tvm::ffi::GetRef<BufferStore>(op);
auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
......@@ -618,11 +629,11 @@ public:
ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector());
PrimExpr extent = this->VisitExpr(op->extent);
if (extent.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
Stmt body = this->VisitStmt(op->body);
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return For(op->loop_var, op->min, extent, op->kind, body,
op->thread_binding, op->annotations);
......@@ -633,7 +644,7 @@ public:
ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = std::nullopt;
......@@ -642,7 +653,7 @@ public:
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return IfThenElse(condition, then_case, else_case);
}
......@@ -654,20 +665,23 @@ public:
// LetStmt
Stmt VisitStmt_(const LetStmtNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
ICHECK(!let_binding_.count(op->var))
ICHECK(!let_var_map_.count(op->var))
<< "SSA violation, a single var is binded twice";
let_binding_[op->var] = value;
if (value.dtype().get_lanes_or_vscale_factor() !=
op->value.dtype().get_lanes_or_vscale_factor()) {
Var new_var(op->var->name_hint, value.dtype());
let_binding_[op->var] = new_var;
let_var_map_[op->var] = new_var;
// Record mapping from the new var to its bound value
let_value_binding_[op->var] = op->value;
let_value_binding_[new_var] = value;
return LetStmt(new_var, value, this->VisitStmt(op->body));
} else {
let_binding_[op->var] = op->var;
let_var_map_[op->var] = op->var;
let_value_binding_[op->var] = value;
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return LetStmt(op->var, value, body);
}
......@@ -681,7 +695,7 @@ public:
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
return StmtMutator::VisitStmt_(op);
......@@ -689,8 +703,27 @@ public:
// scalarize the statement
Stmt Scalarize(Stmt stmt) {
Var idx(var_->name_hint + ".s", var_->dtype);
Var idx(var_->name_hint + "_s", var_->dtype);
// Find all Vars in stmt that are keys in let_value_binding_
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_let_bound_vars;
PostOrderVisit(stmt, [this, &used_let_bound_vars](const ObjectRef &node) {
if (const auto *v = node.as<VarNode>()) {
Var var = GetRef<Var>(v);
if (let_value_binding_.count(var)) {
used_let_bound_vars.insert(var);
}
}
});
stmt = Substitute(stmt, {{var_, idx}});
if (!used_let_bound_vars.empty()) {
for (const auto &v : used_let_bound_vars) {
// Bind the existing var v to its value around the stmt scope
auto new_value = Substitute(let_value_binding_.at(v), {{var_, idx}});
stmt = LetStmt(v, new_value, stmt);
}
}
return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt);
}
......@@ -707,8 +740,11 @@ private:
PrimExpr ramp_;
// flag to mark requirement of scalarization.
bool need_scalarize_{false};
// Let binding
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
// Let var mapping
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_var_map_;
// Let value binding: map new_var -> value
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>
let_value_binding_;
// vectorizable property
OpAttrMap<TVectorizable> op_vectorizable_ =
Op::GetAttrMap<TVectorizable>("TVectorizable");
......@@ -746,7 +782,7 @@ private:
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
......@@ -762,7 +798,7 @@ private:
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
......@@ -806,7 +842,7 @@ public:
<< " for target " << Target::Current();
}
ICHECK(is_zero(op->min));
return TLVectorizer(op->loop_var, op->extent)(op->body);
return TLVectorizer::Vectorize(op->loop_var, op->extent, op->body);
} else {
return StmtMutator::VisitStmt_(op);
}
......@@ -842,10 +878,10 @@ tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) {
return CreatePrimFuncPass(pass_func, 0, "tl.VectorizeLoop", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.VectorizeLoop", VectorizeLoop);
});
}
} // namespace tl
} // namespace tvm
......@@ -159,7 +159,7 @@ public:
// Check reads from global
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ GetRef<Stmt>(op));
/*body*/ tvm::ffi::GetRef<Stmt>(op));
auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto reads = access[0];
Role role = Role::kProducer;
......@@ -511,7 +511,7 @@ private:
annotations.Set(String("stmt_group"), Integer(1));
auto original_node = (op->body).as<SeqStmtNode>();
if (!original_node) {
return GetRef<For>(op);
return tvm::ffi::GetRef<For>(op);
}
Array<Stmt> new_body;
int cur_id = 0;
......@@ -646,7 +646,7 @@ private:
if (role == Role::kBoth) {
return StmtMutator::VisitStmt_(op);
} else if ((role == Role::kProducer) == is_emitting_producer_) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return Evaluate(0);
}
......@@ -1284,7 +1284,7 @@ tvm::transform::Pass WarpSpecialized() {
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized,
disable_shuffle_elect);
} else {
ObjectRef node = String("default");
auto node = ffi::String("default");
f.CopyOnWrite()->body =
AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body);
return f;
......@@ -1293,10 +1293,10 @@ tvm::transform::Pass WarpSpecialized() {
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized);
});
}
} // namespace tl
} // namespace tvm
......@@ -266,10 +266,10 @@ tvm::transform::Pass RewriteWgmmaSync() {
return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.RewriteWgmmaSync", RewriteWgmmaSync);
});
}
} // namespace tl
} // namespace tvm
......@@ -22,15 +22,6 @@ def tl_matmul(
b_transposed=True,
k_pack=1,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
......@@ -190,6 +181,9 @@ def assert_tl_matmul_correctness(M,
if in_dtype == "int8":
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
elif in_dtype == "float8_e4m3fnuz":
A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
else:
A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))
......
......@@ -217,6 +217,9 @@ def assert_tl_matmul_correctness(M,
if in_dtype == "int8":
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
elif in_dtype == "float8_e4m3fnuz":
A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
else:
A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))
......@@ -264,11 +267,11 @@ def assert_tl_matmul_correctness(M,
@tilelang.testing.requires_rocm
def test_assert_tl_matmul():
assert_tl_matmul_correctness(
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
......@@ -283,6 +286,21 @@ def test_assert_tl_matmul():
k_pack=2,
b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True)
assert_tl_matmul_correctness(
256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(
256,
256,
512,
"float8_e4m3fnuz",
"float32",
k_pack=2,
b_transposed=False,
b_preshuffle=True)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -223,29 +223,26 @@ def run_gemm_rs(
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
@tilelang.testing.requires_rocm
def test_gemm_rs_f16f32f32_nt():
run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)
@tilelang.testing.requires_rocm
def test_gemm_rs_bf16f32f32_nt():
run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
@tilelang.testing.requires_rocm
def test_gemm_rs_bf16bf16f32_nt():
run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# @tilelang.testing.requires_rocm
# def test_gemm_rs_f16f32f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)
# @tilelang.testing.requires_rocm
# def test_gemm_rs_bf16f32f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
# @tilelang.testing.requires_rocm
# def test_gemm_rs_bf16bf16f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -514,4 +514,5 @@ def test_assert_tl_matmul_block_all_dynamic_with_pass_config():
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
import torch
import tilelang
import tilelang.testing
from tilelang import language as T
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _fill_with_static_region_kernel():
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
T.fill(x[0:128], 0)
return buggy_kernel
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _fill_with_dynamic_region_kernel():
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
a, b = T.alloc_var('int'), T.alloc_var('int')
T.fill(x[a:b], 0)
return buggy_kernel
def test_fill_with_static_region_kernel():
kernel = _fill_with_static_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
kernel(x)
def test_fill_with_dynamic_region_kernel():
kernel = _fill_with_dynamic_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
kernel(x)
if __name__ == '__main__':
tilelang.testing.main()
import torch
import tilelang
import tilelang.language as T
def test_int64_address():
@tilelang.jit
def set_cache_kernel(
S,
D,
pos_ty='int64',
dtype="float32",
):
@T.prim_func
def main(
pos: T
.Tensor(
[
S,
], pos_ty
), # type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32`
value: T.Tensor([S, D], dtype), # type: ignore
cache: T.Tensor([S, D], dtype), # type: ignore
):
with T.Kernel(S, threads=128) as bx:
slot = pos[bx]
for i in T.Parallel(D):
cache[slot, i] = value[bx, i]
return main
D = 2
S = 10
cache = torch.rand((S, D), device="cuda", dtype=torch.float32)
value = torch.rand((S, D), device='cuda', dtype=torch.float32)
pos_int64 = torch.arange(S, device='cuda', dtype=torch.int64)
pos_int32 = torch.arange(S, device='cuda', dtype=torch.int32)
kernel_int64 = set_cache_kernel(S, D, 'int64')
kernel_int32 = set_cache_kernel(S, D, 'int32')
kernel_int64(pos_int64, value, cache)
torch.testing.assert_close(cache, value)
kernel_int32(pos_int32, value, cache)
torch.testing.assert_close(cache, value)
if __name__ == "__main__":
tilelang.testing.main()
import tilelang.testing
import tilelang.language as T
def test_issue_1198():
@T.prim_func
def foo(x: T.Buffer([
32,
], "int32")):
pass
if __name__ == '__main__':
tilelang.testing.main()
import tilelang
import tilelang.language as T
import tilelang.testing
def _make_kernel(M, N):
dtype = "bfloat16"
@T.prim_func
def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), "int32")):
with T.Kernel(4, threads=1):
A = T.alloc_shared([N], dtype)
B = T.alloc_shared([N], dtype)
# Regression for a bug where InjectSoftwarePipeline left the loop
# variable as a free var, causing MakePackedAPI to fail
for i in T.Pipelined(4, num_stages=1):
_id = ids[i]
T.copy(KV[_id, :], A)
T.clear(B)
return fwd_main
def test_make_packed_api_no_free_loop_var():
func = _make_kernel(4, 4)
# Keep warp-specialization/TMA disabled to match the original repro
cfg = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
}
tilelang.compile(func, pass_configs=cfg)
if __name__ == "__main__":
tilelang.testing.main()
import tilelang.testing
from tilelang import language as T
def test_issue_1237_dynamic_copy_extent_builds():
# Repro from debug/1113_issues/copy_dyn.py, adapted as a unit test.
# The goal is to ensure T.copy correctly handles dynamic extents
# (e.g., src slice length vs. static dst buffer size) during prim_func building.
length = T.symbolic("len", dtype="int32")
@T.prim_func
def sample_kernel(global_tensor: T.Tensor[(length,), "int32"]): # noqa: F821
with T.Kernel(1, threads=32):
buffer_shared = T.alloc_shared((1024,), dtype="int32")
T.copy(global_tensor[0:length], buffer_shared)
# Building the prim_func is sufficient to exercise the bug path; no need to JIT/execute.
_ = sample_kernel
if __name__ == "__main__":
tilelang.testing.main()
......@@ -85,7 +85,7 @@ def run_gemm(
stramp = "&*(XS)"
@tvm.register_func("tilelang_callback_cuda_postproc", override=True)
@tvm.register_global_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code
......@@ -407,4 +407,5 @@ def test_ctypes_dynamic_shape():
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
test_gemm_f16f16f16_nn()
......@@ -85,7 +85,7 @@ def run_gemm(
stramp = "&*(XS)"
@tvm.register_func("tilelang_callback_cuda_postproc", override=True)
@tvm.register_global_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code
......
import tilelang.testing
import tilelang
import torch
@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
verbose=True,
)
def matmul_kernel_jit(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A=False,
trans_B=True,
in_dtype='float16',
out_dtype='float32',
accum_dtype='float32',
num_stages=2,
threads=128,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def test_par_compile():
configs = [
(1024, 1024, 1024, 128, 128, 32),
(2048, 2048, 2048, 256, 256, 64),
(4096, 4096, 4096, 64, 64, 128),
]
kernels = matmul_kernel_jit.par_compile(configs)
for (M, N, K, _, _, _), kernel in zip(configs, kernels):
A = torch.randn(M, K, dtype=torch.float16).cuda()
B = torch.randn(N, K, dtype=torch.float16).cuda()
ref = (A @ B.T).float()
C = kernel(A, B)
tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment