"src/vscode:/vscode.git/clone" did not exist on "2832520418fe81185b1a44f8421ec32f2efc1714"
Commit fecc8336 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] align shared memory allocations (#583)

* [Enhancement] Update `pythonic_expr` to format type casts and improve tensor validation in Cython wrapper

- Enhanced `pythonic_expr` to represent type casts as `(type)value` for better clarity in expression representation.
- Modified tensor validation in `CythonKernelWrapper` to conditionally check for tensor contiguity based on a new `skip_tensor_validation` parameter.
- Improved type mapping in `map_torch_type` to include version checks for new float8 types, ensuring compatibility with specific PyTorch versions.

* [Feature] Implement dynamic shared memory allocation alignment

- Added a new transformation pass `AlignDynamicSharedMemoryAllocations` to align dynamic shared memory allocations to specified byte boundaries, enhancing memory access efficiency.
- Introduced a new utility class `TileLangAlignDynamicSharedMemoryAllocations` to handle the alignment logic for both allocation and buffer operations.
- Updated the `LowerAndLegalize` function to apply the alignment transformation based on the target device's capabilities, ensuring compatibility with different architectures.

* [Enhancement] Update dtype and argument defaults in GEMM autotuning example

- Changed data type from `float16` to `bfloat16` for improved precision in computations.
- Updated the default value of the `--with_roller` argument from `True` to `False` to modify the behavior of the autotuning process.

* [Enhancement] Improve thread range computation in storage access

- Added a new method `ComputeThreadRange` to calculate the range of threads for better access tracking.
- Updated `AccessEntry` structure to include `thread_range`.
- Modified various visitor methods to utilize `IRVisitorWithAnalyzer` for improved analysis during expression and statement visits.
- Ensured thread range is computed and stored during buffer load and store operations, enhancing memory access efficiency.

* [Refactor] Update comments for clarity in dynamic shared memory allocation alignment

- Translated comments in `align_dynamic_shared_memory_allocations.cc` from Chinese to English for better understanding.
- Removed an unnecessary call to `IRVisitorWithAnalyzer::VisitStmt_` in `storage_access.cc`.
- Added a blank line for improved readability in `thread_storage_sync.cc`.

* [Refactor] Enhance storage access analysis and thread range computation

- Introduced `ExtractRealCondition` to improve condition handling in `IfThenElseNode` visits.
- Updated `ComputeThreadRange` to use `Var` instead of `IterVar` for thread range mapping, enhancing clarity and consistency.
- Wrapped statement visits in `With<arith::ConstraintContext>` to ensure proper analysis context during condition evaluations.

* [Enhancement] Update default matrix dimensions in GEMM autotune example

- Changed default values for matrix dimensions M, N, and K from 16384 to 4096 in `example_gemm_autotune.py` to facilitate quicker testing and benchmarking.

* typo fix

* enhancement

* [Fix] Add conflict detection for buffer index size mismatch in thread storage sync

- Implemented a check to return true if the sizes of previous and current buffer indices do not match, indicating a conflict.
parent f4bb9f6c
...@@ -84,7 +84,7 @@ def get_best_config(M, N, K, with_roller=False): ...@@ -84,7 +84,7 @@ def get_best_config(M, N, K, with_roller=False):
thread_num=None, thread_num=None,
enable_rasteration=None, enable_rasteration=None,
): ):
dtype = "float16" dtype = "bfloat16"
accum_dtype = "float" accum_dtype = "float"
@T.prim_func @T.prim_func
...@@ -204,11 +204,11 @@ def matmul(M, ...@@ -204,11 +204,11 @@ def matmul(M,
return gemm_autotune return gemm_autotune
def main(m: int = 16384, def main(m: int = 4096,
n: int = 16384, n: int = 4096,
k: int = 16384, k: int = 4096,
use_autotune: bool = False, use_autotune: bool = False,
with_roller: bool = True): with_roller: bool = False):
M, N, K = m, n, k M, N, K = m, n, k
use_autotune = True use_autotune = True
if use_autotune: if use_autotune:
...@@ -232,9 +232,9 @@ def main(m: int = 16384, ...@@ -232,9 +232,9 @@ def main(m: int = 16384,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K")
parser.add_argument( parser.add_argument(
"--use_autotune", "--use_autotune",
action="store_true", action="store_true",
...@@ -243,7 +243,7 @@ if __name__ == "__main__": ...@@ -243,7 +243,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--with_roller", "--with_roller",
action="store_true", action="store_true",
default=True, default=False,
help="Whether to enable BitBLAS roller for search space") help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args() args = parser.parse_args()
main(args.m, args.n, args.k, args.use_autotune, args.with_roller) main(args.m, args.n, args.k, args.use_autotune, args.with_roller)
...@@ -6,7 +6,8 @@ import example_gemm ...@@ -6,7 +6,8 @@ import example_gemm
def test_example_gemm_autotune(): def test_example_gemm_autotune():
example_gemm_autotune.main() # enable roller for fast tuning
example_gemm_autotune.main(with_roller=True)
def test_example_gemm_intrinsics(): def test_example_gemm_intrinsics():
......
/*!
* \file align_dynamic_shared_memory_allocations.cc
* \brief align dynamic shared memory allocations
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include "../op/builtin.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
class TileLangAlignDynamicSharedMemoryAllocations : public StmtExprMutator {
public:
explicit TileLangAlignDynamicSharedMemoryAllocations(int align_bytes)
: align_bytes_(align_bytes) {}
static Stmt Substitute(int align_bytes, Stmt stmt) {
TileLangAlignDynamicSharedMemoryAllocations smem_rewriter(align_bytes);
return smem_rewriter.VisitStmt(stmt);
}
Stmt VisitStmt_(const AllocateNode *op) final {
auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn") {
auto new_extents =
MakeRoundRobinAlignment(op->extents, align_bytes_, op->dtype.bytes());
if (!new_extents.same_as(op->extents)) {
auto new_allocate = Allocate(op->buffer_var, op->dtype, new_extents,
op->condition, op->body, op->annotations);
return StmtExprMutator::VisitStmt(new_allocate);
}
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply([this](Buffer buf) {
auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buf->data));
if (storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn") {
auto new_shape = MakeRoundRobinAlignment(buf->shape, align_bytes_,
buf->dtype.bytes());
if (!new_shape.same_as(buf->shape)) {
ObjectPtr<BufferNode> new_buffer =
make_object<BufferNode>(*(buf.get()));
new_buffer->shape = std::move(new_shape);
buffer_remap_.Set(buf, Buffer(new_buffer));
return Buffer(new_buffer);
}
}
return buf;
});
if (!alloc_buffers.same_as(op->alloc_buffers)) {
block.CopyOnWrite()->alloc_buffers = alloc_buffers;
}
return StmtExprMutator::VisitStmt_(block.get());
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store_node = GetRef<BufferStore>(op);
Buffer buf = op->buffer;
if (buffer_remap_.count(buf)) {
buf = buffer_remap_[buf];
return BufferStore(buf, op->value, op->indices);
}
return StmtExprMutator::VisitStmt_(store_node.get());
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load_node = GetRef<BufferLoad>(op);
Buffer buf = op->buffer;
if (buffer_remap_.count(buf)) {
buf = buffer_remap_[buf];
return BufferLoad(buf, op->indices);
}
return StmtExprMutator::VisitExpr_(load_node.get());
}
private:
static Array<PrimExpr> MakeRoundRobinAlignment(Array<PrimExpr> extents,
int align_bytes,
int dtype_bytes) {
if (extents.empty())
return extents;
// Calculate total number of elements
PrimExpr total_elems = make_const(extents[0].dtype(), 1);
for (auto extent : extents) {
total_elems = total_elems * extent;
}
// Calculate total bytes
PrimExpr total_bytes = total_elems * dtype_bytes;
// Check if already aligned
PrimExpr remainder = indexmod(total_bytes, align_bytes);
if (is_zero(remainder)) {
return extents;
}
// Need to pad the last dimension
Array<PrimExpr> adjusted;
for (size_t i = 0; i < extents.size(); ++i) {
adjusted.push_back(extents[i]);
}
// Calculate padded last dimension
// pad = ceil(total_bytes / align_bytes) * align_bytes
PrimExpr last_extent = extents.back();
PrimExpr other_elems = make_const(extents[0].dtype(), 1);
for (size_t i = 0; i < extents.size() - 1; ++i) {
other_elems = other_elems * extents[i];
}
// new_last_extent = ceil(total_bytes / align_bytes) * align_bytes /
// (other_elems * dtype_bytes)
PrimExpr padded_total_bytes =
floordiv(total_bytes + align_bytes - 1, align_bytes) * align_bytes;
PrimExpr new_last_extent =
floordiv(padded_total_bytes, other_elems * dtype_bytes);
adjusted.Set(adjusted.size() - 1, new_last_extent);
return adjusted;
}
int align_bytes_;
Map<Buffer, Buffer> buffer_remap_;
};
tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) {
using namespace tir::transform;
auto pass_func = [align_bytes](PrimFunc f, IRModule m, PassContext ctx) {
auto *n = f.CopyOnWrite();
n->body = TileLangAlignDynamicSharedMemoryAllocations::Substitute(
align_bytes, n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0,
"tl.AlignDynamicSharedMemoryAllocations", {});
}
TVM_REGISTER_GLOBAL("tl.transform.AlignDynamicSharedMemoryAllocations")
.set_body_typed(AlignDynamicSharedMemoryAllocations);
} // namespace tl
} // namespace tvm
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
*/ */
#include "storage_access.h" #include "storage_access.h"
#include <tvm/arith/analyzer.h>
#include <tvm/target/target_info.h> #include <tvm/target/target_info.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
...@@ -42,7 +43,9 @@ void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) { ...@@ -42,7 +43,9 @@ void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) {
ICHECK(allow_append_) << GetRef<BufferLoad>(op) << " " << scope.to_string(); ICHECK(allow_append_) << GetRef<BufferLoad>(op) << " " << scope.to_string();
AccessEntry e; AccessEntry e;
e.threads = env_threads(); e.threads = env_threads();
e.thread_range = this->ComputeThreadRange(e.threads);
e.buffer = buf; e.buffer = buf;
e.buffer_indices = op->indices;
e.dtype = op->dtype.element_of(); e.dtype = op->dtype.element_of();
for (const auto &index : op->indices) { for (const auto &index : op->indices) {
e.touched.push_back(arith::IntSet::Vector(index)); e.touched.push_back(arith::IntSet::Vector(index));
...@@ -52,7 +55,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) { ...@@ -52,7 +55,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) {
curr_stmt_.access.emplace_back(std::move(e)); curr_stmt_.access.emplace_back(std::move(e));
} }
// traverse child // traverse child
StmtExprVisitor::VisitExpr_(op); IRVisitorWithAnalyzer::VisitExpr_(op);
} }
void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) { void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) {
...@@ -65,7 +68,9 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) { ...@@ -65,7 +68,9 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) {
if (Enabled(buf.get(), scope)) { if (Enabled(buf.get(), scope)) {
AccessEntry e; AccessEntry e;
e.threads = env_threads(); e.threads = env_threads();
e.thread_range = this->ComputeThreadRange(e.threads);
e.buffer = buf; e.buffer = buf;
e.buffer_indices = op->indices;
e.dtype = op->value.dtype().element_of(); e.dtype = op->value.dtype().element_of();
for (const auto &index : op->indices) { for (const auto &index : op->indices) {
e.touched.push_back(arith::IntSet::Vector(index)); e.touched.push_back(arith::IntSet::Vector(index));
...@@ -75,7 +80,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) { ...@@ -75,7 +80,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) {
curr_stmt_.access.emplace_back(std::move(e)); curr_stmt_.access.emplace_back(std::move(e));
} }
// traverse child // traverse child
StmtExprVisitor::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
// push to the scope // push to the scope
scope_.back().push_back(curr_stmt_); scope_.back().push_back(curr_stmt_);
// clear access entry. // clear access entry.
...@@ -87,7 +92,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const EvaluateNode *op) { ...@@ -87,7 +92,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const EvaluateNode *op) {
allow_append_ = true; allow_append_ = true;
ICHECK_EQ(curr_stmt_.access.size(), 0U); ICHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op; curr_stmt_.stmt = op;
StmtExprVisitor::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
// push to the scope // push to the scope
if (curr_stmt_.access.size() != 0) { if (curr_stmt_.access.size() != 0) {
scope_.back().push_back(curr_stmt_); scope_.back().push_back(curr_stmt_);
...@@ -115,7 +120,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const AttrStmtNode *op) { ...@@ -115,7 +120,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const AttrStmtNode *op) {
ICHECK(double_buffer_write_ == nullptr); ICHECK(double_buffer_write_ == nullptr);
double_buffer_write_ = op->node.as<VarNode>(); double_buffer_write_ = op->node.as<VarNode>();
scope_.push_back(std::vector<StmtEntry>()); scope_.push_back(std::vector<StmtEntry>());
StmtExprVisitor::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
StmtEntry s; StmtEntry s;
s.stmt = op; s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr); s.access = Summarize(std::move(scope_.back()), nullptr);
...@@ -132,21 +137,25 @@ void TileLangStorageAccessVisitor::VisitStmt_(const AttrStmtNode *op) { ...@@ -132,21 +137,25 @@ void TileLangStorageAccessVisitor::VisitStmt_(const AttrStmtNode *op) {
} else if (op->attr_key == tvm::tir::attr::coproc_scope) { } else if (op->attr_key == tvm::tir::attr::coproc_scope) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
env_threads_.push_back(iv); env_threads_.push_back(iv);
StmtExprVisitor::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
env_threads_.pop_back(); env_threads_.pop_back();
} else if (op->attr_key == tvm::tir::attr::thread_extent) { } else if (op->attr_key == tvm::tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
env_threads_.push_back(iv); env_threads_.push_back(iv);
ICHECK_NE(iv->thread_tag.length(), 0U);
analyzer_.Bind(
iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value));
if (!in_device_env_) { if (!in_device_env_) {
in_device_env_ = true; in_device_env_ = true;
scope_.push_back(std::vector<StmtEntry>()); scope_.push_back(std::vector<StmtEntry>());
StmtExprVisitor::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
// no need to take the result as the thread barrier automatically syncs. // no need to take the result as the thread barrier automatically syncs.
Summarize(std::move(scope_.back()), nullptr); Summarize(std::move(scope_.back()), nullptr);
in_device_env_ = false; in_device_env_ = false;
scope_.pop_back(); scope_.pop_back();
} else { } else {
StmtExprVisitor::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
} }
env_threads_.pop_back(); env_threads_.pop_back();
} else if (op->attr_key == tvm::tir::attr::hand_threaded) { } else if (op->attr_key == tvm::tir::attr::hand_threaded) {
...@@ -154,13 +163,13 @@ void TileLangStorageAccessVisitor::VisitStmt_(const AttrStmtNode *op) { ...@@ -154,13 +163,13 @@ void TileLangStorageAccessVisitor::VisitStmt_(const AttrStmtNode *op) {
// this avoids control flow and read/write conflicts // this avoids control flow and read/write conflicts
// between hand-threaded kernels and automatic threading // between hand-threaded kernels and automatic threading
} else { } else {
StmtExprVisitor::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
} }
} }
void TileLangStorageAccessVisitor::VisitStmt_(const ForNode *op) { void TileLangStorageAccessVisitor::VisitStmt_(const ForNode *op) {
scope_.push_back(std::vector<StmtEntry>()); scope_.push_back(std::vector<StmtEntry>());
StmtExprVisitor::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
StmtEntry s; StmtEntry s;
s.stmt = op; s.stmt = op;
s.access = Summarize(std::move(scope_.back()), op); s.access = Summarize(std::move(scope_.back()), op);
...@@ -206,18 +215,27 @@ void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) { ...@@ -206,18 +215,27 @@ void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) {
allow_append_ = true; allow_append_ = true;
this->VisitExpr(op->condition); this->VisitExpr(op->condition);
PrimExpr real_condition = ExtractRealCondition(op->condition);
curr_stmt_.access.clear(); curr_stmt_.access.clear();
allow_append_ = false; allow_append_ = false;
scope_.push_back(std::vector<StmtEntry>()); scope_.push_back(std::vector<StmtEntry>());
this->VisitStmt(op->then_case); {
With<arith::ConstraintContext> constraint(&analyzer_, real_condition);
this->VisitStmt(op->then_case);
}
StmtEntry s; StmtEntry s;
s.stmt = op; s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr); s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back(); scope_.pop_back();
if (op->else_case) { if (op->else_case) {
scope_.push_back(std::vector<StmtEntry>()); scope_.push_back(std::vector<StmtEntry>());
this->VisitStmt(op->else_case.value()); {
With<arith::ConstraintContext> constraint(&analyzer_, real_condition);
this->VisitStmt(op->else_case.value());
}
auto v = Summarize(std::move(scope_.back()), nullptr); auto v = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back(); scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end()); s.access.insert(s.access.end(), v.begin(), v.end());
...@@ -258,8 +276,10 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -258,8 +276,10 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
ICHECK(allow_append_); ICHECK(allow_append_);
AccessEntry e; AccessEntry e;
e.threads = env_threads(); e.threads = env_threads();
e.thread_range = this->ComputeThreadRange(e.threads);
e.dtype = dtype; e.dtype = dtype;
e.buffer = Downcast<Var>(buffer->data); e.buffer = Downcast<Var>(buffer->data);
e.buffer_indices = load->indices;
for (const auto &index : load->indices) { for (const auto &index : load->indices) {
e.touched.push_back(arith::IntSet::Vector(index)); e.touched.push_back(arith::IntSet::Vector(index));
} }
...@@ -267,9 +287,9 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -267,9 +287,9 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
e.scope = scope; e.scope = scope;
curr_stmt_.access.emplace_back(e); curr_stmt_.access.emplace_back(e);
} }
StmtExprVisitor::VisitExpr_(load); IRVisitorWithAnalyzer::VisitExpr_(load);
} else { } else {
StmtExprVisitor::VisitExpr_(op); IRVisitorWithAnalyzer::VisitExpr_(op);
} }
} else if (op->op.same_as(builtin::tvm_access_ptr())) { } else if (op->op.same_as(builtin::tvm_access_ptr())) {
ICHECK_EQ(op->args.size(), 5U); ICHECK_EQ(op->args.size(), 5U);
...@@ -284,8 +304,10 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -284,8 +304,10 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
ICHECK(allow_append_); ICHECK(allow_append_);
AccessEntry e; AccessEntry e;
e.threads = env_threads(); e.threads = env_threads();
e.thread_range = this->ComputeThreadRange(e.threads);
e.dtype = dtype; e.dtype = dtype;
e.buffer = Downcast<Var>(op->args[1]); e.buffer = Downcast<Var>(op->args[1]);
e.buffer_indices = {offset, extent};
e.touched = { e.touched = {
arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))}; arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))};
e.scope = scope; e.scope = scope;
...@@ -298,7 +320,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -298,7 +320,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
curr_stmt_.access.emplace_back(e); curr_stmt_.access.emplace_back(e);
} }
} }
StmtExprVisitor::VisitExpr_(op); IRVisitorWithAnalyzer::VisitExpr_(op);
} else if (op->op.same_as(builtin::tvm_storage_sync())) { } else if (op->op.same_as(builtin::tvm_storage_sync())) {
ICHECK(allow_append_); ICHECK(allow_append_);
const std::string &s = op->args[0].as<StringImmNode>()->value; const std::string &s = op->args[0].as<StringImmNode>()->value;
...@@ -306,13 +328,33 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -306,13 +328,33 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
StorageScope scope = StorageScope::Create(s); StorageScope scope = StorageScope::Create(s);
AccessEntry e; AccessEntry e;
e.threads = env_threads(); e.threads = env_threads();
e.thread_range = this->ComputeThreadRange(e.threads);
e.type = kSync; e.type = kSync;
e.scope = StorageScope::Create(s); e.scope = StorageScope::Create(s);
curr_stmt_.access.emplace_back(std::move(e)); curr_stmt_.access.emplace_back(std::move(e));
} }
} else { } else {
StmtExprVisitor::VisitExpr_(op); IRVisitorWithAnalyzer::VisitExpr_(op);
}
}
Map<Var, Range>
TileLangStorageAccessVisitor::ComputeThreadRange(Array<IterVar> threads) {
Map<Var, Range> thread_range;
for (const auto &th : threads) {
auto thread_tag = th->thread_tag;
if (thread_tag == "threadIdx.x" || thread_tag == "threadIdx.y" ||
thread_tag == "threadIdx.z") {
auto const_int_bound = analyzer_.const_int_bound(th->var);
auto min_value = const_int_bound->min_value;
auto max_value = const_int_bound->max_value;
auto extent = max_value - min_value + 1;
auto dtype = th->var.dtype();
thread_range.Set(th->var, Range::FromMinExtent(IntImm(dtype, min_value),
IntImm(dtype, extent)));
}
} }
return thread_range;
} }
StorageScope TileLangStorageAccessVisitor::GetScope(Var buffer_var) const { StorageScope TileLangStorageAccessVisitor::GetScope(Var buffer_var) const {
......
...@@ -61,7 +61,10 @@ public: ...@@ -61,7 +61,10 @@ public:
struct AccessEntry { struct AccessEntry {
/*! \brief The thread index that access this entry */ /*! \brief The thread index that access this entry */
Array<IterVar> threads; Array<IterVar> threads;
/*! \brief The touched thread range */
Map<Var, Range> thread_range;
/*! \brief The buffer variable, if any */ /*! \brief The buffer variable, if any */
Array<PrimExpr> buffer_indices;
Var buffer = NullValue<Var>(); Var buffer = NullValue<Var>();
/*! \brief The access data type */ /*! \brief The access data type */
DataType dtype; DataType dtype;
...@@ -125,6 +128,14 @@ protected: ...@@ -125,6 +128,14 @@ protected:
*/ */
virtual std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, virtual std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq,
const ForNode *loop) = 0; const ForNode *loop) = 0;
/*!
* \brief Compute the thread range for the given threads.
* \param threads The threads to compute the range for.
* \return The thread range.
*/
Map<Var, Range> ComputeThreadRange(Array<IterVar> threads);
/*! /*!
* \brief Get the scope of the buffer array. * \brief Get the scope of the buffer array.
* \return The scope of the final buffer array. * \return The scope of the final buffer array.
......
...@@ -158,6 +158,7 @@ protected: ...@@ -158,6 +158,7 @@ protected:
std::vector<AccessEntry> head, tail; std::vector<AccessEntry> head, tail;
AccessEntry esync; AccessEntry esync;
esync.threads = this->env_threads(); esync.threads = this->env_threads();
esync.thread_range = this->ComputeThreadRange(esync.threads);
esync.type = kSync; esync.type = kSync;
esync.scope = sync_scope_; esync.scope = sync_scope_;
...@@ -220,39 +221,32 @@ private: ...@@ -220,39 +221,32 @@ private:
// Same index value means no conflicts // Same index value means no conflicts
// TODO(tqchen) more standard set based testing. // TODO(tqchen) more standard set based testing.
bool has_same_index = true; bool has_same_index = true;
// Even if access has the same index, those indices need to bool range_is_equal = true;
// depend on the innermost thread id to avoid race condition for (const auto &kv : prev.thread_range) {
bool depends_on_thread_index = true; if (!StructuralEqual()(kv.second, curr.thread_range[kv.first])) {
const VarNode *thread_index_var = nullptr; range_is_equal = false;
if (!curr.threads.empty()) { break;
thread_index_var = curr.threads.back()->var.get(); }
} }
for (size_t i = 0; i < prev.touched.size(); i++) { if (prev.buffer_indices.size() != curr.buffer_indices.size()) {
const auto &prev_intset = prev.touched[i]; // They are not the same indices, should be conflict.
const auto &curr_intset = curr.touched[i]; return true;
}
if (prev_intset.IsSinglePoint() && curr_intset.IsSinglePoint()) {
PrimExpr prev_index = prev_intset.PointValue(); for (size_t i = 0; i < prev.buffer_indices.size(); i++) {
PrimExpr curr_index = curr_intset.PointValue(); const auto &prev_indice = prev.buffer_indices[i];
has_same_index = ExprDeepEqual()(prev_index, curr_index); const auto &curr_indice = curr.buffer_indices[i];
if (thread_index_var != nullptr) { if (!ExprDeepEqual()(prev_indice, curr_indice)) {
auto f_uses_thread_index = [=](const tvm::tir::VarNode *parameter) {
return parameter == thread_index_var;
};
depends_on_thread_index = depends_on_thread_index &&
UsesVar(curr_index, f_uses_thread_index) &&
UsesVar(prev_index, f_uses_thread_index);
}
} else {
has_same_index = false; has_same_index = false;
} }
if (!(has_same_index && depends_on_thread_index)) { if (!(has_same_index)) {
break; break;
} }
} }
if (has_same_index && depends_on_thread_index) {
if (has_same_index && range_is_equal) {
return false; return false;
} }
...@@ -261,7 +255,6 @@ private: ...@@ -261,7 +255,6 @@ private:
if (prev.double_buffer_write && curr.type == kRead && !loop_carry) { if (prev.double_buffer_write && curr.type == kRead && !loop_carry) {
return false; return false;
} }
// If nothing else allows sharing the same buffer, then they are // If nothing else allows sharing the same buffer, then they are
// in conflict. // in conflict.
return true; return true;
......
...@@ -39,33 +39,30 @@ def run_passes(func: tvm.tir.PrimFunc): ...@@ -39,33 +39,30 @@ def run_passes(func: tvm.tir.PrimFunc):
return tilelang.transform.ThreadSync("shared")(mod) return tilelang.transform.ThreadSync("shared")(mod)
@tvm.testing.requires_cuda @tilelang.testing.requires_cuda
def test_thread_storage_sync(): def test_sync_if_with_same_index():
m = te.size_var("m")
l = te.size_var("l") @T.prim_func
A = te.placeholder((m, l), name="A") def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None:
threadIdx_x = T.env_thread("threadIdx.x")
A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") threadIdx_y = T.env_thread("threadIdx.y")
A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2") blockIdx_x = T.env_thread("blockIdx.x")
p0 = T.Buffer([2], dtype="float32", data=p0_arg.data)
s = te.create_schedule(A2.op) result_local = T.alloc_buffer([1], dtype="float32", scope="local")
xo, xi = s[A2].split(A2.op.axis[0], factor=8) temp_shared = T.alloc_buffer([1], dtype="float32", scope="shared")
s[A2].bind(xo, te.thread_axis("blockIdx.x")) T.launch_thread(blockIdx_x, 8)
s[A1].compute_at(s[A2], xo) T.launch_thread(threadIdx_x, 4)
s[A1].set_scope("shared") result_local[0] = T.float32(0)
if threadIdx_y < 8:
bounds = tvm.te.schedule.InferBound(s) temp_shared[threadIdx_x] = p0[0]
assert isinstance(bounds, tvm.container.Map) temp_shared[threadIdx_x] = temp_shared[threadIdx_x]
stmt = tvm.te.schedule.ScheduleOps(s, bounds) result_local[0] = result_local[0] + temp_shared[0]
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
mod = run_passes(func) mod = run_passes(func)
f = mod["test_kernel"] assert "T.tvm_storage_sync" in str(mod)
body_list = tvm.tir.stmt_list(f.body.body.body.body.body.body)
assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))
@tvm.testing.requires_cuda @tilelang.testing.requires_cuda
def test_sync_else_branch(): def test_sync_else_branch():
def ir(A, B): def ir(A, B):
...@@ -101,7 +98,7 @@ def test_sync_else_branch(): ...@@ -101,7 +98,7 @@ def test_sync_else_branch():
assert "T.tvm_storage_sync" in str(mod) assert "T.tvm_storage_sync" in str(mod)
@tvm.testing.requires_cuda @tilelang.testing.requires_cuda
def test_sync_read_thread_id_independent_location(): def test_sync_read_thread_id_independent_location():
@T.prim_func @T.prim_func
...@@ -125,7 +122,7 @@ def test_sync_read_thread_id_independent_location(): ...@@ -125,7 +122,7 @@ def test_sync_read_thread_id_independent_location():
assert "T.tvm_storage_sync" in str(mod) assert "T.tvm_storage_sync" in str(mod)
@tvm.testing.requires_cuda @tilelang.testing.requires_cuda
def test_sync_let_stmt(): def test_sync_let_stmt():
@T.prim_func(private=True) @T.prim_func(private=True)
......
...@@ -165,6 +165,10 @@ class AutotuneResult: ...@@ -165,6 +165,10 @@ class AutotuneResult:
- kernel_lib.so: The compiled kernel library - kernel_lib.so: The compiled kernel library
- params.pkl: The serialized kernel parameters - params.pkl: The serialized kernel parameters
""" """
if os.path.exists(cache_path):
logger.info(f"Cache path {cache_path} already exists, skipping saving kernel to disk")
return
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
# Save kernel source code # Save kernel source code
......
...@@ -240,6 +240,11 @@ class AutoTunerCache: ...@@ -240,6 +240,11 @@ class AutoTunerCache:
- params.pkl: The serialized kernel parameters - params.pkl: The serialized kernel parameters
""" """
cache_path = self._get_cache_path(key) cache_path = self._get_cache_path(key)
if os.path.exists(cache_path):
self.logger.info(
f"Cache path {cache_path} already exists, skipping saving kernel to disk")
return
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
# Save kernel source code # Save kernel source code
......
...@@ -430,6 +430,8 @@ def have_tma(target): ...@@ -430,6 +430,8 @@ def have_tma(target):
target : tvm.target.Target target : tvm.target.Target
The compilation target The compilation target
""" """
if target.kind.name != "cuda":
return False
compute_version = get_target_compute_version(target) compute_version = get_target_compute_version(target)
major, minor = parse_compute_version(compute_version) major, minor = parse_compute_version(compute_version)
# TMA is supported in Ada Lovelace (9.0) or later architectures. # TMA is supported in Ada Lovelace (9.0) or later architectures.
......
...@@ -22,22 +22,16 @@ def allow_warp_specialized(pass_ctx: Optional[PassContext] = None, ...@@ -22,22 +22,16 @@ def allow_warp_specialized(pass_ctx: Optional[PassContext] = None,
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool: target: Optional[Target] = None) -> bool:
# avoid circular import
from tilelang.jit.adapter.utils import is_cuda_target
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
if not is_cuda_target(target) or not have_tma(target): if not have_tma(target):
return False return False
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
return not disable_tma_lower and allow_warp_specialized(pass_ctx=pass_ctx, target=target) return not disable_tma_lower and allow_warp_specialized(pass_ctx=pass_ctx, target=target)
def allow_fence_proxy(target: Optional[Target] = None) -> bool: def allow_fence_proxy(target: Optional[Target] = None) -> bool:
# avoid circular import return have_tma(target)
from tilelang.jit.adapter.utils import is_cuda_target
return is_cuda_target(target) and have_tma(target)
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool: def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
...@@ -72,12 +66,18 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -72,12 +66,18 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LegalizeVectorizedLoop()(mod) mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
# Add safety checks for memory accesses # Add safety checks for memory accesses
mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod) mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod)
# Align dynamic shared memory allocations
if have_tma(target):
# Hopper Swizzling requires dynamic shared memory address to be aligned to 1024 bytes
mod = tilelang.transform.AlignDynamicSharedMemoryAllocations(1024)(mod)
else:
# For other devices, we align to 16 bytes
mod = tilelang.transform.AlignDynamicSharedMemoryAllocations(16)(mod)
# Simplify again to clean up any duplicated conditions # Simplify again to clean up any duplicated conditions
# that may have been introduced by safety checks # that may have been introduced by safety checks
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
# Try to vectorize loop with dynamic shape # Try to vectorize loop with dynamic shape
mod = tilelang.transform.LoopVectorizeDynamic()(mod) mod = tilelang.transform.LoopVectorizeDynamic()(mod)
return mod return mod
......
...@@ -353,4 +353,18 @@ def LowerL2Persistent(): ...@@ -353,4 +353,18 @@ def LowerL2Persistent():
def PersistThreadblock(): def PersistThreadblock():
"""PersistThreadblock """PersistThreadblock
""" """
return _ffi_api.PersistThreadblock() # type: ignore return _ffi_api.PersistThreadblock() # type: ignore
\ No newline at end of file
def AlignDynamicSharedMemoryAllocations(align_bytes: int = 16):
"""AlignDynamicSharedMemoryAllocations
Parameters
----------
align_bytes: int
The alignment bytes.
Returns
-------
"""
return _ffi_api.AlignDynamicSharedMemoryAllocations(align_bytes) # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment