Commit d946d1d4 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix for T.copy with dynamic range (#462)

* [Refactor] Update barrier functions and remove argparse in example_warp_specialize_flashmla.py

* Refactored barrier functions to use new signatures for improved clarity and consistency.
* Replaced `mbarrier_arrive` and `mbarrier_wait_parity` with `barrier_arrive` and `barrier_wait` respectively.
* Removed argparse dependency and replaced it with hardcoded parameters for batch size and dimensions in the main function, simplifying the example script.

* [Refactor] Update warp_specialized_rewriter with license change and code cleanup

* Replaced Apache License header with MIT License in `warp_specialized_rewriter.cc`.
* Removed the `ThreadTagChecker` class to streamline the code, as it was no longer needed.
* Added `#include` for `common/collector.h` to support new functionality.
* Updated file documentation to reflect the correct filename and purpose.
* Improved overall code readability by removing unnecessary comments and sections.

* [Feature] Add thread synchronization functions in builtin.py and refine buffer region checks in copy.py

* Introduced `sync_threads` and `sync_thread_partial` functions in `builtin.py` for improved thread synchronization capabilities.
* Enhanced documentation for new synchronization functions to clarify usage and parameters.
* Updated buffer region validation in `copy.py` to ensure type checking for integer values, improving error handling for region extents.

* lint fix

* [Feature] Introduce TMA barrier injection and related utilities

* Added `inject_tma_barrier.cc` to implement TMA barrier rewriting for CUDA GPU (sm90+).
* Created `common/attr.h` and `common/collector.h` for attribute checks and information collection from the IR.
* Updated `ir.cc` to use a constant for the main block name instead of a hardcoded string.
* Cleaned up `warp_specialized_rewriter.cc` by removing unnecessary whitespace.
* Enhanced thread tag validation with `ThreadTagChecker` to ensure only `threadIdx.x` is used in TMA barrier contexts.

* lint fix
parent dd7eb488
......@@ -4,15 +4,13 @@
*
*/
#include "./transform/common/attr.h"
#include <tvm/arith/analyzer.h>
#include <tvm/script/ir_builder/tir/ir.h>
namespace tvm {
namespace tl {
constexpr const char *tilelang_is_cpu_kernel_frame =
"tilelang.is_cpu_kernel_frame";
using namespace script::ir_builder::tir;
static Var CreateEnvThread(String name, String thread_tag, DataType dtype) {
......@@ -204,11 +202,11 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
}
if (attrs.defined()) {
auto empty_block = Block("root");
auto empty_block = Block(MainBlockName);
empty_block->annotations = attrs;
n->frames.push_back(empty_block);
} else {
n->frames.push_back(Block("root"));
n->frames.push_back(Block(MainBlockName));
}
return KernelLaunchFrame(n);
......
/*!
* \file attr.h
* \brief Check attributes of the IR
*/
namespace tvm {
namespace tl {
constexpr const char *MainBlockName = "tilelang_root";
constexpr const char *tilelang_is_cpu_kernel_frame =
"tilelang.is_cpu_kernel_frame";
} // namespace tl
} // namespace tvm
/*!
* \file collector.h
* \brief Collect information from the IR
*/
#include "arith/ir_visitor_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../../op/builtin.h"
namespace tvm {
namespace tl {
using namespace tir;
class ThreadTagChecker : public StmtExprVisitor {
public:
static bool HasOnlyThreadIdxX(const PrimFunc &f) {
ThreadTagChecker checker;
checker(f->body);
return checker.is_valid_;
}
private:
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iter_var = Downcast<IterVar>(op->node);
String thread_tag = iter_var->thread_tag;
bool is_y_or_z =
thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z";
if (!thread_tag.empty() && is_y_or_z && !is_one(iter_var->dom->extent)) {
is_valid_ = false;
}
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kThreadBinding) {
ICHECK(op->thread_binding.defined());
String thread_tag = op->thread_binding.value()->thread_tag;
bool is_y_or_z =
thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z";
if (!thread_tag.empty() && is_y_or_z) {
auto iter_var = Downcast<IterVar>(op->thread_binding);
if (iter_var.defined() && iter_var->dom.defined() &&
!is_one(iter_var->dom->extent)) {
is_valid_ = false;
}
}
}
StmtExprVisitor::VisitStmt_(op);
}
bool is_valid_ = true;
};
} // namespace tl
} // namespace tvm
......@@ -34,6 +34,8 @@
#include "../op/builtin.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "./common/collector.h"
#include "./common/attr.h"
namespace tvm {
namespace tl {
......@@ -41,6 +43,7 @@ namespace tl {
using namespace tir;
using namespace tir::transform;
using arith::IRMutatorWithAnalyzer;
using arith::IRVisitorWithAnalyzer;
class TmaTraitsCollector : public StmtExprVisitor {
public:
......@@ -91,17 +94,6 @@ private:
IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
IterVarType::kDataPar);
PrimExpr GetBarrierId(const PrimExpr &ko) {
// FloorMod(ko, 1)
return FloorMod(ko, IntImm(DataType::Int(32), 1));
}
PrimExpr GetBarrierParity(const PrimExpr &ko) {
// FloorDiv(ko, 1) % 2
return FloorMod(FloorDiv(ko, IntImm(DataType::Int(32), 1)),
IntImm(DataType::Int(32), 2));
}
PrimExpr makeGetBarrier(PrimExpr barrier_id) {
return Call(DataType::Handle(), get_mbarrier(), {barrier_id});
}
......@@ -116,6 +108,7 @@ private:
: IRMutatorWithAnalyzer(analyzer) {}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
......@@ -168,13 +161,25 @@ private:
}
};
class TmaBarrierCollector : public StmtExprVisitor {
class TmaBarrierCollector : public IRVisitorWithAnalyzer {
public:
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id() {
return tma_op_to_barrier_id_;
}
Map<PrimExpr, IntImm> barrier_id_to_range() { return barrier_id_to_range_; }
private:
void UpdateBarrierRange(PrimExpr barrier_id, IntImm extent) {
if (barrier_id_to_range_.count(barrier_id)) {
auto old_extent = barrier_id_to_range_[barrier_id];
ICHECK_EQ(old_extent->value, extent->value)
<< "barrier_id: " << barrier_id << " has different extent";
barrier_id_to_range_.Set(barrier_id, extent);
} else {
barrier_id_to_range_.Set(barrier_id, extent);
}
}
void VisitStmt_(const EvaluateNode *op) final {
if (const auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(tma_load())) {
......@@ -186,33 +191,76 @@ private:
for (auto tma_call : pending_tma_ops_) {
tma_op_to_barrier_id_.Set(tma_call, barrier_id);
}
auto const_int_bound = analyzer_.const_int_bound(thread_var_);
auto extent = const_int_bound->max_value - const_int_bound->min_value + 1;
UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent));
pending_tma_ops_.clear();
} else if (call->op.same_as(builtin::ptx_wait_barrier())) {
PrimExpr barrier_id = call->args[0];
auto const_int_bound = analyzer_.const_int_bound(thread_var_);
auto extent = const_int_bound->max_value - const_int_bound->min_value + 1;
UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent));
}
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode *op) {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
thread_var_ = iv;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
IterVar thread_var_;
std::vector<Call> pending_tma_ops_;
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_;
Map<PrimExpr, IntImm> barrier_id_to_range_;
};
// we trust mbarrier_wait_parity to be correct
class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
public:
TmaBarrierRewriter(arith::Analyzer *analyzer,
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id)
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id,
Map<PrimExpr, IntImm> barrier_id_to_range,
bool has_create_list_of_mbarrier)
: IRMutatorWithAnalyzer(analyzer),
tma_op_to_barrier_id_(tma_op_to_barrier_id) {}
tma_op_to_barrier_id_(tma_op_to_barrier_id),
barrier_id_to_range_(barrier_id_to_range),
has_create_list_of_mbarrier_(has_create_list_of_mbarrier) {}
static PrimFunc Rewrite(PrimFunc f, arith::Analyzer *analyzer) {
f = TmaExpectTxRewriter::Rewrite(f, analyzer);
TmaBarrierCollector collector;
collector(f->body);
TmaBarrierRewriter rewriter(analyzer, collector.tma_op_to_barrier_id());
bool has_create_list_of_mbarrier = false;
PostOrderVisit(f->body, [&](const ObjectRef& node) {
if (const auto* call = node.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier())) {
has_create_list_of_mbarrier = true;
}
}
});
TmaBarrierRewriter rewriter(analyzer, collector.tma_op_to_barrier_id(),
collector.barrier_id_to_range(), has_create_list_of_mbarrier);
f.CopyOnWrite()->body = rewriter(f->body);
return f;
}
private:
Stmt VisitStmt_(const BlockNode *op){
auto block = GetRef<Block>(op);
if (!has_create_list_of_mbarrier_ && op->name_hint == MainBlockName) {
ICHECK(false) << "Please declare create_list_of_mbarrier.";
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load())) {
// check this must be in the tma_op_to_barrier_id_
......@@ -233,10 +281,21 @@ private:
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_;
Map<PrimExpr, IntImm> barrier_id_to_range_;
bool has_create_list_of_mbarrier_;
};
tvm::transform::Pass InjectTmaBarrier() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
// Check if function only uses threadIdx.x before proceeding
if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) {
LOG(WARNING) << "InjectTmaBarrier will be disabled because the program "
"uses thread tags other than threadIdx.x\n"
<< "If you want to use TMA barrier, please refactor "
"your program to use threadIdx.x only";
// Return original function unchanged if other thread tags are found
return f;
}
arith::Analyzer analyzer;
return TmaBarrierRewriter::Rewrite(f, &analyzer);
};
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file warp_specialized_pipeline.cc
* \file warp_specialized_rewriter.cc
* \brief Warp specialized Pipeline for cuda GPU (sm90+)
*/
......@@ -31,6 +12,7 @@
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "./common/collector.h"
namespace tvm {
namespace tl {
......@@ -932,49 +914,6 @@ private:
friend class WarpSpecializedRewriter;
};
class ThreadTagChecker : public StmtExprVisitor {
public:
static bool HasOnlyThreadIdxX(const PrimFunc &f) {
ThreadTagChecker checker;
checker(f->body);
return checker.is_valid_;
}
private:
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iter_var = Downcast<IterVar>(op->node);
String thread_tag = iter_var->thread_tag;
bool is_y_or_z =
thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z";
if (!thread_tag.empty() && is_y_or_z && !is_one(iter_var->dom->extent)) {
is_valid_ = false;
}
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kThreadBinding) {
ICHECK(op->thread_binding.defined());
String thread_tag = op->thread_binding.value()->thread_tag;
bool is_y_or_z =
thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z";
if (!thread_tag.empty() && is_y_or_z) {
auto iter_var = Downcast<IterVar>(op->thread_binding);
if (iter_var.defined() && iter_var->dom.defined() &&
!is_one(iter_var->dom->extent)) {
is_valid_ = false;
}
}
}
StmtExprVisitor::VisitStmt_(op);
}
bool is_valid_ = true;
};
class SetMaxNRegCollector : public StmtExprVisitor {
public:
static Array<IntImm> Collect(const PrimFunc &f) {
......
......@@ -292,3 +292,22 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
The value to shuffle
"""
return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset)
def sync_threads():
"""Synchronize all threads in a warp.
"""
return tir.op.tvm_storage_sync("shared")
def sync_thread_partial(barrier_id: Union[int, PrimExpr, tir.Call]):
"""Synchronize threads within a warp.
Args:
barrier_id: Optional[int, PrimExpr]
The memory barrier to synchronize
Returns:
tir.Call: A handle to the synchronization operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.sync_thread_partial"), barrier_id)
......@@ -83,7 +83,7 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s
v = region_extents[i]
if v in tmp_extents:
tmp_extents.remove(v)
elif v != 1:
elif isinstance(v, tir.IntImm) and v != 1:
raise ValueError(
f"buffer {buffer_region.buffer} region_extents[{i}] = {v}, extents[{i}] = {extents[i]}"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment