Commit ec84188f authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Separate tilelang Pass Thread Sync (with Hopper support) from tvm (#85)

* bump version into v0.1.0

* [Enhancement] Add custom develop command for editable installs and update .gitignore

* [Documentation] Update README to include system dependencies installation instructions

* [Build] Update setup.py to support library file copying for both release and develop modes

* [Build] Refactor library file copying logic in setup.py

* [Documentation] Remove unnecessary install section header in Installation.md

* [Build] Add tox configuration and local distribution script for multi-Python version support

* [Build] Improve git submodule update function with better error handling

* [Build] Update LLVM configuration path in ROCm installation script

* [Build] Add .tox/ to .gitignore for tox testing environment

* [Build] Add support for TVM prebuild path configuration in CMakeLists.txt

* [Cleanup] Remove unused TVM runtime error codes header

* [Cleanup] Fix TVM grid constant type reference in CUDA module

* [Cleanup] Remove unused customized_code function from IR module

* [Feature] Add TileLang thread synchronization and storage access analysis passes

* [Build] Reorder DLL search path directories for more flexible library loading

* [Refactor] Improve thread synchronization and library path handling

- Rename ThreadSync and TileLangThreadSync functions in C++ code
- Update Python docstring for ThreadSync with more detailed description
- Reorder library path detection in tilelang environment setup
- Minor comment and code cleanup in CUDA and warp specialization modules

* [Refactor] Improve thread synchronization code style and formatting

- Standardize pointer type spacing in storage_access.h and storage_access.cc
- Update whitespace and indentation in thread_storage_sync.cc
- Reorder include statements in thread_partial_sync.cc
- Minor code formatting improvements across thread synchronization files

* [Refactor] Fix global function registration for ThreadSync

- Correct global function registration to use ThreadSync instead of TileLangThreadSync
- Update TVM global registration to match recent refactoring efforts

* [Refactor] Simplify ThreadSync global function registration

- Remove unnecessary whitespace in global function registration
- Compact the TVM global registration line for ThreadSync
parent f55defac
......@@ -82,3 +82,6 @@ build_sdist/
# ignore lib with develop mode
tilelang/lib
# tox
.tox/
......@@ -65,10 +65,19 @@ endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
# Locate TVM prebuild path
if(NOT DEFINED TVM_PREBUILD_PATH)
if(DEFINED ENV{TVM_PREBUILD_PATH})
set(TVM_PREBUILD_PATH "$ENV{TVM_PREBUILD_PATH}")
endif()
endif()
# Locate TVM source directory
if(NOT DEFINED TVM_SOURCE_DIR)
if(DEFINED ENV{TVM_SOURCE_DIR})
set(TVM_SOURCE_DIR "$ENV{TVM_SOURCE_DIR}")
elseif(DEFINED TVM_PREBUILD_PATH)
set(TVM_SOURCE_DIR "${TVM_PREBUILD_PATH}/..")
else()
set(TVM_SOURCE_DIR ${PROJECT_SOURCE_DIR}/3rdparty/tvm)
endif()
......@@ -127,6 +136,7 @@ message(STATUS "Collected source files: ${TILE_LANG_SRCS}")
# Add TileLang object library
add_library(tilelang_objs OBJECT ${TILE_LANG_SRCS})
message(STATUS "TVM_SOURCE_DIR: ${TVM_SOURCE_DIR}")
# Include directories for TileLang
set(TILE_LANG_INCLUDES
${TVM_SOURCE_DIR}/include
......
......@@ -71,7 +71,7 @@ cd build
echo "Configuring TVM build with LLVM and CUDA paths..."
echo "set(USE_LLVM $LLVM_CONFIG_PATH)" >> config.cmake && echo "set(USE_ROCM /opt/rocm)" >> config.cmake
echo "set(USE_LLVM llvm-config-16)" >> config.cmake && echo "set(USE_ROCM /opt/rocm)" >> config.cmake
echo "Running CMake for TileLang..."
cmake ..
......
......@@ -24,7 +24,6 @@
#include <tvm/relay/executor.h>
#include <tvm/relay/runtime.h>
#include <tvm/runtime/crt/error_codes.h>
#include <tvm/runtime/module.h>
#include <tvm/target/codegen.h>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file storage_access.cc
*/
#include "storage_access.h"
#include <tvm/target/target_info.h>
#include <tvm/tir/op.h>
#include <string>
#include <utility>
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) {
Var buf = op->buffer->data;
StorageScope scope = GetScope(buf);
if (Enabled(buf.get(), scope)) {
ICHECK(allow_append_) << op << " " << scope.to_string();
AccessEntry e;
e.threads = env_threads();
e.buffer = buf;
e.dtype = op->dtype.element_of();
for (const auto &index : op->indices) {
e.touched.push_back(arith::IntSet::Vector(index));
}
e.type = kRead;
e.scope = scope;
curr_stmt_.access.emplace_back(std::move(e));
}
// traverse child
StmtExprVisitor::VisitExpr_(op);
}
void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) {
allow_append_ = true;
ICHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
Var buf = op->buffer->data;
StorageScope scope = GetScope(buf);
if (Enabled(buf.get(), scope)) {
AccessEntry e;
e.threads = env_threads();
e.buffer = buf;
e.dtype = op->value.dtype().element_of();
for (const auto &index : op->indices) {
e.touched.push_back(arith::IntSet::Vector(index));
}
e.type = kWrite;
e.scope = scope;
curr_stmt_.access.emplace_back(std::move(e));
}
// traverse child
StmtExprVisitor::VisitStmt_(op);
// push to the scope
scope_.back().push_back(curr_stmt_);
// clear access entry.
curr_stmt_.access.clear();
allow_append_ = false;
}
void TileLangStorageAccessVisitor::VisitStmt_(const EvaluateNode *op) {
allow_append_ = true;
ICHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
StmtExprVisitor::VisitStmt_(op);
// push to the scope
if (curr_stmt_.access.size() != 0) {
scope_.back().push_back(curr_stmt_);
curr_stmt_.access.clear();
}
allow_append_ = false;
}
void TileLangStorageAccessVisitor::VisitStmt_(const LetStmtNode *op) {
allow_append_ = true;
ICHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
this->VisitExpr(op->value);
// push to the scope
scope_.back().push_back(curr_stmt_);
// clear access entry.
curr_stmt_.access.clear();
allow_append_ = false;
// traverse body block
this->VisitStmt(op->body);
}
void TileLangStorageAccessVisitor::VisitStmt_(const AttrStmtNode *op) {
if (op->attr_key == tvm::tir::attr::double_buffer_write) {
ICHECK(double_buffer_write_ == nullptr);
double_buffer_write_ = op->node.as<VarNode>();
scope_.push_back(std::vector<StmtEntry>());
StmtExprVisitor::VisitStmt_(op);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
if (!s.access.empty()) {
for (AccessEntry &e : s.access) {
if (e.type == kWrite && e.buffer.get() == double_buffer_write_) {
e.double_buffer_write = true;
}
}
scope_.back().emplace_back(std::move(s));
}
double_buffer_write_ = nullptr;
} else if (op->attr_key == tvm::tir::attr::coproc_scope) {
IterVar iv = Downcast<IterVar>(op->node);
env_threads_.push_back(iv);
StmtExprVisitor::VisitStmt_(op);
env_threads_.pop_back();
} else if (op->attr_key == tvm::tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
env_threads_.push_back(iv);
if (!in_device_env_) {
in_device_env_ = true;
scope_.push_back(std::vector<StmtEntry>());
StmtExprVisitor::VisitStmt_(op);
// no need to take the result as the thread barrier automatically syncs.
Summarize(std::move(scope_.back()), nullptr);
in_device_env_ = false;
scope_.pop_back();
} else {
StmtExprVisitor::VisitStmt_(op);
}
env_threads_.pop_back();
} else if (op->attr_key == tvm::tir::attr::hand_threaded) {
// skip this pass on blocks that were hand_threaded
// this avoids control flow and read/write conflicts
// between hand-threaded kernels and automatic threading
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
void TileLangStorageAccessVisitor::VisitStmt_(const ForNode *op) {
scope_.push_back(std::vector<StmtEntry>());
StmtExprVisitor::VisitStmt_(op);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), op);
scope_.pop_back();
if (s.access.size() != 0) {
// relax the touched set to contain all ranges in the loop.
std::unordered_map<const VarNode *, arith::IntSet> relax_map;
relax_map[op->loop_var.get()] =
arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent));
for (AccessEntry &e : s.access) {
if (e.buffer.defined()) {
ICHECK(e.touched.size());
Array<arith::IntSet> new_touched;
for (const auto &touched : e.touched) {
new_touched.push_back(arith::EvalSet(touched, relax_map));
}
e.touched = std::move(new_touched);
}
}
}
if (!s.access.empty()) {
scope_.back().emplace_back(std::move(s));
}
}
bool IsThreadInvariant(const PrimExpr &cond) {
if (auto call = cond.as<CallNode>()) {
if (auto opt_call_op = call->op.as<Op>()) {
auto call_op = opt_call_op.value();
if (call_op.same_as(builtin::tvm_thread_invariant())) {
return true;
}
}
}
return false;
}
void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) {
bool is_thread_invariant = IsThreadInvariant(op->condition);
if (!is_thread_invariant) {
++condition_counter_;
}
this->VisitExpr(op->condition);
scope_.push_back(std::vector<StmtEntry>());
this->VisitStmt(op->then_case);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
if (op->else_case) {
scope_.push_back(std::vector<StmtEntry>());
this->VisitStmt(op->else_case.value());
auto v = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
}
scope_.back().emplace_back(std::move(s));
if (!is_thread_invariant) {
--condition_counter_;
}
}
void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) {
bool is_thread_invariant = IsThreadInvariant(op->condition);
if (!is_thread_invariant) {
++condition_counter_;
}
this->VisitExpr(op->condition);
scope_.push_back(std::vector<StmtEntry>());
this->VisitStmt(op->body);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
scope_.back().emplace_back(std::move(s));
if (!is_thread_invariant) {
--condition_counter_;
}
}
void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
if (op->op.same_as(builtin::address_of())) {
ICHECK_EQ(op->args.size(), 1U);
const BufferLoadNode *load = op->args[0].as<BufferLoadNode>();
Buffer buffer = load->buffer;
DataType dtype = buffer->dtype;
const VarNode *buffer_var = buffer->data.as<VarNode>();
StorageScope scope = GetScope(GetRef<Var>(buffer_var));
if (Enabled(buffer_var, scope)) {
ICHECK(allow_append_);
AccessEntry e;
e.threads = env_threads();
e.dtype = dtype;
e.buffer = Downcast<Var>(buffer->data);
for (const auto &index : load->indices) {
e.touched.push_back(arith::IntSet::Vector(index));
}
e.type = kRead;
e.scope = scope;
curr_stmt_.access.emplace_back(e);
}
StmtExprVisitor::VisitExpr_(load);
} else if (op->op.same_as(builtin::tvm_access_ptr())) {
ICHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode *buffer = op->args[1].as<VarNode>();
PrimExpr offset = op->args[2];
PrimExpr extent = op->args[3];
const IntImmNode *flag = op->args[4].as<IntImmNode>();
StorageScope scope = GetScope(GetRef<Var>(buffer));
// The buffer scope.
if (Enabled(buffer, scope)) {
ICHECK(allow_append_);
AccessEntry e;
e.threads = env_threads();
e.dtype = dtype;
e.buffer = Downcast<Var>(op->args[1]);
e.touched = {
arith::IntSet::FromRange(Range::FromMinExtent(offset, extent))};
e.scope = scope;
if (flag->value & 1) {
e.type = kRead;
curr_stmt_.access.emplace_back(e);
}
if (flag->value & 2) {
e.type = kWrite;
curr_stmt_.access.emplace_back(e);
}
}
StmtExprVisitor::VisitExpr_(op);
} else if (op->op.same_as(builtin::tvm_storage_sync())) {
ICHECK(allow_append_);
const std::string &s = op->args[0].as<StringImmNode>()->value;
if (s != "warp") {
StorageScope scope = StorageScope::Create(s);
AccessEntry e;
e.threads = env_threads();
e.type = kSync;
e.scope = StorageScope::Create(s);
curr_stmt_.access.emplace_back(std::move(e));
}
} else {
StmtExprVisitor::VisitExpr_(op);
}
}
StorageScope TileLangStorageAccessVisitor::GetScope(Var buffer_var) const {
if (buffer_var->type_annotation.as<PointerTypeNode>()) {
return StorageScope::Create(GetPtrStorageScope(buffer_var));
}
return StorageScope(); // global by default
}
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file storage_access.h
* \brief Common data structure for storage access analysis.
*/
#ifndef TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_
#define TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_
#include <tvm/arith/int_set.h>
#include <tvm/ir/attrs.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_map>
#include <vector>
#include "runtime/thread_storage_scope.h"
namespace tvm {
namespace tl {
using namespace tir;
using runtime::StorageRank;
using runtime::StorageScope;
/*!
* \brief Base class of storage access analysis
*/
class TileLangStorageAccessVisitor : public StmtExprVisitor {
public:
/*! \brief Storage access type */
enum AccessType {
kRead,
kWrite,
kSync,
kAlloc,
// acquired version of read, only need to handle WAR dep.
kReadAcquire
};
/*! \brief An access entry */
struct AccessEntry {
/*! \brief The thread index that access this entry */
Array<IterVar> threads;
/*! \brief The buffer variable, if any */
Var buffer = NullValue<Var>();
/*! \brief The access data type */
DataType dtype;
/*! \brief The touched access range
*
* Has one IntSet for each index in the buffer being accessed.
*/
Array<arith::IntSet> touched;
/*! \brief The type of access */
AccessType type;
/*! \brief The storage scope */
StorageScope scope;
/*! \brief Whether the access is double buffer write */
bool double_buffer_write = false;
};
/*! \brief Access pattern about a single statement */
struct StmtEntry {
/*! \brief The statement */
const Object *stmt;
/*! \brief access patterns in the statement */
std::vector<AccessEntry> access;
};
// override visitor pattern
void VisitExpr_(const BufferLoadNode *op) final;
void VisitStmt_(const BufferStoreNode *op) final;
void VisitStmt_(const EvaluateNode *op) final;
void VisitStmt_(const LetStmtNode *op) final;
void VisitStmt_(const AttrStmtNode *op) override;
void VisitStmt_(const ForNode *op) final;
void VisitStmt_(const IfThenElseNode *op) final;
void VisitStmt_(const WhileNode *op) final;
void VisitExpr_(const CallNode *op) final;
protected:
TileLangStorageAccessVisitor() { scope_.push_back(std::vector<StmtEntry>()); }
/*! \return number of conditions in the current scope. */
int condition_counter() const { return condition_counter_; }
/*! \return whether we are in device environment. */
bool in_device_env() const { return in_device_env_; }
/*! \return environment threads */
const Array<IterVar> &env_threads() const { return env_threads_; }
/*!
* \brief Whether we need analyze the buffer in current scope.
* \param buffer The buffer to be checked
* \param scope The scope of the buffer.
* \return Whether the analysis of buffer is enabled.
*/
virtual bool Enabled(const VarNode *buffer, const StorageScope &scope) const {
return true;
}
/*!
* \brief Summarize the sequence of operations into parent.
*
* Insert synchronization if necessary and remove un-necessary
* memory access which are already synced.
*
* \param seq The sequence of the access operations.
* \param loop Pass loop node if it is a loop, otherwise nullptr.
* \return The summarized sequence that represent access that
* the parent should taken care of to synchronize.
*/
virtual std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq,
const ForNode *loop) = 0;
/*!
* \brief Get the scope of the buffer array.
* \return The scope of the final buffer array.
*/
StorageScope GetScope(Var buffer_var) const;
// access scope
std::vector<std::vector<StmtEntry>> scope_;
private:
// whether access appending is enabled.
bool allow_append_{false};
// Whether we are in device environment
bool in_device_env_{false};
// Whether we are inside condition.
int condition_counter_{0};
// The current double buffer write scope.
const VarNode *double_buffer_write_{nullptr};
// the current free stmt entry.
StmtEntry curr_stmt_;
// The involving threads
Array<IterVar> env_threads_;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_TRANSFORMS_STORAGE_ACCESS_H_
......@@ -15,18 +15,18 @@
#include <unordered_set>
#include "../op/builtin.h"
#include "./storage_access.h"
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
#include "tir/transforms/storage_access.h"
namespace tvm {
namespace tl {
using namespace tir;
class ThreadPartialSyncPlanner : public StorageAccessVisitor {
class TileLangThreadPartialSyncPlanner : public TileLangStorageAccessVisitor {
public:
explicit ThreadPartialSyncPlanner(StorageScope sync_scope)
explicit TileLangThreadPartialSyncPlanner(StorageScope sync_scope)
: sync_scope_(sync_scope) {}
// The syncs inserted before each statement
......@@ -274,7 +274,7 @@ private:
num_partial_threads_ = NullOpt;
} else {
StorageAccessVisitor::VisitStmt_(op);
TileLangStorageAccessVisitor::VisitStmt_(op);
}
}
......@@ -352,9 +352,9 @@ private:
const std::unordered_map<const Object *, int> &partial_syncs_;
};
Stmt ThreadPartialSync(Stmt stmt, std::string storage_scope) {
Stmt TileLangThreadPartialSync(Stmt stmt, std::string storage_scope) {
StorageScope sync_scope = StorageScope::Create(storage_scope);
ThreadPartialSyncPlanner planner(sync_scope);
TileLangThreadPartialSyncPlanner planner(sync_scope);
planner(stmt);
return ThreadPartialSyncInserter(sync_scope, planner.syncs_inserted_,
planner.partial_syncs_inserted_)(
......@@ -365,17 +365,17 @@ using namespace tir::transform;
namespace transform {
Pass ThreadPartialSync(String storage_scope) {
Pass TileLangThreadPartialSync(String storage_scope) {
auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
auto *n = f.CopyOnWrite();
n->body = tl::ThreadPartialSync(std::move(n->body), storage_scope);
n->body = tl::TileLangThreadPartialSync(std::move(n->body), storage_scope);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.ThreadPartialSync", {});
}
TVM_REGISTER_GLOBAL("tl.transform.ThreadPartialSync")
.set_body_typed(ThreadPartialSync);
.set_body_typed(TileLangThreadPartialSync);
} // namespace transform
} // namespace tl
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file thread_storage_sync.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "./storage_access.h"
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
public:
explicit TileLangThreadSyncPlanner(StorageScope sync_scope)
: sync_scope_(sync_scope) {}
// The syncs inserted before each statement
std::unordered_set<const Object *> syncs_inserted_;
std::unordered_map<const Object *, int> partial_syncs_inserted_;
protected:
bool Enabled(const VarNode *buf, const StorageScope &scope) const final {
return in_device_env() && scope == sync_scope_;
}
// Plan the sync
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq,
const ForNode *loop) final {
// Redirect all "shared.dyn" buffer access to the same buffer var
// so that the accesses can be planned together.
Var shared_dyn_buf;
// for (StmtEntry& entry : seq) {
// for (AccessEntry& access : entry.access) {
// if (access.scope.rank == StorageRank::kShared && access.scope.tag ==
// ".dyn" &&
// access.buffer.defined()) {
// if (!shared_dyn_buf.defined()) {
// shared_dyn_buf = access.buffer;
// } else {
// access.buffer = shared_dyn_buf;
// }
// }
// }
// }
// Unsynced reads and writes
std::vector<AccessEntry> reads;
std::vector<AccessEntry> writes;
// if it is a loop, rotate two times to consider effect of loop.
// simulation based approach to find dependencies
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry &s = seq[i];
// check if sync before statement is needed.
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
// Apply the syncs added already.
if (sync_before_stmt) {
reads.clear();
writes.clear();
}
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, false)) {
sync_before_stmt = true;
break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, false)) {
sync_before_stmt = true;
break;
}
} else if (acc.type == kSync) {
reads.clear();
writes.clear();
}
}
// If sync is inserted. remove the irrelevant things.
if (sync_before_stmt) {
reads.clear();
writes.clear();
}
// Add the read/write of current statement
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) {
reads.push_back(acc);
} else if (acc.type == kWrite) {
writes.push_back(acc);
} else if (acc.type == kSync) {
reads.clear();
writes.clear();
}
}
if (sync_before_stmt) {
insert_syncs(s.stmt);
}
}
if (loop != nullptr) {
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry &s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0)
break;
if (reads.empty() && writes.empty())
break;
bool sync_before_stmt = false;
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, true)) {
sync_before_stmt = true;
break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, true)) {
sync_before_stmt = true;
break;
}
} else if (acc.type == kSync) {
reads.clear();
writes.clear();
}
}
if (sync_before_stmt) {
insert_syncs(s.stmt);
break;
}
}
}
// return the exposed entries, remove unecessary ones.
int sync_count = 0;
// head are before first sync, tail are after last sync
std::vector<AccessEntry> head, tail;
AccessEntry esync;
esync.threads = this->env_threads();
esync.type = kSync;
esync.scope = sync_scope_;
for (const StmtEntry &s : seq) {
if (syncs_inserted_.count(s.stmt)) {
if (sync_count != 0) {
tail.clear();
} else {
head.push_back(esync);
}
++sync_count;
}
for (const AccessEntry &acc : s.access) {
if (acc.type == kSync) {
if (sync_count != 0) {
tail.clear();
} else {
head.push_back(esync);
}
++sync_count;
} else {
if (sync_count != 0) {
tail.push_back(acc);
} else {
head.push_back(acc);
}
}
}
}
head.insert(head.end(), tail.begin(), tail.end());
if (loop != nullptr) {
// clear double buffer flag after a loop is finished.
for (AccessEntry &e : head) {
e.double_buffer_write = false;
}
}
return head;
}
private:
// find conflicting entry in vec.
bool FindConflict(const std::vector<AccessEntry> &prev,
const AccessEntry &curr, bool loop_carry) {
for (const AccessEntry &x : prev) {
if (FindConflict(x, curr, loop_carry)) {
return true;
}
}
return false;
}
bool FindConflict(const AccessEntry &prev, const AccessEntry &curr,
bool loop_carry) {
// Access to different buffers does not conflict.
if (!prev.buffer.same_as(curr.buffer)) {
return false;
}
// Assumes no race between threads
// Same index value means no conflicts
// TODO(tqchen) more standard set based testing.
bool has_same_index = true;
// Even if access has the same index, those indices need to
// depend on the innermost thread id to avoid race condition
bool depends_on_thread_index = true;
const VarNode *thread_index_var = nullptr;
if (!curr.threads.empty()) {
thread_index_var = curr.threads.back()->var.get();
}
for (size_t i = 0; i < prev.touched.size(); i++) {
const auto &prev_intset = prev.touched[i];
const auto &curr_intset = curr.touched[i];
if (prev_intset.IsSinglePoint() && curr_intset.IsSinglePoint()) {
PrimExpr prev_index = prev_intset.PointValue();
PrimExpr curr_index = curr_intset.PointValue();
has_same_index = ExprDeepEqual()(prev_index, curr_index);
if (thread_index_var != nullptr) {
auto f_uses_thread_index = [=](const tvm::tir::VarNode *parameter) {
return parameter == thread_index_var;
};
depends_on_thread_index = depends_on_thread_index &&
UsesVar(curr_index, f_uses_thread_index) &&
UsesVar(prev_index, f_uses_thread_index);
}
} else {
has_same_index = false;
}
if (!(has_same_index && depends_on_thread_index)) {
break;
}
}
if (has_same_index && depends_on_thread_index) {
return false;
}
// If this is a read into a double buffer that was previously
// swapped out, then it doesn't conflict.
if (prev.double_buffer_write && curr.type == kRead && !loop_carry) {
return false;
}
// If nothing else allows sharing the same buffer, then they are
// in conflict.
return true;
}
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "kWarpSpecializationScope") {
IfThenElse body = Downcast<IfThenElse>(op->body);
auto partitions = Downcast<Array<IntImm>>(op->node);
ICHECK(partitions.size() == 2);
scope_.push_back(std::vector<StmtEntry>());
num_partial_threads_ = partitions[0];
this->VisitStmt(body->then_case);
StmtEntry s;
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
num_partial_threads_ = partitions[1];
scope_.push_back(std::vector<StmtEntry>());
VisitStmt(body->else_case.value());
auto v = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
num_partial_threads_ = NullOpt;
} else {
TileLangStorageAccessVisitor::VisitStmt_(op);
}
}
void insert_syncs(const Object *obj) {
// ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside
// condition";
if (syncs_inserted_.count(obj))
return;
if (num_partial_threads_.defined()) {
syncs_inserted_.insert(obj);
partial_syncs_inserted_[obj] =
static_cast<int>(num_partial_threads_.value()->value);
} else {
syncs_inserted_.insert(obj);
}
}
private:
Optional<IntImm> num_partial_threads_;
// synchronization scope
StorageScope sync_scope_;
};
// There are cases where necessary syncthreads is not inserted by
// ThreadSyncInserter. For example, syncthreads is needed after async_wait_queue
// in the second loop below, but since ThreadSyncInserter is not aware of the
// asynchronous semantics, it cannot tell that the syncthreads is needed there.
//
// // Pipeline prologue
// for i in range(125):
// async_commit_queue(0):
// async_scope:
// shared[(i + 3) % 4] = ...
// ...
//
// // Pipeline Epilogue
// for i in range(3):
// async_wait_queue(0, 2 - i):
// local[...] = shared[(i + 125) % 4]
// This class adds syncthreads after all async_wait_queue. That includes
// syncthreads that can be inserted by ThreadSyncInserter as well, but
// ThreadSyncInserter will not insert duplicate syncthreads if it finds an
// existing one at the synchronization point.
class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator {
public:
explicit ThreadSyncAfterWaitQueueInserter(StorageScope sync_scope)
: sync_scope_(sync_scope) {}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::tir::attr::async_wait_queue_scope) {
auto sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
{StringImm(sync_scope_.to_string())}));
auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner &&
inner->attr_key == tvm::tir::attr::async_wait_inflight_count);
auto zero = make_zero(DataType::Int(32));
auto new_body = SeqStmt({sync, inner->body});
return AttrStmt(zero, tvm::tir::attr::async_wait_queue_scope, op->value,
AttrStmt(zero, tvm::tir::attr::async_wait_inflight_count,
inner->value, new_body));
}
return StmtExprMutator::VisitStmt_(op);
}
private:
StorageScope sync_scope_;
};
class ThreadSyncInserter : public StmtExprMutator {
public:
ThreadSyncInserter(StorageScope sync_scope,
const std::unordered_set<const Object *> &syncs,
std::unordered_map<const Object *, int> partial_syncs)
: sync_scope_(sync_scope), syncs_(syncs), partial_syncs_(partial_syncs) {}
Stmt VisitStmt(const Stmt &stmt) final {
if (syncs_.size() == 0)
return stmt;
if (syncs_.count(stmt.get())) {
Stmt barrier;
if (sync_scope_.rank == StorageRank::kGlobal) {
barrier = MakeGlobalBarrier();
} else if (partial_syncs_.count(stmt.get())) {
return StmtExprMutator::VisitStmt(stmt);
} else {
barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
{StringImm(sync_scope_.to_string())}));
}
// Mutate after query, to avoid stmt change.
auto ret = StmtExprMutator::VisitStmt(stmt);
ret = SeqStmt({barrier, ret});
return ret;
} else {
return StmtExprMutator::VisitStmt(stmt);
}
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer->data).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer->data].read_count;
}
return StmtExprMutator::VisitExpr_(op);
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(op->buffer->data).rank == StorageRank::kGlobal) {
++rw_stats_[op->buffer->data].write_count;
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::tir::attr::thread_extent) {
bool temp = true;
std::swap(temp, in_thread_env_);
thread_extents_.push_back(op);
Stmt ret = StmtExprMutator::VisitStmt_(op);
thread_extents_.pop_back();
std::swap(temp, in_thread_env_);
// first thread scope.
if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
ret = InitGlobalBarrier(ret.as<AttrStmtNode>());
num_blocks_ = PrimExpr();
is_lead_ = PrimExpr();
}
return ret;
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
ICHECK_EQ(op->args.size(), 5U);
Var buffer_var(Downcast<Var>(op->args[1]));
const IntImmNode *flag = op->args[4].as<IntImmNode>();
if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[buffer_var].read_count;
}
if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[buffer_var].write_count;
}
return expr;
} else if (op->op.same_as(builtin::address_of())) {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
ICHECK_EQ(op->args.size(), 1U)
<< "address_of should only have one argument (Buffer)";
BufferLoad load = Downcast<BufferLoad>(op->args[0]);
Var buffer_var(Downcast<Var>(load->buffer->data));
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[buffer_var].read_count;
}
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[buffer_var].write_count;
}
return expr;
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
private:
// RW statistics about data
struct Entry {
int read_count{0};
int write_count{0};
};
// Get current storage scope.
StorageScope GetScope(Var buffer_var) const {
return StorageScope::Create(GetPtrStorageScope(buffer_var));
}
// private functions.
Stmt InitGlobalBarrier(const AttrStmtNode *op) {
ICHECK(op != nullptr);
Array<PrimExpr> pargs = {
StringImm(runtime::symbol::tvm_prepare_global_barrier)};
Stmt prep =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs));
Stmt body = op->body;
for (const auto &kv : rw_stats_) {
const auto &e = kv.second;
if (e.read_count != 0 && e.write_count != 0) {
body = AttrStmt(kv.first, tvm::tir::attr::volatile_scope, 1, body);
}
}
rw_stats_.clear();
Stmt kinit = Evaluate(
Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {}));
body = SeqStmt({kinit, body});
body = AttrStmt(op->node, op->attr_key, op->value, body);
return SeqStmt({prep, body});
}
Stmt MakeGlobalBarrier() {
ICHECK(sync_scope_.rank == StorageRank::kGlobal);
if (!num_blocks_.defined()) {
ICHECK(!is_lead_.defined());
num_work_dim_ = thread_extents_.size();
for (const AttrStmtNode *attr : thread_extents_) {
IterVar iv = Downcast<IterVar>(attr->node);
runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag);
if (s.rank == 0) {
num_blocks_ =
(num_blocks_.defined() ? attr->value * num_blocks_ : attr->value);
} else if (s.rank == 1) {
PrimExpr cond = iv->var == make_zero(iv->var.dtype());
is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond;
}
}
} else {
ICHECK_EQ(num_work_dim_, thread_extents_.size());
}
return Evaluate(
Call(DataType::Int(32), builtin::tvm_storage_sync(),
{StringImm(sync_scope_.to_string()), is_lead_, num_blocks_}));
}
// data structure.
StorageScope sync_scope_;
const std::unordered_set<const Object *> &syncs_;
const std::unordered_map<const Object *, int> &partial_syncs_;
// The read write statistics of storage
std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> rw_stats_;
// The statistics for global barrier
bool in_thread_env_{false};
// memorized results
std::vector<const AttrStmtNode *> thread_extents_;
size_t num_work_dim_{0};
PrimExpr num_blocks_;
PrimExpr is_lead_;
};
Stmt TileLangThreadSync(Stmt stmt, std::string storage_scope) {
StorageScope sync_scope = StorageScope::Create(storage_scope);
if (sync_scope.rank == StorageRank::kShared && sync_scope.tag == "") {
stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt);
}
TileLangThreadSyncPlanner planner(sync_scope);
planner(stmt);
return ThreadSyncInserter(sync_scope, planner.syncs_inserted_,
planner.partial_syncs_inserted_)(std::move(stmt));
}
using namespace tir::transform;
namespace transform {
tvm::transform::Pass ThreadSync(String storage_scope) {
auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
auto *n = f.CopyOnWrite();
n->body = tl::TileLangThreadSync(std::move(n->body), storage_scope);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.ThreadSync", {});
}
TVM_REGISTER_GLOBAL("tl.transform.ThreadSync").set_body_typed(ThreadSync);
} // namespace transform
} // namespace tl
} // namespace tvm
......@@ -972,7 +972,7 @@ private:
DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads));
Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent),
producer_code, consumer_code);
// Add an attr here to handle the partial thread count in THreadSync pass.
// Add an attr here to handle the partial thread count in ThreadSync pass.
Array<IntImm> ws_partition = {Downcast<IntImm>(producer_thread_extent),
Downcast<IntImm>(consumer_thread_extent)};
body = AttrStmt(ws_partition, "kWarpSpecializationScope", 0, body);
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tilelang
import tilelang.testing
from tilelang import tvm as tvm
from tvm import te
from tvm.script import tir as T
def run_passes(func: tvm.tir.PrimFunc):
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
cuda_target = tvm.target.Target("cuda", host="llvm")
mod = tvm.tir.transform.Apply(lambda f: f.with_attr({
"global_symbol": "test",
"target": cuda_target
}))(
mod)
mod = tvm.tir.transform.AnnotateDeviceRegions()(mod)
mod = tvm.tir.transform.SplitHostDevice()(mod)
return tilelang.transform.ThreadSync("shared")(mod)
@tvm.testing.requires_cuda
def test_thread_storage_sync():
m = te.size_var("m")
l = te.size_var("l")
A = te.placeholder((m, l), name="A")
A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1")
A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2")
s = te.create_schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], factor=8)
s[A2].bind(xo, te.thread_axis("blockIdx.x"))
s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared")
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
mod = run_passes(func)
f = mod["test_kernel"]
body_list = tvm.tir.stmt_list(f.body.body.body.body.body.body)
assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))
@tvm.testing.requires_cuda
def test_sync_else_branch():
def ir(A, B):
ib = tvm.tir.ir_builder.create()
Aptr = ib.buffer_ptr(A)
Bptr = ib.buffer_ptr(B)
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(tx, "thread_extent", 1)
local = ib.allocate(A.dtype, (8,), name="buf_local", scope="local")
shared = ib.allocate(A.dtype, (8,), name="buf_shared", scope="shared")
with ib.for_range(0, 8) as i:
with ib.if_scope(Aptr[i] < 0):
local[i] = Aptr[i]
with ib.else_scope():
shared[i] = Aptr[i]
with ib.for_range(0, 8) as i:
with ib.if_scope(Aptr[i] < 0):
Bptr[i] = local[i]
with ib.else_scope():
Bptr[i] = shared[i]
return ib.get()
A = tvm.tir.decl_buffer((8,), "float32")
B = tvm.tir.decl_buffer((8,), "float32")
stmt = ir(A, B)
func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
mod = run_passes(func)
assert "T.tvm_storage_sync" in str(mod)
@tvm.testing.requires_cuda
def test_sync_read_thread_id_independent_location():
@T.prim_func
def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None:
threadIdx_x = T.env_thread("threadIdx.x")
blockIdx_x = T.env_thread("blockIdx.x")
p0 = T.Buffer([2], dtype="float32", data=p0_arg.data)
result_local = T.alloc_buffer([1], dtype="float32", scope="local")
temp_shared = T.alloc_buffer([1], dtype="float32", scope="shared")
T.launch_thread(blockIdx_x, 8)
T.launch_thread(threadIdx_x, 4)
result_local[0] = T.float32(0)
if threadIdx_x < 1:
temp_shared[0] = p0[0]
result_local[0] = result_local[0] + temp_shared[0] * p1[0]
if threadIdx_x < 1:
temp_shared[0] = p0[1]
result_local[0] = result_local[0] + temp_shared[0] * p1[1]
mod = run_passes(func)
assert "T.tvm_storage_sync" in str(mod)
@tvm.testing.requires_cuda
def test_sync_let_stmt():
@T.prim_func(private=True)
def func(A: T.Buffer((16 * 512), "float32")):
blockIdx_x = T.launch_thread("blockIdx.x", 16)
A_shared = T.allocate([512], "float32", "shared")
in_thread_A_temp = T.allocate([1], "float32", "local")
cross_thread_A_temp = T.allocate([1], "float32", "local")
threadIdx_x = T.launch_thread("threadIdx.x", 128)
A_shared_1 = T.Buffer((512,), data=A_shared, scope="shared")
for ax0 in range(512):
A_shared_1[ax0] = A[blockIdx_x * 512 + ax0]
in_thread_A_temp_1 = T.Buffer((1,), data=in_thread_A_temp, scope="local")
in_thread_A_temp_1[0] = T.float32(0)
with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) as A_temp:
in_thread_A_temp_1[0] = A_temp
with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128]) as A_temp:
in_thread_A_temp_1[0] = A_temp
with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256]) as A_temp:
in_thread_A_temp_1[0] = A_temp
with T.LetStmt(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384]) as A_temp:
in_thread_A_temp_1[0] = A_temp
cross_thread_A_temp_1 = T.Buffer((1,), data=cross_thread_A_temp, scope="local")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
T.tvm_thread_allreduce(
T.uint32(1),
in_thread_A_temp_1[0],
T.bool(True),
cross_thread_A_temp_1[0],
threadIdx_x,
)
@T.prim_func(private=True)
def expected(A: T.Buffer((8192,), "float32")):
blockIdx_x = T.launch_thread("blockIdx.x", 16)
A_shared_1 = T.allocate([512], "float32", "shared")
in_thread_A_temp_1 = T.allocate([1], "float32", "local")
cross_thread_A_temp_1 = T.allocate([1], "float32", "local")
threadIdx_x = T.launch_thread("threadIdx.x", 128)
A_shared_1_1 = T.Buffer((512,), data=A_shared_1, scope="shared")
for ax0 in range(512):
A_shared_1_1[ax0] = A[blockIdx_x * 512 + ax0]
in_thread_A_temp_1_1 = T.Buffer((1,), data=in_thread_A_temp_1, scope="local")
in_thread_A_temp_1_1[0] = T.float32(0)
T.tvm_storage_sync("shared")
with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x]) as A_temp:
in_thread_A_temp_1_1[0] = A_temp
with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 128]) as A_temp:
in_thread_A_temp_1_1[0] = A_temp
with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 256]) as A_temp:
in_thread_A_temp_1_1[0] = A_temp
with T.LetStmt(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 384]) as A_temp:
in_thread_A_temp_1_1[0] = A_temp
T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
)
cross_thread_A_temp_1_1 = T.Buffer((1,), data=cross_thread_A_temp_1, scope="local")
T.tvm_thread_allreduce(
T.uint32(1),
in_thread_A_temp_1_1[0],
T.bool(True),
cross_thread_A_temp_1_1[0],
threadIdx_x,
)
mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared")(mod)
tvm.ir.assert_structural_equal(mod["main"], expected)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -191,8 +191,8 @@ def lower(
mod = tl.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
mod = tir.transform.ThreadSync("shared")(mod)
mod = tir.transform.ThreadSync("shared.dyn")(mod)
mod = tl.transform.ThreadSync("shared")(mod)
mod = tl.transform.ThreadSync("shared.dyn")(mod)
mod = tl.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
......
......@@ -28,7 +28,6 @@ if TVM_IMPORT_PYTHON_PATH is not None:
sys.path.insert(0, TVM_IMPORT_PYTHON_PATH)
else:
install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
install_tvm_library_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lib")
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = (
install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
......@@ -37,14 +36,15 @@ else:
develop_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
develop_tvm_library_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "build", "tvm")
if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = (
develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, develop_tvm_path + "/python")
TVM_IMPORT_PYTHON_PATH = develop_tvm_path + "/python"
develop_tvm_library_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "build", "tvm")
install_tvm_library_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lib")
if os.environ.get("TVM_LIBRARY_PATH") is None:
if os.path.exists(develop_tvm_library_path):
os.environ["TVM_LIBRARY_PATH"] = develop_tvm_library_path
......
......@@ -1316,17 +1316,6 @@ def prefetch(
return _ffi_api.Prefetch(buffer, bounds) # type: ignore[attr-defined] # pylint: disable=no-member
def customized_code(code: str):
"""Add a customized code block.
Parameters
----------
code : str
The code block to be added.
"""
return _ffi_api.CustomizedCode(code) # type: ignore[attr-defined] # pylint: disable=no-member
def evaluate(value: PrimExpr) -> None:
"""Evaluate the input expression.
......
......@@ -24,9 +24,9 @@ def get_dll_directories():
source_dir = os.path.abspath(os.path.join(curr_dir, ".."))
dll_path = [
curr_dir,
os.path.join(curr_dir, "lib"), # pypi build
os.path.join(source_dir, "build"), # local build
os.path.join(source_dir, "build", "Release"),
os.path.join(curr_dir, "lib"), # pypi build
]
if TILELANG_LIBRARY_PATH:
dll_path.append(TILELANG_LIBRARY_PATH)
......
......@@ -95,6 +95,22 @@ def WarpSpecializedPipeline():
return _ffi_api.WarpSpecializedPipeline() # type: ignore
def ThreadSync(storage_scope: str):
"""Insert sync between parallel read/write of shared buffers.
Parameters
----------
storage_scope: str
The target storage scope.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.ThreadSync(storage_scope) # type: ignore
def ThreadPartialSync(storage_scope: str):
"""Insert partial sync.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment