Unverified Commit a7c9a8b9 authored by Siyuan Feng's avatar Siyuan Feng Committed by GitHub
Browse files

Refactor to support upstream tvm (#595)

**Summarize part of the rebase pr:**

1. **Support T.thread_return() → CUDA return syntax**  
   Added support for translating `T.thread_return()` to CUDA's native `return` statement.

2. **Dynamic type support for function inputs**  
   Functions now accept dynamically typed parameters using `typing`:
   ```python
   dyn_type = T.int32 or T.float
   @T.prim_func
   def main(
       a: dyn_type,
   )
   ```

3. **Device Function Codegen**  
   Added support for generating `__device__` functions in CUDA:
   ```python
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def add(a: T.int32, b: T.int32) -> T.int32:
           return a + b

       @T.prim_func
       def main(
           A: T.Buffer((128, 128), "int32"),
           B: T.Buffer((128, 128), "int32"),
           C: T.Buffer((128, 128), "int32"),
       ):
           T.func_attr({"global_symbol": "main"})
           length: T.int32 = Module.add(64, 64)  # Host call
           for bx in...
parent 8edd6941
......@@ -6,7 +6,7 @@
#include "tvm/tir/expr.h"
#include "tvm/tir/stmt.h"
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
......@@ -209,8 +209,10 @@ tvm::transform::Pass LowerSharedBarrier() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedBarrier", {});
}
TVM_REGISTER_GLOBAL("tl.transform.LowerSharedBarrier")
.set_body_typed(LowerSharedBarrier);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerSharedBarrier", LowerSharedBarrier);
});
} // namespace transform
} // namespace tl
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Lower allreduce to device implementable ir.
* \file lower_thread_allreduce.cc
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
#include "tir/transforms/update_pointer_storage_scope.h"
namespace tvm {
namespace tl {
using namespace tir;
using runtime::StorageRank;
using runtime::StorageScope;
/*!
* \brief collect the mapping from the buffer var to its allocate
*/
class AllocateCollector : public StmtExprVisitor {
private:
bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn";
}
bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == "";
}
public:
void VisitStmt_(const AllocateNode *op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
dyn_shmem_allocs_[op->buffer_var.get()] = op;
} else if (IsStaticSharedMemory(op->buffer_var)) {
static_shmem_allocs_[op->buffer_var.get()] = op;
}
StmtExprVisitor::VisitStmt_(op);
}
// The dynamic mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const AllocateNode *> dyn_shmem_allocs_;
// The static mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const AllocateNode *>
static_shmem_allocs_;
};
class ThreadAllreduceBuilder final : public StmtExprMutator {
public:
explicit ThreadAllreduceBuilder(const TargetNode *target,
bool is_dynamic = false)
: target_(target),
warp_size_(
target->GetAttr<Integer>("thread_warp_size", 1).value().IntValue()),
max_num_threads_(target->GetAttr<Integer>("max_num_threads", -1)
.value()
.IntValue()) {
if (is_dynamic) {
shared_scope = "shared.dyn";
}
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
thread_extents_.push_back(op);
Stmt ret = StmtExprMutator::VisitStmt_(op);
thread_extents_.pop_back();
return ret;
} else if (op->attr_key == tir::attr::reduce_scope) {
const CommReducerNode *combiner = op->node.as<CommReducerNode>();
ICHECK(combiner);
reduce_combiner_.push_back(combiner);
Stmt ret = StmtExprMutator::VisitStmt_(op);
reduce_combiner_.pop_back();
return ret;
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const EvaluateNode *op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<EvaluateNode>();
const CallNode *call = op->value.as<CallNode>();
if (call && call->op.same_as(builtin::tvm_thread_allreduce())) {
return MakeAllreduce(call);
} else {
return stmt;
}
}
Stmt VisitStmt_(const AllocateNode *op) final {
auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
if (auto it = alloc_remap_.find(node->buffer_var.get());
it != alloc_remap_.end()) {
Buffer buf = Downcast<Buffer>(it->second);
auto write_ptr = node.CopyOnWrite();
write_ptr->buffer_var = buf->data;
write_ptr->dtype = buf->dtype;
write_ptr->extents = buf->shape;
write_ptr->condition = const_true(buf->dtype.lanes());
if (buf.scope() == shared_scope) {
// Use volatile access to shared buffer.
write_ptr->body =
AttrStmt(buf->data, tir::attr::volatile_scope, 1, write_ptr->body);
}
}
return std::move(node);
}
Optional<Buffer> GetRemappedBuffer(const Buffer &buf) {
if (auto it = buf_remap_.find(buf.get()); it != buf_remap_.end()) {
return it->second;
}
if (auto it = var_remap_.find(buf->data.get()); it != var_remap_.end()) {
Buffer new_buf = buf;
new_buf.CopyOnWrite()->data = it->second;
buf_remap_[buf.get()] = new_buf;
return new_buf;
}
return std::nullopt;
}
Stmt VisitStmt_(const DeclBufferNode *op) final {
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
if (auto buf = GetRemappedBuffer(node->buffer)) {
node.CopyOnWrite()->buffer = buf.value();
}
return std::move(node);
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
if (auto it = load_remap_.find(op->buffer->data.get());
it != load_remap_.end()) {
for (const auto &index : op->indices) {
ICHECK(is_zero(index))
<< "The index of buffer " << op->buffer << " is " << index;
}
return it->second;
}
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
op = load.get();
if (auto opt = GetRemappedBuffer(load->buffer)) {
load.CopyOnWrite()->buffer = opt.value();
}
return std::move(load);
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (auto opt = GetRemappedBuffer(store->buffer)) {
store.CopyOnWrite()->buffer = opt.value();
}
return std::move(store);
}
private:
// Thread entry
struct ThreadEntry {
runtime::ThreadScope scope;
IterVar iv;
int extent;
// comparator
bool operator<(const ThreadEntry &other) const {
return scope.dim_index < other.scope.dim_index;
}
};
// make allreduce.
Stmt MakeAllreduce(const CallNode *call) {
ICHECK(!reduce_combiner_.empty());
const CommReducerNode *combiner = reduce_combiner_.back();
size_t size = combiner->result.size();
const IntImmNode *size_of_args = call->args[0].as<IntImmNode>();
ICHECK(size_of_args) << call->args[0]->GetTypeKey();
ICHECK_EQ(size, size_of_args->value);
Array<PrimExpr> inits = combiner->identity_element;
std::vector<PrimExpr> values(size);
std::vector<DataType> types(size);
PrimExpr cond = call->args[size + 1];
for (size_t idx = 0; idx < size; ++idx) {
values[idx] = call->args[1 + idx];
if (!is_one(cond)) {
values[idx] = Select(cond, values[idx], inits[idx]);
}
types[idx] = values[idx].dtype();
}
std::vector<Buffer> buffers(size);
for (size_t idx = 0; idx < size; ++idx) {
PrimExpr arg = call->args[2 + size + idx];
// Loads from boolean buffers may have cast nodes inserted by
// earlier passes.
if (auto cast = arg.as<CastNode>()) {
arg = cast->value;
}
buffers[idx] = Downcast<BufferLoad>(arg)->buffer;
}
std::unordered_set<const VarNode *> reduce_set;
for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
const VarNode *v = call->args[i].as<VarNode>();
// The simply optimization replace a iteration variable with a constant
// when extent of the iteration is 1. As threaded IterVar always started
// from 0, we can just ignore this variable in this case.
if (v) {
reduce_set.insert(v);
} else {
ICHECK(call->args[i].as<IntImmNode>() &&
call->args[i].as<IntImmNode>()->value == 0)
<< "arg" << i << "should be a VarNode or IntImmNode "
<< "while it is " << call->args[i];
}
}
size_t nmatch = 0;
std::vector<ThreadEntry> vred, vpar;
int reduce_dim_index = -1;
for (const AttrStmtNode *attr : thread_extents_) {
ThreadEntry e;
IterVar iv = Downcast<IterVar>(attr->node);
e.scope = runtime::ThreadScope::Create(iv->thread_tag);
e.iv = iv;
ICHECK_LE(e.scope.rank, 1);
ICHECK_GE(e.scope.dim_index, 0)
<< "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) {
const auto *ptr = attr->value.as<IntImmNode>();
ICHECK(ptr) << "Need constant extent for reduce set " << iv;
e.extent = static_cast<int>(ptr->value);
// ignore variables equal to 0
if (e.extent == 1) {
continue;
}
if (reduce_set.count(iv->var.get())) {
bool already_exists = false;
for (const auto &entry : vred) {
if (entry.scope.dim_index == e.scope.dim_index) {
already_exists = true;
break;
}
}
if (!already_exists) {
vred.push_back(e);
++nmatch;
reduce_dim_index = e.scope.dim_index;
}
} else {
bool already_exists = false;
for (const auto &entry : vpar) {
if (entry.scope.dim_index == e.scope.dim_index) {
already_exists = true;
break;
}
}
if (!already_exists) {
vpar.push_back(e);
}
}
}
}
// remove reduce thread from parallel thread
if (reduce_dim_index != -1) {
for (size_t i = 0; i < vpar.size(); ++i) {
if (vpar[i].scope.dim_index == reduce_dim_index) {
vpar.erase(vpar.begin() + i);
break;
}
}
}
ICHECK_EQ(nmatch, reduce_set.size())
<< "Not all reduce index are presented in the context";
std::sort(vred.begin(), vred.end());
std::sort(vpar.begin(), vpar.end());
// the size of each index.
int reduce_extent, group_extent;
PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
PrimExpr group_index = FlattenThread(vpar, &group_extent);
// the longest contiguous reduce extent after flattening
int contiguous_reduce_extent = 1;
std::vector<std::tuple<int, int, bool>>
block_threads; // tuple(dim_index, extent, is_reduce)
for (const ThreadEntry &thr : vred) {
if (thr.scope.rank == 1) { // threadIdx
block_threads.emplace_back(thr.scope.dim_index, thr.extent, true);
}
}
for (const ThreadEntry &thr : vpar) {
if (thr.scope.rank == 1) { // threadIdx
block_threads.emplace_back(thr.scope.dim_index, thr.extent, false);
}
}
// sort according to dim_index
std::sort(block_threads.begin(), block_threads.end());
for (auto &&thr_attr : block_threads) {
auto [dim_index, extent, is_reduce] = thr_attr;
(void)dim_index; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
if (is_reduce) {
contiguous_reduce_extent *= extent;
} else {
break;
}
}
std::vector<Stmt> seq;
std::vector<Buffer> new_alloc_bufs;
//
// This is an optimization. For small reduction sizes, it may be beneficial
// for a single warp to performance the entire reduction. No trips to shared
// memory and no cross warp synchronizations are required.
// The following code emits the reduction as follows:
//
// Allocate reduction vars v[i], i = 0..size-1
//
// for offset from WARP_SIZE to 1 by 2
//
// a <- load(v[i])
// b <- shuffle_down(load(v[i], offset))
// v[i] <- reduction(a, b)
//
// broadcast results from lane 0 to all other lanes and store
// the final reduction result to the proper location.
//
// When the thread extent is multiple of warp size, we can use a two-stage
// warp-level reduction to optimize. This is implemented by applying the
// algorithm above twice.
//
// For example, suppose we want to use 512 threads to reduce 512 elements
// and the warp size is 32. In this case there are (512 / 32) = 16 warps.
// In the first stage, each of the 16 warps reduces 32 elements. So after
// the stage, we have 16 remaining elements to be reduced, one for each
// warp. We store the 16 elements in shared memory, and start the second
// stage. In the second stage we use the first 16 lanes of the first warp to
// reduce the remaining elements, and this reduction can also be optimized
// by shuffle_down warp-level primitives.
PrimExpr zero_index = make_const(reduce_index->dtype, 0);
if (IsWarpReduction(types, group_extent, reduce_extent,
contiguous_reduce_extent)) {
std::vector<PrimExpr> reduce_results;
DataType mask_dtype = DataType::UInt(32);
PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
if (reduce_extent <= warp_size_) {
std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
values, types, combiner, reduce_index, reduce_extent, group_index,
mask, std::nullopt, &seq);
// Broadcast the reduction result from lane 0 to all other lanes.
// This avoids to emit predicated stores, as all threads are
// uniformly writing the same result.
for (size_t i = 0; i < size; ++i) {
Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
PrimExpr val = BufferLoad(buf, {zero_index});
ICHECK_EQ(val->dtype, types[i]);
PrimExpr splat =
WarpShuffle(builtin::tvm_warp_shuffle(), new_alloc_bufs.back(),
val, reduce_extent * group_index);
seq.push_back(BufferStore(buf, splat, {zero_index}));
}
} else {
int n_warps = reduce_extent / warp_size_;
std::vector<Buffer> local_bufs;
// 1. Create the staging buffer in shared memory.
std::vector<Buffer> staging_shared_bufs;
staging_shared_bufs.reserve(size);
for (size_t i = 0; i < size; ++i) {
Buffer staging_shared_buf = decl_buffer(
/*shape=*/{make_const(reduce_index->dtype,
n_warps * group_extent)},
/*dtype=*/buffers[i]->dtype, /*name=*/"red_buf_staging",
/*storage_scope=*/shared_scope);
staging_shared_bufs.push_back(staging_shared_buf);
new_alloc_bufs.push_back(staging_shared_buf);
}
// 2. First round of allreduce.
std::tie(reduce_results, local_bufs) =
MakeWarpAllreduce(values, types, combiner, reduce_index, warp_size_,
group_index, mask, std::nullopt, &seq);
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(),
local_bufs.end());
// 3. Write allreduce results to staging buffer.
std::vector<Stmt> write_staging_buf;
write_staging_buf.reserve(size);
for (size_t i = 0; i < size; ++i) {
new_alloc_bufs.push_back(
Downcast<BufferLoad>(reduce_results[i])->buffer);
write_staging_buf.push_back(BufferStore(
/*buffer=*/staging_shared_bufs[i],
/*value=*/reduce_results[i],
/*indices=*/
{group_index * n_warps + floordiv(reduce_index, warp_size_)}));
}
PrimExpr cond = floormod(reduce_index, warp_size_) == zero_index;
seq.push_back(IfThenElse(cond, SeqStmt::Flatten(write_staging_buf)));
seq.push_back(SyncThread(shared_scope));
// 4. Load staging buffer.
// Second round of allreduce.
for (size_t i = 0; i < size; ++i) {
values[i] =
BufferLoad(/*buffer=*/staging_shared_bufs[i],
/*indices=*/{group_index * n_warps + reduce_index});
}
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
values, types, combiner, reduce_index, n_warps, group_index, mask,
/*predicate=*/reduce_index <
make_const(reduce_index->dtype, n_warps),
&seq);
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(),
local_bufs.end());
// 5. Create shared memory buffer(s) of `group_extent` elements, storing
// the allreduce results so each thread can access.
std::vector<Stmt> write_result;
write_result.reserve(size);
for (size_t i = 0; i < size; ++i) {
new_alloc_bufs.push_back(
Downcast<BufferLoad>(reduce_results[i])->buffer);
Buffer broadcast_shared_buf = decl_buffer(
/*shape=*/{make_const(reduce_index->dtype, group_extent)},
/*dtype=*/buffers[i]->dtype, /*name=*/"red_result",
/*storage_scope=*/shared_scope);
write_result.push_back(BufferStore(broadcast_shared_buf,
reduce_results[i], {group_index}));
// Update `reduce_results`, pointing to the value loaded from the
// shared memory buffer.
reduce_results[i] = BufferLoad(broadcast_shared_buf, {group_index});
}
seq.push_back(IfThenElse(reduce_index == zero_index,
SeqStmt::Flatten(write_result)));
seq.push_back(SyncThread(shared_scope));
}
// Write back allreduce results and update existing allocations.
for (size_t i = 0; i < size; ++i) {
ICHECK(!load_remap_.count(buffers[i]->data.get()));
PrimExpr pred = const_true(types[i].lanes());
Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
ICHECK_EQ(reduce_results[i]->dtype, types[i]);
load_remap_[buffers[i]->data.get()] = reduce_results[i];
auto node =
Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0));
alloc_remap_[buffers[i]->data.get()] = buf;
var_remap_[buffers[i]->data.get()] = buf->data;
buf_remap_[buffers[i].get()] = buf;
}
} else {
std::vector<Buffer> shared_bufs(size);
if (reduce_extent == 1) {
// special case, no reduction is needed.
std::vector<Stmt> stores;
for (size_t i = 0; i < size; ++i) {
stores.push_back(BufferStore(buffers[i], values[i], {0}));
}
return SeqStmt::Flatten(stores);
}
// This sync is necessary because there might be incomplete read of
// previous iteration on the same buffer.
seq.emplace_back(SyncThread(shared_scope));
for (size_t idx = 0; idx < size; ++idx) {
shared_bufs[idx] = decl_buffer(
{IntImm(group_index->dtype, group_extent * reduce_extent)},
types[idx], "red_buf" + std::to_string(idx), shared_scope);
seq.emplace_back(
BufferStore(shared_bufs[idx], values[idx],
{BufIndex(reduce_index, group_index, reduce_extent)}));
}
seq.emplace_back(SyncThread(shared_scope));
seq.emplace_back(MakeBufAllreduce(
combiner, types, shared_bufs, reduce_index, group_index,
reduce_extent, group_extent, contiguous_reduce_extent));
for (size_t idx = 0; idx < size; ++idx) {
ICHECK(!load_remap_.count(buffers[idx]->data.get()));
PrimExpr pred = const_true(types[idx].lanes());
BufferLoad load(shared_bufs[idx],
{BufIndex(make_zero(reduce_index.dtype()), group_index,
reduce_extent)});
ICHECK_EQ(load->dtype, types[idx]);
load_remap_[buffers[idx]->data.get()] = load;
alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx];
var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
buf_remap_[buffers[idx].get()] = shared_bufs[idx];
}
}
// Fix all local allocations as all statements are built.
Stmt body = SeqStmt::Flatten(seq);
for (Buffer buf : new_alloc_bufs) {
body = DeclBuffer(buf, body);
body = Allocate(buf->data, buf->dtype, buf->shape,
const_true(buf->dtype.lanes()), body);
}
return body;
}
std::pair<std::vector<PrimExpr>, std::vector<Buffer>>
MakeWarpAllreduce(std::vector<PrimExpr> src_values, //
std::vector<DataType> dtypes, //
const CommReducerNode *combiner, //
PrimExpr reduce_index, int reduce_extent, //
PrimExpr group_index, //
PrimExpr mask, Optional<PrimExpr> predicate, //
std::vector<Stmt> *seq) {
int n_buffers = src_values.size();
std::vector<Buffer> shared_bufs;
std::vector<Buffer> local_bufs;
shared_bufs.reserve(n_buffers);
// This is the index to the reduction variable, one reduction
// variable per warp. Local scope seems easier to reason without
// relying on a pattern match pass to fix it later.
Array<PrimExpr> zero_indices = {0};
Array<PrimExpr> shape = {1};
std::vector<Stmt> load_values;
load_values.reserve(n_buffers);
for (int idx = 0; idx < n_buffers; ++idx) {
shared_bufs.push_back(decl_buffer(
shape, dtypes[idx], "red_buf" + std::to_string(idx), "local"));
load_values.push_back(
BufferStore(shared_bufs[idx], src_values[idx], zero_indices));
// Uses a local variable to store the shuffled data. Later
// on, an allocation will be built for this local variable.
local_bufs.push_back(
decl_buffer(shape, dtypes[idx], "t" + std::to_string(idx), "local"));
}
if (predicate.defined()) {
seq->push_back(
IfThenElse(predicate.value(), SeqStmt::Flatten(load_values)));
} else {
seq->insert(seq->end(), load_values.begin(), load_values.end());
}
// The mask for this reducer, as this reducer may sit inside
// a divergent control flow. Here it uses a variable to cache the current
// active channels.
Optional<Buffer> mask_buffer;
if (need_warp_shuffle_mask_) {
mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local");
seq->emplace_back(BufferStore(mask_buffer.value(), mask, zero_indices));
// Push the buffer description. Later this will have an
// allocation built for it.
local_bufs.push_back(mask_buffer.value());
}
// Emit reductions within a warp.
int start_offset = 1;
while (start_offset * 2 < reduce_extent) {
start_offset *= 2;
}
for (int offset = start_offset; offset > 0; offset /= 2) {
// Load reduction values, no synchronization needed.
Array<PrimExpr> a, b;
for (int i = 0; i < n_buffers; ++i) {
Buffer shared_buf = shared_bufs[i];
BufferLoad val(shared_buf, zero_indices);
ICHECK_EQ(val->dtype, dtypes[i]);
a.push_back(val);
// __shfl_*sync calls shall not appear in if_then_else expressions
// as this is causing extra divergency. E.g.
//
// v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0);
//
// behaves differently from
//
// int t = __shfl_sync(mask, v1, 0);
// v1 = (v2 < v3) ? v3 : t;
//
// The former may cause dead lock as there is a divergent
// branch with a warp sync call inside.
PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(),
mask_buffer, val, offset);
Buffer local_buf = local_bufs[i];
Stmt s = BufferStore(local_buf, other, zero_indices);
seq->push_back(s);
BufferLoad load = BufferLoad(local_buf, zero_indices);
ICHECK_EQ(load->dtype, dtypes[i]);
b.push_back(load);
}
// Do reductions.
Array<PrimExpr> ret = (*combiner)(a, b);
// Store the reduction result to itself.
std::vector<Stmt> stores;
stores.reserve(n_buffers);
for (int i = 0; i < n_buffers; ++i) {
Buffer buf = shared_bufs[i];
stores.push_back(BufferStore(buf, ret[i], zero_indices));
}
// During the sub-warp reduction, values from inactive threads could be
// read, which is an undefined behavior according to the cuda document.
//
// In practice, the return value are usually 0, which does no harm to sum
// reduction. However, the result can be incorrect in max or prod
// reduction. Therefore an additional range check has to be performed to
// ensure the correctness.
if (offset * 2 > reduce_extent) {
PrimExpr cond = reduce_index + offset < reduce_extent;
seq->push_back(IfThenElse(cond, SeqStmt::Flatten(stores)));
} else {
seq->push_back(SeqStmt::Flatten(stores));
}
}
std::vector<PrimExpr> reduce_results;
reduce_results.reserve(n_buffers);
for (int i = 0; i < n_buffers; ++i) {
reduce_results.push_back(BufferLoad(shared_bufs[i], zero_indices));
}
return {reduce_results, local_bufs};
}
// make allreduce.
Stmt MakeBufAllreduce(const CommReducerNode *combiner,
const std::vector<DataType> &types,
const Array<Buffer> &shared_bufs, PrimExpr reduce_index,
PrimExpr group_index, int reduce_extent,
int group_extent, int contiguous_reduce_extent) {
// Get next power of two
int reduce_align = 1;
while (reduce_extent > reduce_align) {
reduce_align = reduce_align << 1;
}
ICHECK_GT(reduce_align, 1);
std::vector<Stmt> seq;
size_t size = shared_bufs.size();
PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
// make reduction
auto fload = [&](int offset) {
Array<PrimExpr> a, b;
for (size_t i = 0; i < size; ++i) {
BufferLoad b_load(
shared_bufs[i],
{BufIndex(reduce_index + offset, group_index, reduce_extent)});
ICHECK_EQ(b_load->dtype, types[i]);
b.push_back(b_load);
BufferLoad a_load(shared_bufs[i], {buf_index});
ICHECK_EQ(a_load->dtype, types[i]);
a.push_back(a_load);
}
Array<PrimExpr> ret = (*combiner)(a, b);
return ret;
};
auto fstore = [&](const Array<PrimExpr> &ret) {
std::vector<Stmt> stores(size);
for (size_t i = 0; i < size; ++i) {
stores[i] = BufferStore(shared_bufs[i], ret[i], {buf_index});
}
return SeqStmt::Flatten(stores);
};
auto freduce = [&](int offset) {
auto ret = fload(offset);
return fstore(ret);
};
// Step one, check for
if (reduce_align > reduce_extent) {
// reduction with the boundary condition
reduce_align = reduce_align >> 1;
PrimExpr cond = reduce_index < (reduce_extent - reduce_align);
seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread(shared_scope));
}
// normal synchronization
bool warp_align =
group_extent == 1 || contiguous_reduce_extent % warp_size_ == 0;
while (reduce_align > contiguous_reduce_extent ||
reduce_align > warp_size_ || !warp_align) {
if (reduce_align == 1) {
break;
}
reduce_align = reduce_align >> 1;
PrimExpr cond = reduce_index < reduce_align;
seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
seq.emplace_back(SyncThread(shared_scope));
}
// in warp synchronization.
if (reduce_align > 1) {
PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1);
std::vector<Stmt> in_warp_seq;
while (reduce_align > 1) {
reduce_align = reduce_align >> 1;
// freduce can read/write to the same memory location. For
// example, with reduce_align of 4, threadIdx 3 reads from
// memory location 7 as threadIdx 7 is writing to it.
// Therefore, we need to separate out the load from the store
// with a memory barrier in-between. This isn't necessary for
// the earlier normal synchronization, because those are each
// protected by an if-statement. The if-statement is avoided
// here to reduce thread divergence.
auto loads = fload(reduce_align);
Array<Var> in_warp_local_vars;
for (auto expr : loads) {
Var var("w_" + std::to_string(reduce_align) + "_" +
std::to_string(in_warp_local_vars.size()),
expr->dtype);
in_warp_local_vars.push_back(var);
}
std::vector<Stmt> in_let_statement;
in_let_statement.emplace_back(SyncThread("warp"));
in_let_statement.emplace_back(
fstore({in_warp_local_vars.begin(), in_warp_local_vars.end()}));
in_let_statement.emplace_back(SyncThread("warp"));
Stmt body = SeqStmt::Flatten(in_let_statement);
for (size_t i = 0; i < size; i++) {
body = LetStmt(in_warp_local_vars[i], loads[i], body);
}
in_warp_seq.push_back(body);
}
Stmt warp_body = SeqStmt::Flatten(in_warp_seq);
seq.emplace_back(IfThenElse(in_warp_cond, warp_body));
seq.emplace_back(SyncThread(shared_scope));
}
return SeqStmt::Flatten(seq);
}
// Flatten the thread index.
// Also return a warp number,
PrimExpr FlattenThread(const std::vector<ThreadEntry> &tvec,
int *out_total_extent) {
int &total_extent = *out_total_extent;
total_extent = 1;
if (tvec.size() == 0) {
return make_zero(DataType::Int(32));
}
PrimExpr ret;
for (const ThreadEntry &e : tvec) {
if (ret.defined()) {
ret = ret + e.iv->var * total_extent;
} else {
ICHECK_EQ(total_extent, 1);
ret = e.iv->var;
}
total_extent *= e.extent;
}
return ret;
}
// The local buffer index.
PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index,
int reduce_extent) {
if (!is_zero(group_index)) {
return analyzer_.Simplify(group_index * reduce_extent + reduce_index);
} else {
return reduce_index;
}
}
// sync thread op.
static Stmt SyncThread(const std::string &sync) {
return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
{StringImm(sync)}));
}
// Emit warp shuffle calls.
PrimExpr WarpShuffle(const Op &op, Optional<Buffer> mask_buffer, PrimExpr val,
PrimExpr delta_or_lane) {
Array<PrimExpr> indices = {0};
PrimExpr mask;
if (mask_buffer.defined()) {
mask = BufferLoad(mask_buffer.value(), indices);
} else {
mask = IntImm(DataType::Int(32), 0);
}
PrimExpr width = IntImm(DataType::Int(32), warp_size_);
Array<PrimExpr> args{mask, val, delta_or_lane, width, width};
return Call(val.dtype(), op, args);
}
// Check if we can use warp level reduction.
//
// Note: The ROCm backend will only have warp reductions for now.
// Also, the warp/wavefront size differs (64 on rocm, 32 on cuda and metal).
bool IsWarpReduction(const std::vector<DataType> &types, int group_extent,
int reduce_extent, int contiguous_reduce_extent) {
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
(target_->kind->name != "metal")) {
return false;
}
need_warp_shuffle_mask_ = target_->kind->name != "metal";
// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
(std::any_of(types.begin(), types.end(), [](DataType ty) {
if (ty.is_fixed_length_vector())
return ty.bits() * ty.lanes() != 32;
return ty.bits() != 32;
}))) {
return false;
}
// Supported types:
// {u}int, {u}long, {u}long long, float, double, half/half2
if (std::any_of(types.begin(), types.end(), [](DataType ty) {
if (ty.is_float16())
return ty.lanes() > 2;
if (ty.is_fixed_length_vector())
return true;
return ty.bytes() < 4 || ty.bytes() > 8;
})) {
return false;
}
if (thread_extents_.empty()) {
return false;
}
// reduce region must be contiguous.
if (contiguous_reduce_extent != reduce_extent) {
return false;
}
// whether reduce_extent and group_extent are valid for warp reduction.
if (target_->kind->name == "rocm") {
return reduce_extent == warp_size_;
} else {
if (reduce_extent == 1) {
return false; // no need to warp reduce
} else {
bool is_subwarp_reduction = warp_size_ % reduce_extent == 0;
bool is_multiwarp_reduction =
max_num_threads_ != -1 &&
max_num_threads_ <= warp_size_ * warp_size_ &&
reduce_extent % warp_size_ == 0;
if (is_subwarp_reduction || is_multiwarp_reduction) {
return true;
} else {
return group_extent == 1 && reduce_extent <= warp_size_;
}
}
}
}
// The target.
const TargetNode *target_ = nullptr;
// The shared scope.
String shared_scope = "shared";
// The warp size of the device.
int warp_size_{1};
// The maximum number of threads of the device. "-1" denotes unknown.
int max_num_threads_{-1};
// A boolean indicating if the target supports warp-level masking.
bool need_warp_shuffle_mask_;
// surrounding scope of thread extent.
std::vector<const AttrStmtNode *> thread_extents_;
std::vector<const CommReducerNode *> reduce_combiner_;
// The load remap
std::unordered_map<const VarNode *, PrimExpr> load_remap_;
// Allocate remap
std::unordered_map<const VarNode *, Buffer> alloc_remap_;
// BufferVar remap
std::unordered_map<const VarNode *, Var> var_remap_;
// Buffer remap
std::unordered_map<const BufferNode *, Buffer> buf_remap_;
// Internal analyzer
arith::Analyzer analyzer_;
};
namespace transform {
using namespace tir::transform;
tvm::transform::Pass LowerThreadAllreduce() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
AllocateCollector collector;
collector(f->body);
bool is_dynamic = collector.dyn_shmem_allocs_.size() > 1;
auto *n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined())
<< "LowerThreadAllreduce: Require the target attribute";
const TargetNode *target_node = target.as<TargetNode>();
ThreadAllreduceBuilder thread_all_reduce(target_node, is_dynamic);
n->body = thread_all_reduce(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerThreadAllreduce", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerThreadAllreduce",
LowerThreadAllreduce);
});
} // namespace transform
} // namespace tl
} // namespace tvm
......@@ -3,6 +3,7 @@
* \brief Lower the tile op for further codegen.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
......@@ -108,12 +109,14 @@ private:
* \return The rewritten block.
*/
Stmt RewritePaddingMap(const BlockNode *op) {
auto padding_map =
op->annotations.Get(attr::kPaddingMap).as<Map<Var, PrimExpr>>().value();
auto padding_map = op->annotations.Get(attr::kPaddingMap);
if (!padding_map) {
LOG(FATAL) << "Padding map annotation is missing";
}
Map<Var, Var> var_remap = CreateVarRemap();
Map<Var, PrimExpr> new_padding_map =
RemapPaddingMap(padding_map, var_remap);
Map<Var, PrimExpr> new_padding_map = RemapPaddingMap(
Downcast<Map<Var, PrimExpr>>(padding_map.value()), var_remap);
auto block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto block_ptr = block.CopyOnWrite();
......@@ -235,7 +238,7 @@ private:
}
PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr,
Optional<PrimExpr> offset = NullOpt,
Optional<PrimExpr> offset = std::nullopt,
DataType dtype = DataType::Int(32)) {
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
// accumulate it to smem_offset
......@@ -318,7 +321,7 @@ private:
op->op.same_as(tl::tma_store()))) {
has_tma_ = true;
}
Array<RelayExpr> ptx_instructions = {builtin::ptx_ldmatrix(),
Array<RelaxExpr> ptx_instructions = {builtin::ptx_ldmatrix(),
builtin::mma_store()};
if (std::find(ptx_instructions.begin(), ptx_instructions.end(), op->op) ==
......@@ -354,7 +357,7 @@ private:
// mma_store now
auto access_ptr = call->args[2];
auto new_access_ptr =
HandleAccessPtrAndOffset(access_ptr, NullOpt, call->dtype);
HandleAccessPtrAndOffset(access_ptr, std::nullopt, call->dtype);
auto new_call = call.CopyOnWrite();
new_call->args.Set(2, new_access_ptr);
} else {
......@@ -496,7 +499,10 @@ tvm::transform::Pass LowerTileOp() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {});
}
TVM_REGISTER_GLOBAL("tl.transform.LowerTileOp").set_body_typed(LowerTileOp);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerTileOp", LowerTileOp);
});
} // namespace transform
} // namespace tl
......
......@@ -20,8 +20,10 @@
/*!
* \file make_packed_api.cc Lower PrimFunc to use the packed function API.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
......@@ -30,7 +32,6 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -75,7 +76,7 @@ public:
private:
struct ConvertedInfo {
int tcode{-1};
int type_index{-1};
PrimExpr expr;
Buffer dummy_val_buffer;
Buffer dummy_tcode_buffer;
......@@ -87,13 +88,13 @@ private:
// convert val's data type to FFI data type, return type code
DataType dtype = val.dtype();
if (dtype.is_int() || dtype.is_uint()) {
info.tcode = kTVMArgInt;
info.type_index = ffi::TypeIndex::kTVMFFIInt;
info.expr = Cast(DataType::Int(64), val);
} else if (dtype.is_float()) {
info.tcode = kTVMArgFloat;
info.type_index = ffi::TypeIndex::kTVMFFIFloat;
info.expr = Cast(DataType::Float(64), val);
} else if (dtype.is_void()) {
info.tcode = kTVMNullptr;
info.type_index = ffi::TypeIndex::kTVMFFINone;
info.expr = val;
} else {
LOG(FATAL) << "data type " << dtype << " not supported yet";
......@@ -101,18 +102,18 @@ private:
// If multiple return locations have the same data type, use the
// same dummy buffer declaration.
auto it = dummy_val_buffer_map_.find(info.tcode);
auto it = dummy_val_buffer_map_.find(info.type_index);
if (it != dummy_val_buffer_map_.end()) {
info.dummy_val_buffer = it->second;
} else {
info.dummy_val_buffer =
Buffer(ret_var_, info.expr.dtype(), {1}, {1}, ConstInt32(0),
ret_var_->name_hint, 0, 0, kDefault);
dummy_val_buffer_map_[info.tcode] = info.dummy_val_buffer;
dummy_val_buffer_map_[info.type_index] = info.dummy_val_buffer;
}
// The tcode is always a 32-bit int, so we don't need to have a separate
// map.
// The type_index is always a 32-bit int, so we don't need to have a
// separate map.
if (!dummy_tcode_buffer_.defined()) {
dummy_tcode_buffer_ =
Buffer(ret_tcode_, DataType::Int(32), {1}, {1}, ConstInt32(0),
......@@ -126,7 +127,8 @@ private:
Stmt WriteToOut(PrimExpr val) {
auto info = ConvertForFFI(val);
Stmt store_val = BufferStore(info.dummy_val_buffer, info.expr, {0});
Stmt store_tcode = BufferStore(info.dummy_tcode_buffer, info.tcode, {0});
Stmt store_tcode =
BufferStore(info.dummy_tcode_buffer, info.type_index, {0});
Stmt ret_zero = Evaluate(tvm::ret(0));
return SeqStmt({store_val, store_tcode, ret_zero});
}
......@@ -153,7 +155,7 @@ public:
if (rewriter.made_change_) {
return stmt;
} else {
return NullOpt;
return std::nullopt;
}
}
......@@ -204,21 +206,21 @@ inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) {
* \param func The function to be inspected
*
* \returns The global_symbol to be used for the function at call
* sites, or NullOpt if the function is to remain unchanged.
* sites, or std::nullopt if the function is to remain unchanged.
*/
Optional<String> RequiresPackedAPI(const PrimFunc &func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
if (auto opt = func->GetAttr<Integer>(tvm::attr::kCallingConv)) {
if (CallingConv(opt.value()->value) != CallingConv::kDefault) {
return NullOpt;
return std::nullopt;
}
}
// Internal function calls do not need the PackedFunc API
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (!global_symbol.defined()) {
return NullOpt;
return std::nullopt;
}
return global_symbol;
......@@ -344,9 +346,9 @@ PrimFunc MakePackedAPI(PrimFunc func) {
}
// type code checks
Var tcode(param->name_hint + ".code", DataType::Int(32));
Var type_index(param->name_hint + ".code", DataType::Int(32));
seq_init.emplace_back(LetStmt(
tcode,
type_index,
BufferLoad(buf_packed_arg_type_ids, {IntImm(DataType::Int(32), i)}),
nop));
DataType t = param.dtype();
......@@ -354,20 +356,22 @@ PrimFunc MakePackedAPI(PrimFunc func) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_init.emplace_back(
AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle ||
tcode == kTVMDLTensorHandle || tcode == kTVMNullptr,
AssertStmt(type_index == ffi::TypeIndex::kTVMFFINone ||
type_index == ffi::TypeIndex::kTVMFFIOpaquePtr ||
type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr ||
type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin,
tvm::tir::StringImm(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_init.emplace_back(
AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));
seq_init.emplace_back(AssertStmt(type_index == kDLInt,
tvm::tir::StringImm(msg.str()), nop));
} else {
ICHECK(t.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_init.emplace_back(
AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop));
seq_init.emplace_back(AssertStmt(type_index == kDLFloat,
tvm::tir::StringImm(msg.str()), nop));
}
}
......@@ -406,13 +410,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
seq_check.push_back(
AttrStmt(node, tir::attr::device_type, device_type, nop));
bool need_set_device =
(target_device_type != kDLMicroDev &&
(
// or is c source target
target_device_type != kDLCPU || target->kind->name != "llvm"));
if (need_set_device) {
if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) {
Stmt set_device =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(runtime::symbol::tvm_set_device),
......@@ -468,7 +466,6 @@ PrimFunc MakePackedAPI(PrimFunc func) {
<< " are used, but are not passed in as API arguments";
func_ptr->buffer_map = Map<Var, Buffer>();
func_ptr->checked_type_ = func_ptr->func_type_annotation();
func_ptr->ret_type = PrimType(DataType::Int(32)); // return the function.
return func;
}
......@@ -516,8 +513,10 @@ tvm::transform::Pass MakePackedAPI() {
return tvm::transform::CreateModulePass(pass_func, 0, "tl.MakePackedAPI", {});
}
TVM_REGISTER_GLOBAL("tl.transform.MakePackedAPI").set_body_typed([]() {
return MakePackedAPI();
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MakePackedAPI",
[]() { return MakePackedAPI(); });
});
} // namespace tl
......
......@@ -3,6 +3,7 @@
* \brief Merge the If Stmt in SeqStmt
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
......@@ -91,7 +92,10 @@ tvm::transform::Pass MergeIfStmt() {
return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {});
}
TVM_REGISTER_GLOBAL("tl.transform.MergeIfStmt").set_body_typed(MergeIfStmt);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MergeIfStmt", MergeIfStmt);
});
} // namespace tl
} // namespace tvm
......@@ -23,8 +23,9 @@
* memory allocation. This pass merges multiple TIR-level dynamic or static
* shared memory allocations into one allocation.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
......@@ -1048,8 +1049,11 @@ Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false,
{});
}
TVM_REGISTER_GLOBAL("tl.transform.MergeSharedMemoryAllocations")
.set_body_typed(MergeSharedMemoryAllocations);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MergeSharedMemoryAllocations",
MergeSharedMemoryAllocations);
});
} // namespace transform
} // namespace tl
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file warp_specialized_pipeline.cc
* \brief Warp specialized Pipeline for cuda GPU (sm90+)
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
......@@ -220,14 +202,14 @@ private:
Stmt VisitStmt_(const ForNode *op) final {
loop_stack_.emplace_back(op->loop_var, op->extent);
auto num_stages_anno = op->annotations.Get("num_stages");
if (!num_stages_anno.defined()) {
if (!num_stages_anno) {
auto for_node = StmtExprMutator::VisitStmt_(op);
loop_stack_.pop_back();
return for_node;
}
ICHECK(num_stages_anno.as<IntImmNode>());
int num_stages = static_cast<int>(num_stages_anno.as<IntImmNode>()->value);
ICHECK(num_stages_anno->as<IntImmNode>());
int num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
const SeqStmtNode *pipeline_body_seq = op->body.as<SeqStmtNode>();
CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
......@@ -340,8 +322,10 @@ tvm::transform::Pass MultiVersionBuffer() {
return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {});
}
TVM_REGISTER_GLOBAL("tl.transform.MultiVersionBuffer")
.set_body_typed(MultiVersionBuffer);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MultiVersionBuffer", MultiVersionBuffer);
});
} // namespace tl
} // namespace tvm
......@@ -3,6 +3,7 @@
* \brief Lower L2 persistent annotation
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
......@@ -59,8 +60,10 @@ tvm::transform::Pass PersistThreadblock() {
return CreatePrimFuncPass(pass_func, 0, "tl.PersistThreadblock", {});
}
TVM_REGISTER_GLOBAL("tl.transform.PersistThreadblock")
.set_body_typed(PersistThreadblock);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PersistThreadblock", PersistThreadblock);
});
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file pipeline_planning.cc
* \brief Plan the software pipeline
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
......@@ -224,12 +201,12 @@ private:
auto order_anno = loop->annotations.Get("tl_pipeline_order");
auto stage_anno = loop->annotations.Get("tl_pipeline_stage");
auto num_stages_anno = loop->annotations.Get("num_stages");
if (order_anno.defined() && stage_anno.defined()) {
if (order_anno && stage_anno) {
// Check if order_anno or stage_anno contains -1, which means TMA+WS is
// enabled
bool ws_tma_enabled = false;
auto order_array = Downcast<Array<Integer>>(order_anno);
auto stage_array = Downcast<Array<Integer>>(stage_anno);
auto order_array = Downcast<Array<Integer>>(order_anno.value());
auto stage_array = Downcast<Array<Integer>>(stage_anno.value());
for (const auto &val : order_array) {
if (val->value == -1) {
ws_tma_enabled = true;
......@@ -249,20 +226,20 @@ private:
return StmtExprMutator::VisitStmt_(loop);
}
Map<String, ObjectRef> annotations;
Map<String, Any> annotations;
for (const auto &[key, value] : loop->annotations) {
if (key != "tl_pipeline_order") {
annotations.Set(key, value);
}
}
annotations.Set(tir::attr::software_pipeline_order, order_anno);
annotations.Set(tir::attr::software_pipeline_order, order_anno.value());
for (const auto &[key, value] : loop->annotations) {
if (key != "tl_pipeline_stage") {
annotations.Set(key, value);
}
}
annotations.Set(tir::attr::software_pipeline_stage, stage_anno);
annotations.Set(tir::attr::software_pipeline_stage, stage_anno.value());
if (TargetHasAsyncCopy(target_) && use_async_copy_)
annotations.Set(tir::attr::software_pipeline_async_stages,
Array<Integer>{0});
......@@ -271,9 +248,9 @@ private:
return for_node;
}
if (!num_stages_anno.defined())
if (!num_stages_anno)
return StmtExprMutator::VisitStmt_(loop);
int num_stages = num_stages_anno.as<IntImmNode>()->value;
int num_stages = num_stages_anno->as<IntImmNode>()->value;
Stmt pipeline_body{nullptr};
if (const auto *realize = loop->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
......@@ -443,7 +420,7 @@ private:
}
// Finally, make the pipeline annotation
Map<String, ObjectRef> annotations;
Map<String, Any> annotations;
for (const auto &[key, value] : loop->annotations) {
if (key != "num_stages") {
annotations.Set(key, value);
......@@ -496,8 +473,10 @@ tvm::transform::Pass PipelinePlanning() {
return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
}
TVM_REGISTER_GLOBAL("tl.transform.PipelinePlanning")
.set_body_typed(PipelinePlanning);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PipelinePlanning", PipelinePlanning);
});
} // namespace tl
} // namespace tvm
/*!
* \file simplify.cc
* \brief Remove useless parameters of TL PrimFunc.
* \brief Statement simplifier based on analyzer and remove useless parameters
* of TL PrimFunc.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
......@@ -19,39 +21,45 @@ namespace tl {
using namespace tir;
using namespace arith;
struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
bool transitively_prove_inequalities;
bool propagate_knowns_to_prove_conditional;
bool propagate_knowns_to_simplify_expressions;
bool convert_boolean_to_and_of_ors;
bool apply_constraints_to_boolean_branches;
TVM_DECLARE_ATTRS(SimplifyConfigNode, "tl.transform.SimplifyConfig") {
TVM_ATTR_FIELD(transitively_prove_inequalities)
.describe("If true, simplify conditionals with transitive combinations "
"of scoped constraints")
.set_default(false);
TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional)
.describe("If true, known buffer values are propagated and used to "
"statically prove conditionals")
.set_default(false);
TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions)
.describe("If true, known buffer values are propagated and used to "
"replace BufferLoad wherever "
"possible")
.set_default(false);
TVM_ATTR_FIELD(convert_boolean_to_and_of_ors)
.describe("If true, simplify conditionals into an AND of ORs")
.set_default(false);
TVM_ATTR_FIELD(apply_constraints_to_boolean_branches)
.describe("If true, simplify each branch of AND/OR "
"under a constraints provided by the other branch")
.set_default(false);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<SimplifyConfigNode>()
.def_ro("transitively_prove_inequalities",
&SimplifyConfigNode::transitively_prove_inequalities,
"If true, simplify conditionals with transitive combinations "
"of scoped constraints",
refl::DefaultValue(false))
.def_ro("propagate_knowns_to_prove_conditional",
&SimplifyConfigNode::propagate_knowns_to_prove_conditional,
"If true, known buffer values are propagated and used to "
"statically prove conditionals",
refl::DefaultValue(false))
.def_ro("propagate_knowns_to_simplify_expressions",
&SimplifyConfigNode::propagate_knowns_to_simplify_expressions,
"If true, known buffer values are propagated and used to "
"replace BufferLoad wherever "
"possible",
refl::DefaultValue(false))
.def_ro("convert_boolean_to_and_of_ors",
&SimplifyConfigNode::convert_boolean_to_and_of_ors,
"If true, simplify conditionals into an AND of ORs",
refl::DefaultValue(false))
.def_ro("apply_constraints_to_boolean_branches",
&SimplifyConfigNode::apply_constraints_to_boolean_branches,
"If true, simplify each branch of AND/OR under a constraints "
"provided by the other "
"branch",
refl::DefaultValue(false));
}
static constexpr const char *_type_key = "tl.transform.SimplifyConfig";
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SimplifyConfigNode, BaseAttrsNode);
RewriteSimplifier::Extension GetEnabledExtensions() const {
RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
......@@ -200,6 +208,7 @@ public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs,
SimplifyConfigNode);
};
TVM_FFI_STATIC_INIT_BLOCK({ SimplifyConfigNode::RegisterReflection(); });
TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
......@@ -207,7 +216,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer {
public:
static PrimFunc Apply(PrimFunc func, Analyzer *analyzer,
Optional<SimplifyConfig> config_opt = NullOpt,
Optional<SimplifyConfig> config_opt = std::nullopt,
bool simplify_arguments = false) {
auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
analyzer->rewrite_simplify.SetEnabledExtensions(
......@@ -229,6 +238,7 @@ public:
// Begin to remove useless var and buffer
// First get used buffers
simplifier.used_buffers_ = CollectUsedBuffers(func);
bool param_updated = false;
Array<Var> new_params;
Map<Var, Buffer> new_buffer_map;
......@@ -239,13 +249,18 @@ public:
simplifier.used_buffers_.end()) {
new_params.push_back(var);
new_buffer_map.Set(var, func->buffer_map[var]);
} else if (simplifier.used_in_buffer_def_.find(
func->buffer_map[var]->data.get()) !=
simplifier.used_in_buffer_def_.end()) {
new_params.push_back(var);
new_buffer_map.Set(var, func->buffer_map[var]);
} else {
param_updated = true;
}
}
}
if (simplify_arguments && param_updated) {
if (param_updated) {
return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
new_buffer_map, func->attrs, func->span);
} else {
......@@ -444,7 +459,7 @@ private:
arith::ProofStrength::kSymbolicBound)) {
return Bool(true);
}
return NullOpt;
return std::nullopt;
}
}
......@@ -452,7 +467,7 @@ private:
std::optional<ControlFlowGraph> touch_pattern_;
Map<Var, PrimExpr> non_inlined_bindings_;
Optional<Stmt> current_stmt_{NullOpt};
Optional<Stmt> current_stmt_{std::nullopt};
std::unordered_set<const VarNode *> used_in_buffer_def_;
std::unordered_set<const VarNode *> used_vars_;
std::unordered_set<const BufferNode *> used_buffers_;
......@@ -469,7 +484,10 @@ tvm::transform::Pass Simplify(bool simplify_arguments = true) {
return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
}
TVM_REGISTER_GLOBAL("tl.transform.Simplify").set_body_typed(Simplify);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.Simplify", Simplify);
});
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file storage_rewrite.cc
* \brief Memory access pattern analysis and optimization.
* Re-write data access to enable memory sharing when possible.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/type.h>
#include <tvm/target/target_info.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include "arith/int_operator.h"
#include "runtime/thread_storage_scope.h"
#include "tir/ir/buffer_common.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using runtime::StorageRank;
using runtime::StorageScope;
using namespace tir;
/*!
* \brief Perform data type legalization on the given BufferLoadNode pointer.
* Equal to BufferLoadNode::LegalizeDType, but operates on a pointer.
* \param n A pointer to a writable BufferLoadNode.
*/
static void LegalizeBufferLoadDType(BufferLoadNode *n) {
// Check that all indices except the last one have a scalar dtype
for (int i = 0; i < static_cast<int>(n->indices.size()) - 1; i++) {
ICHECK(n->indices[i].dtype().is_scalar())
<< "Only the last index of a buffer access may be a vector type.";
}
// If there are no indices, set the dtype to the buffer's dtype
if (n->indices.empty()) {
n->dtype = n->buffer->dtype;
} else {
auto index_dtype = n->indices.back().dtype();
bool is_buffer_dtype_scalable = n->buffer->dtype.is_scalable_vector();
bool is_index_scalable = index_dtype.is_scalable_vector();
// Do not allow both index dtype and buffer dtype to be scalable vectors
ICHECK(!(is_index_scalable && is_buffer_dtype_scalable))
<< "Index dtype and buffer dtype cannot both be scalable.";
if (is_index_scalable) {
// Index is a scalable vector, while the buffer is not
n->dtype = n->buffer->dtype.with_scalable_vscale_factor(
index_dtype.vscale_factor() * n->buffer->dtype.lanes());
} else if (is_buffer_dtype_scalable) {
// The buffer is a scalable vector, while the index is not
n->dtype = n->buffer->dtype.with_scalable_vscale_factor(
n->buffer->dtype.vscale_factor() * index_dtype.lanes());
} else {
// Neither side is a scalable vector, multiply lanes
n->dtype = n->buffer->dtype.with_lanes(index_dtype.lanes() *
n->buffer->dtype.lanes());
}
}
}
/*!
* \brief collect the mapping from the buffer var to its allocate
*/
class AllocateCollector : public StmtExprVisitor {
private:
bool IsDynamicSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn";
}
bool IsStaticSharedMemory(Var buffer_var) {
StorageScope storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
return storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == "";
}
public:
void VisitStmt_(const AllocateNode *op) final {
if (IsDynamicSharedMemory(op->buffer_var)) {
dyn_shmem_allocs_[op->buffer_var.get()] = op;
} else if (IsStaticSharedMemory(op->buffer_var)) {
static_shmem_allocs_[op->buffer_var.get()] = op;
}
StmtExprVisitor::VisitStmt_(op);
}
// The dynamic mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const AllocateNode *> dyn_shmem_allocs_;
// The static mapping from the original buffer var to its allocate
std::unordered_map<const VarNode *, const AllocateNode *>
static_shmem_allocs_;
};
// Find a linear pattern of storage access
// Used for liveness analysis.
// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
// before_scope -> scope_body -> after_scope
//
// The linear_seq_ stores before_scope and after_scope.
// The access to the arrays are stored at the after_scope point.
//
// Define "scope" as the body of For/thread_launch/IfThenElse
// This pass tries to detect last point that we need to keep memory
// alive under the same scope as allocate.
// The storage need to be kept alive between allocate and last access.
// The free point is only inserted at the same scope of allocate.
//
class LinearAccessPatternFinder final : public StmtExprVisitor {
public:
/*! \brief record the touch hist of statment. */
struct StmtEntry {
// The statment
const Object *stmt;
// The index in the linear_seq_ to point to end of the nested scope.
// This is only set to non-zero if stmt is a nested scope.
// if offset > 0, means this is the begin, the end entry is current_index +
// offset if offset < 0, means this is the end, the begin entry is
// current_index + offset
int64_t scope_pair_offset{0};
// The buffer variables this statment touched.
std::vector<const VarNode *> touched;
};
// The scope of each allocation
struct AllocEntry {
// The physical dimension of the allocation.
size_t num_physical_dimensions{0};
// scope level
size_t level{0};
// allocation stmt
const AllocateNode *alloc{nullptr};
};
void VisitStmt_(const AllocateNode *op) final {
size_t level = scope_.size();
const VarNode *buf = op->buffer_var.get();
AllocEntry entry;
entry.alloc = op;
entry.level = level;
// Since StorageRewrite occurs after StorageFlatten/FlattenBuffer,
// all allocations specify the extent of physical dimensions, and
// is 1 for flat memory spaces.
entry.num_physical_dimensions = op->extents.size();
alloc_info_[buf] = entry;
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
scope_.push_back(StmtEntry());
// visit subexpr
StmtExprVisitor::VisitStmt_(op);
all_buffers_accessed_.insert(op->buffer.get());
// Add write access.
const VarNode *buffer_var = op->buffer->data.get();
auto it = alloc_info_.find(buffer_var);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
scope_[it->second.level].touched.push_back(buffer_var);
ICHECK_EQ(op->buffer->axis_separators.size() + 1,
it->second.num_physical_dimensions)
<< "Buffer " << op->buffer->name << " is allocated with "
<< it->second.num_physical_dimensions
<< " physical dimensions, but is accessed as having "
<< op->buffer->axis_separators.size() + 1 << " physical dimensions"
<< std::endl;
}
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.touched.size() != 0) {
e.stmt = op;
linear_seq_.push_back(e);
}
}
void VisitExpr_(const BufferLoadNode *op) final {
// Add write access.
StmtExprVisitor::VisitExpr_(op);
all_buffers_accessed_.insert(op->buffer.get());
const VarNode *buffer_var = op->buffer->data.get();
auto it = alloc_info_.find(buffer_var);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size())
<< "Load memory in places other than store.";
scope_[it->second.level].touched.push_back(buffer_var);
ICHECK_EQ(op->buffer->axis_separators.size() + 1,
it->second.num_physical_dimensions)
<< "Buffer " << op->buffer->name << " is allocated with "
<< it->second.num_physical_dimensions
<< " physical dimensions, but is accessed as having "
<< op->buffer->axis_separators.size() + 1 << " physical dimensions"
<< std::endl;
}
}
void VisitStmt_(const EvaluateNode *op) final {
scope_.push_back(StmtEntry());
// visit subexpr
StmtExprVisitor::VisitStmt_(op);
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.touched.size() != 0) {
e.stmt = op;
linear_seq_.push_back(e);
}
}
void VisitExpr_(const VarNode *buf) final {
// Directly reference to the variable count as a read.
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;
scope_[it->second.level].touched.push_back(buf);
}
}
template <typename T> void VisitNewScope(const T *op) {
scope_.push_back(StmtEntry());
StmtEntry e;
e.stmt = op;
int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
// before scope.
linear_seq_.push_back(e);
StmtExprVisitor::VisitStmt_(op);
// after scope.
e.touched = std::move(scope_.back().touched);
scope_.pop_back();
int64_t end_index = static_cast<int64_t>(linear_seq_.size());
ICHECK_GT(end_index, begin_index);
e.scope_pair_offset = begin_index - end_index;
linear_seq_.push_back(e);
// record the pointer to end index.
ICHECK_NE(end_index, 0U);
linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
}
void VisitStmt_(const AttrStmtNode *op) final {
// Only record the outer most thread extent.
if (op->attr_key == tir::attr::thread_extent && !in_thread_env_) {
in_thread_env_ = true;
VisitNewScope(op);
in_thread_env_ = false;
} else if (op->attr_key == tir::attr::extern_scope) {
VisitNewScope(op);
} else if (op->attr_key == tir::attr::virtual_thread) {
VisitNewScope(op);
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
void VisitStmt_(const IfThenElseNode *op) final { VisitNewScope(op); }
void VisitStmt_(const ForNode *op) final { VisitNewScope(op); }
void VisitStmt_(const WhileNode *op) final { VisitNewScope(op); }
void VisitStmt_(const AssertStmtNode *op) final { VisitNewScope(op); }
void VisitStmt_(const LetStmtNode *op) final { VisitNewScope(op); }
// linearized access sequence.
std::vector<StmtEntry> linear_seq_;
// The storage scope of each buffer
std::unordered_map<const VarNode *, AllocEntry> alloc_info_;
// A record of which Buffer objects have been accessed, to prune
// unused DeclBuffer instances.
std::unordered_set<const BufferNode *> all_buffers_accessed_;
private:
// Whether already in thread env.
bool in_thread_env_{false};
// The scope stack.
std::vector<StmtEntry> scope_;
};
// Verify if the statement can be run safely via inplace fashion
//
// Detect pattern: dst[index] = f(src[index])
//
// WARNING: the current detection algorithm cannot handle the case
// when a location in an array is written multiple times
//
// For example, the following program will pass the check,
// but we cannot make A and B to be the same array.
//
// A[0] = B[0] + 1
// A[0] = B[0] + 1
//
// The high level code generator needs to ensure that the generated
// code only write each location of the target array once.
//
// This is the case with IR generated by the current compute schedule.
// We explicitly return false if we find there is an extern block
// which can be arbitrary IR.
//
// Neve-the-less, inplace detector should be used with care in mind.
// We may also consider introduce a condition checker that checks
// if every index only visited once for an absolute sufficient condition.
//
// The code after inplace transformation is no longer idempotent.
//
class InplaceOpVerifier : public StmtExprVisitor {
public:
bool Check(const Object *stmt, const VarNode *dst, const VarNode *src) {
dst_ = dst;
src_ = src;
result_ = true;
if (stmt->IsInstance<AttrStmtNode>()) {
VisitStmt_(static_cast<const AttrStmtNode *>(stmt));
} else if (stmt->IsInstance<ForNode>()) {
VisitStmt_(static_cast<const ForNode *>(stmt));
} else if (stmt->IsInstance<IfThenElseNode>()) {
VisitStmt_(static_cast<const IfThenElseNode *>(stmt));
} else if (stmt->IsInstance<WhileNode>()) {
VisitStmt_(static_cast<const WhileNode *>(stmt));
} else if (stmt->IsInstance<BufferStoreNode>()) {
VisitStmt_(static_cast<const BufferStoreNode *>(stmt));
} else {
return false;
}
return result_;
}
using StmtExprVisitor::VisitStmt_;
void VisitStmt(const Stmt &n) final {
if (!result_)
return;
StmtExprVisitor::VisitStmt(n);
}
void VisitExpr(const PrimExpr &n) final {
if (!result_)
return;
StmtExprVisitor::VisitExpr(n);
}
void VisitExpr_(const VarNode *op) final {
// assume all opaque access is unsafe
if (op == dst_ || op == src_) {
result_ = false;
return;
}
}
void VisitStmt_(const BufferStoreNode *op) final {
++mem_nest_;
for (const auto &index : op->indices) {
this->VisitExpr(index);
}
--mem_nest_;
if (op->buffer->data.get() == dst_) {
store_ = op;
this->VisitExpr(op->value);
store_ = nullptr;
} else {
this->VisitExpr(op->value);
}
}
void VisitStmt_(const AttrStmtNode *op) final {
// always reject extern code
if (op->attr_key == tir::attr::extern_scope ||
op->attr_key == tir::attr::volatile_scope) {
result_ = false;
return;
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const BufferLoadNode *op) final {
const VarNode *buf = op->buffer->data.get();
// cannot read from dst_ (no reduction)
if (buf == dst_) {
result_ = false;
return;
}
// do not allow indirect memory load
if (mem_nest_ != 0) {
result_ = false;
return;
}
if (src_ == buf) {
if (store_ == nullptr || store_->value.dtype() != op->dtype) {
result_ = false;
return;
}
ICHECK_EQ(store_->indices.size(), op->indices.size())
<< "Store/Load occur to the same buffer " << buf->name_hint
<< " with differing number of indices";
for (size_t i = 0; i < store_->indices.size(); i++) {
if (!tir::ExprDeepEqual()(store_->indices[i], op->indices[i])) {
result_ = false;
return;
}
}
}
++mem_nest_;
StmtExprVisitor::VisitExpr_(op);
--mem_nest_;
}
private:
// result of the check
bool result_{true};
// destination memory
const VarNode *dst_;
// source variable
const VarNode *src_;
// counter of load,
// it is not safe to inplace when there is nested load like A[B[i]]
int mem_nest_{0};
// The current store to be inspected
const BufferStoreNode *store_{nullptr};
};
/* \brief Rewrite and merge memory allocation.
*
* Using LinearAccessPatternFinder, determines which buffers could share an
* allocation. This includes both sequential usage of the same buffer and
* merging small allocations at the same scope into a single larger allocation.
* The merging of small allocations requires the codegen to cast the resulting
* value from the storage type to the output type after access.
*/
class StoragePlanRewriter : public StmtExprMutator {
public:
using StmtEntry = LinearAccessPatternFinder::StmtEntry;
using AllocEntry = LinearAccessPatternFinder::AllocEntry;
Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse,
bool reuse_require_exact_matched_dtype) {
detect_inplace_ = detect_inplace;
// plan the rewrite
LinearAccessPatternFinder finder;
finder(stmt);
this->LivenessAnalysis(finder.linear_seq_);
this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse,
reuse_require_exact_matched_dtype);
all_buffers_accessed_ = finder.all_buffers_accessed_;
this->PrepareNewAlloc();
// start rewrite
stmt = operator()(std::move(stmt));
if (attach_map_.count(nullptr)) {
return MakeAttach(attach_map_.at(nullptr), stmt);
}
return stmt;
}
template <typename Node> Node VisitBufferAccess(Node node) {
auto it = alloc_map_.find(node->buffer->data.get());
if (it != alloc_map_.end()) {
Buffer buf = RemapBuffer(node->buffer, it->second->alloc_var);
Array<PrimExpr> indices = node->indices;
indices.Set(indices.size() - 1,
RemapIndex(node->buffer->dtype, indices[indices.size() - 1],
it->second));
auto writer = node.CopyOnWrite();
writer->buffer = buf;
writer->indices = indices;
}
return node;
}
Buffer RemapBuffer(Buffer buf, Var new_backing_array) {
auto key = buf.get();
auto it = buffer_remap_.find(key);
if (it != buffer_remap_.end()) {
ICHECK_EQ(it->second->data.get(), new_backing_array.get())
<< "Cannot remap buffer " << buf->name << " to use backing array "
<< new_backing_array->name_hint << ", previously used backing array "
<< it->second->data->name_hint;
return it->second;
}
Buffer remapped = Buffer(
new_backing_array, buf->dtype, buf->shape, buf->strides,
buf->elem_offset, new_backing_array->name_hint, buf->data_alignment,
buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span);
buffer_remap_[key] = remapped;
return remapped;
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
return VisitBufferAccess(std::move(node));
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return VisitBufferAccess(std::move(node));
}
PrimExpr VisitExpr_(const VarNode *op) final {
auto it = alloc_map_.find(op);
if (it != alloc_map_.end()) {
if (it->second->bits_offset != 0) {
LOG(WARNING)
<< "Use a merged buffer variable address, could cause error";
}
return it->second->alloc_var;
} else {
return GetRef<PrimExpr>(op);
}
}
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
ICHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode *buffer = op->args[1].as<VarNode>();
auto it = alloc_map_.find(buffer);
if (it == alloc_map_.end()) {
return StmtExprMutator::VisitExpr_(op);
}
const StorageEntry *se = it->second;
PrimExpr offset = this->VisitExpr(op->args[2]);
PrimExpr extent = this->VisitExpr(op->args[3]);
uint64_t elem_bits = dtype.bits() * dtype.lanes();
ICHECK_EQ(se->bits_offset % elem_bits, 0U);
if (se->bits_offset != 0) {
offset =
make_const(offset.dtype(), se->bits_offset / elem_bits) + offset;
}
return Call(op->dtype, op->op,
{op->args[0], se->alloc_var, offset, extent, op->args[4]});
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent ||
op->attr_key == tir::attr::virtual_thread ||
tir::attr::IsPragmaKey(op->attr_key)) {
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto &svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
return AttrStmt(op->node, op->attr_key, op->value,
MakeAttach(svec, op->body));
} else {
return StmtExprMutator::VisitStmt_(op);
}
} else if (op->attr_key == tir::attr::volatile_scope) {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
auto it = alloc_map_.find(op->node.as<VarNode>());
if (it == alloc_map_.end())
return stmt;
return AttrStmt(it->second->alloc_var, op->attr_key, op->value, op->body);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const ForNode *op) final {
ICHECK(op->kind != ForKind::kVectorized)
<< "VectorizeLoop before LiftStorageAlloc";
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto &svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
return For(op->loop_var, op->min, op->extent, op->kind,
MakeAttach(svec, op->body), op->thread_binding,
op->annotations);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const AllocateNode *op) final {
return this->VisitStmt(op->body);
}
Stmt VisitStmt_(const DeclBufferNode *op) final {
if (hoisted_buffer_decls_.count(op->buffer.get()) ||
!all_buffers_accessed_.count(op->buffer.get())) {
return this->VisitStmt(op->body);
}
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
if (auto it = alloc_map_.find(op->buffer->data.get());
it != alloc_map_.end()) {
Buffer buf = RemapBuffer(op->buffer, it->second->alloc_var);
node.CopyOnWrite()->buffer = buf;
}
return std::move(node);
}
private:
struct StorageEntry {
// The scope that this alloc attaches after
// For shared/local memory it is beginning of the thread extent.
// for global memory it is nullptr, means beginning of everything.
const Object *attach_scope_{nullptr};
// The constant size of the buffer in bits, only used if it is constant
uint64_t const_nbits{0};
// The storage scope.
StorageScope scope;
// The physical dimensionality of the allocations. Since
// StorageRewrite is applied after StorageFlatten/FlattenBuffer,
// this is size of `AllocateNode::extents`. If moved
size_t ndim;
// Allocs that shares this entry.
std::vector<const AllocateNode *> allocs;
// The children of this entry, not including itself.
std::vector<StorageEntry *> merged_children;
// The replacement Allocate, if any. May also include associated
// DeclBuffer statement.
std::vector<Stmt> alloc_nest;
// The var expr of new allocation.
Var alloc_var;
// The allocation element type.
DataType elem_type;
// This is non-zero if this allocate is folded into another one
// the address(in bits) becomes alloc_var + bits_offset;
// can be effectively converted to the element type.
// We need to convert bit_offset to offset of specific element type later.
//
// We use bits(instead of bytes) to support non-conventional indexing in
// hardware. When we are merging buffer together, the bits_offset are set to
// be aligned to certain value given by the max_simd_bits property of the
// special memory.
//
// This allows effective sharing among different types as long as their
// alignment requirement fits into the max_simd_bits.
uint64_t bits_offset{0};
};
// Checks whether the storage_scope is especially tagged for a specific
// memory. Special memory is all combined into a single allocation.
bool IsSpecialTaggedMemory(const StorageScope &scope) {
return scope.tag.length() != 0 && scope.tag != ".dyn" &&
scope.tag != ".workspace" && scope.tag != ".vtcm";
}
// Alllocate entry of node.
// Event entry in liveness analysis
struct EventEntry {
// variables we generate
std::vector<const VarNode *> gen;
// variables we kill
std::vector<const VarNode *> kill;
};
Stmt MakeAttach(const std::vector<StorageEntry *> &svec, Stmt body) {
for (auto it = svec.rbegin(); it != svec.rend(); it++) {
body = MergeNest((*it)->alloc_nest, body);
}
return body;
}
// Remap the index
PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry *e) {
if (e->bits_offset == 0)
return index;
uint64_t elem_bits = dtype.bits();
ICHECK_EQ(e->bits_offset % elem_bits, 0U);
return make_const(index.dtype(), e->bits_offset / elem_bits) + index;
}
// Prepare the new allocations
void PrepareNewAlloc() {
for (size_t i = 0; i < alloc_vec_.size(); ++i) {
StorageEntry *e = alloc_vec_[i].get();
attach_map_[e->attach_scope_].push_back(e);
}
// find allocation via attach map.
for (auto &kv : attach_map_) {
// find the element with the most amount of bytes.
std::vector<StorageEntry *> &vec = kv.second;
// try to find merge, for tagged memory
for (size_t i = 0; i < vec.size(); ++i) {
StorageEntry *e = vec[i];
if (IsSpecialTaggedMemory(e->scope)) {
ICHECK_NE(e->const_nbits, 0U)
<< "Special tagged memory must be const size";
for (size_t j = 0; j < i; ++j) {
if (e->scope == vec[j]->scope) {
vec[j]->merged_children.push_back(e);
break;
}
}
}
}
// Start allocation
for (size_t i = 0; i < vec.size(); ++i) {
StorageEntry *e = vec[i];
// already merged
if (e->bits_offset != 0)
continue;
if (e->merged_children.size() != 0) {
NewAllocTagMerged(e);
continue;
}
// Get the allocation size;
e->alloc_var = e->allocs[0]->buffer_var;
DataType alloc_type = e->allocs[0]->dtype;
for (const AllocateNode *op : e->allocs) {
if (op->dtype.lanes() > alloc_type.lanes()) {
alloc_type = op->dtype;
}
}
bool all_allocs_identical = std::all_of(
e->allocs.begin() + 1, e->allocs.end(),
[&](const AllocateNode *op) -> bool {
const AllocateNode *first = *e->allocs.begin();
if (op->dtype != first->dtype) {
return false;
}
if (op->extents.size() != first->extents.size()) {
return false;
}
ExprDeepEqual expr_equal;
for (size_t i = 0; i < op->extents.size(); i++) {
if (!expr_equal(op->extents[i], first->extents[i])) {
return false;
}
}
return true;
});
if (all_allocs_identical) {
// simply use the original allocation.
e->alloc_nest.push_back(
Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate(0)));
if (auto ptr = e->allocs[0]->body.as<DeclBufferNode>()) {
e->alloc_nest.push_back(DeclBuffer(
RemapBuffer(ptr->buffer, e->alloc_var), Evaluate(0)));
hoisted_buffer_decls_.insert(ptr->buffer.get());
}
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
if (info.defined()) {
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
<< "Allocation exceed bound of memory tag "
<< e->scope.to_string();
}
}
} else {
// Build a merged allocation
PrimExpr combo_size;
for (const AllocateNode *op : e->allocs) {
ICHECK_EQ(op->extents.size(), 1)
<< "Buffer var " << op->buffer_var->name_hint
<< " was identified as a re-usable allocation, but has "
<< op->extents.size() << " physical dimensions. "
<< "Currently, only flat 1-d memory spaces should be "
"identified as re-usable "
"allocations.";
PrimExpr sz = op->extents[0];
auto nbits = op->dtype.bits() * op->dtype.lanes();
if (const auto *imm = sz.as<IntImmNode>()) {
if (imm->value > std::numeric_limits<int>::max() / nbits) {
LOG(WARNING) << "The allocation requires : " << imm->value
<< " * " << nbits
<< " bits, which is greater than the maximum of"
" int32. The size is cast to int64."
<< "\n";
sz = make_const(DataType::Int(64), imm->value);
}
}
// transform to bits
auto sz_nbits = sz * nbits;
if (combo_size.defined()) {
combo_size = max(combo_size, sz_nbits);
} else {
combo_size = sz_nbits;
}
}
// transform to alloc bytes
auto type_bits = alloc_type.bits() * alloc_type.lanes();
bool divided =
analyzer_.CanProve(indexmod(combo_size, type_bits) == 0);
combo_size = indexdiv(combo_size, type_bits);
// round up for can not divided
if (!divided) {
combo_size = combo_size + make_const(DataType::Int(32), 1);
}
combo_size = analyzer_.Simplify(combo_size);
e->alloc_nest.push_back(Allocate(e->alloc_var, alloc_type,
{combo_size}, const_true(),
Evaluate(0)));
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
if (info.defined()) {
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
<< "Allocation exceed bound of memory tag "
<< e->scope.to_string();
}
}
}
}
}
}
// New allocation for merged data
void NewAllocTagMerged(StorageEntry *e) {
ICHECK_NE(e->scope.tag.length(), 0U);
// allocate with element type.
ICHECK_NE(e->const_nbits, 0U);
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_bits = e->const_nbits;
// By default, align to 32 bits.
size_t align = 32;
if (info.defined()) {
align = info->max_simd_bits;
}
// Always align to max_simd_bits
// so we can remap types by keeping this property
if (total_bits % align != 0) {
total_bits += align - (total_bits % align);
}
e->alloc_var = e->allocs[0]->buffer_var;
for (StorageEntry *child : e->merged_children) {
ICHECK_NE(child->const_nbits, 0U);
ICHECK_NE(total_bits, 0U);
child->bits_offset = total_bits;
child->alloc_var = e->alloc_var;
total_bits += child->const_nbits;
if (total_bits % align != 0) {
total_bits += align - (total_bits % align);
}
}
uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(),
(total_bits + type_bits - 1) / type_bits);
e->alloc_nest.push_back(Allocate(e->alloc_var, e->elem_type, {alloc_size},
const_true(), Evaluate(0)));
if (info.defined()) {
ICHECK_LE(total_bits, info->max_num_bits)
<< "Allocation exceed bound of memory tag " << e->scope.to_string();
}
}
// Liveness analysis to find gen and kill point of each variable.
void LivenessAnalysis(const std::vector<StmtEntry> &seq) {
// find kill point, do a reverse linear scan.
std::unordered_set<const VarNode *> touched;
for (size_t i = seq.size(); i != 0; --i) {
const StmtEntry &s = seq[i - 1];
for (const VarNode *buffer : s.touched) {
if (!touched.count(buffer)) {
touched.insert(buffer);
event_map_[s.stmt].kill.push_back(buffer);
}
}
}
// find gen point, do forward scan
touched.clear();
for (size_t i = 0; i < seq.size(); ++i) {
int64_t offset = seq[i].scope_pair_offset;
if (offset < 0)
continue;
const StmtEntry &s = seq[i + offset];
for (const VarNode *buffer : s.touched) {
if (!touched.count(buffer)) {
touched.insert(buffer);
event_map_[s.stmt].gen.push_back(buffer);
}
}
}
}
void PlanNewScope(const Object *op) {
if (thread_scope_ != nullptr) {
ICHECK(thread_scope_ == op);
// erase all memory atatched to this scope.
for (auto it = const_free_map_.begin(); it != const_free_map_.end();) {
if (it->second->attach_scope_ == op) {
it = const_free_map_.erase(it);
} else {
++it;
}
}
for (auto it = sym_free_list_.begin(); it != sym_free_list_.end();) {
if ((*it)->attach_scope_ == op) {
it = sym_free_list_.erase(it);
} else {
++it;
}
}
thread_scope_ = nullptr;
} else {
thread_scope_ = op;
}
}
// Memory plan algorithm
void
PlanMemory(const std::vector<StmtEntry> &seq,
const std::unordered_map<const VarNode *, AllocEntry> &alloc_info,
bool enable_reuse, bool reuse_require_exact_matched_dtype) {
std::unordered_set<const VarNode *> inplace_flag;
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry &s = seq[i];
auto it = event_map_.find(seq[i].stmt);
// scope_pair_offset >= 0 means it is either
// - leaf stmt(offset = 0)
// - beginning of scope(offset < 0)
// In both cases, we need to handle the gen event correctly
if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
// Inplace operation detection
// specially handle this
bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2);
for (const VarNode *var : it->second.gen) {
ICHECK(alloc_info.count(var));
const AllocEntry &entry = alloc_info.at(var);
const AllocateNode *alloc = entry.alloc;
auto storage_scope =
StorageScope::Create(GetPtrStorageScope(GetRef<Var>(var)));
StorageEntry *dst_entry = nullptr;
// inplace detection
if (detect_inplace) {
// only one inplace var for s.stmt
bool inplace_found = false;
for (const VarNode *src : it->second.kill) {
if (!inplace_flag.count(src) && alloc_map_.count(src)) {
InplaceOpVerifier visitor;
StorageEntry *src_entry = alloc_map_.at(src);
if (src_entry->scope == storage_scope &&
src_entry->attach_scope_ == thread_scope_ &&
src_entry->elem_type == alloc->dtype.element_of() &&
visitor.Check(s.stmt, var, src)) {
uint64_t const_nbits =
static_cast<uint64_t>(alloc->ConstantAllocationSize()) *
alloc->dtype.bits() * alloc->dtype.lanes();
if (src_entry->const_nbits == const_nbits && !inplace_found) {
// successfully inplace
dst_entry = src_entry;
inplace_flag.insert(src);
inplace_found = true;
}
}
}
}
}
if (dst_entry == nullptr) {
dst_entry = FindAlloc(alloc, thread_scope_, storage_scope,
entry.num_physical_dimensions, enable_reuse,
reuse_require_exact_matched_dtype);
}
dst_entry->allocs.emplace_back(alloc);
alloc_map_[var] = dst_entry;
}
}
// enter/exit new scope
if (s.stmt->IsInstance<AttrStmtNode>()) {
const auto *op = static_cast<const AttrStmtNode *>(s.stmt);
if (op->attr_key == tir::attr::thread_extent ||
op->attr_key == tir::attr::virtual_thread ||
tir::attr::IsPragmaKey(op->attr_key)) {
PlanNewScope(op);
} else {
ICHECK(op->attr_key == tir::attr::extern_scope);
}
} else if (s.stmt->IsInstance<ForNode>()) {
const auto *op = static_cast<const ForNode *>(s.stmt);
if (op->kind == ForKind::kParallel) {
if (thread_scope_ == nullptr || thread_scope_ == op) {
PlanNewScope(op);
}
}
}
// scope_pair_offset <= 0 means it is either
// - leaf stmt(offset = 0)
// - end of scope(offset < 0)
// In both cases, we need to handle the kill event correctly
if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
for (const VarNode *var : it->second.kill) {
// skip space which are already replaced by inplace
if (!inplace_flag.count(var)) {
this->Free(var);
}
}
}
}
}
// Allocate new storage entry.
StorageEntry *NewAlloc(const AllocateNode *op, const Object *attach_scope,
const StorageScope &scope, size_t const_nbits) {
ICHECK(op != nullptr);
// Re-use not successful, allocate a new buffer.
auto entry = std::make_unique<StorageEntry>();
entry->attach_scope_ = attach_scope;
entry->scope = scope;
entry->elem_type = op->dtype.element_of();
entry->const_nbits = const_nbits;
StorageEntry *e = entry.get();
alloc_vec_.emplace_back(std::move(entry));
return e;
}
StorageEntry *FindAlloc(const AllocateNode *op, const Object *attach_scope,
const StorageScope &scope,
size_t num_physical_dimensions, bool enable_reuse,
bool reuse_require_exact_matched_dtype) {
ICHECK(op != nullptr);
// skip plan for local variable,
// compiler can do a better job with register allocation.
const uint64_t match_range = 16;
uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
uint64_t const_nbits =
static_cast<uint64_t>(op->ConstantAllocationSize() * op_elem_bits);
// If the size of the array isn't known at compile-time, it must
// have its own allocation with size determined at runtime.
bool is_known_size = (const_nbits != 0);
// Currently, only flat memory spaces can be re-used. Packing
// into N-d space (e.g. 2-d texture memory on GPUs) will require
// more in-depth algorithms.
bool is_flat_memory_space = (num_physical_dimensions == 1);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
bool is_small_array =
(scope.tag.length() == 0) &&
(scope.rank >= StorageRank::kWarp || op->dtype.is_handle() ||
(is_known_size && const_nbits <= 32));
if (!enable_reuse || is_small_array || !is_flat_memory_space) {
return NewAlloc(op, attach_scope, scope, const_nbits);
}
if (is_known_size) {
// constant allocation.
auto begin = const_free_map_.lower_bound(const_nbits / match_range);
auto mid = const_free_map_.lower_bound(const_nbits);
auto end = const_free_map_.upper_bound(const_nbits * match_range);
// start looking at the buffer that is bigger than the required size first
for (auto it = mid; it != end; ++it) {
StorageEntry *e = it->second;
if (e->attach_scope_ != attach_scope)
continue;
if (e->scope != scope)
continue;
// when not divided, no reuse, eg, float4 vs float3
if (e->bits_offset % op_elem_bits != 0)
continue;
if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) {
continue;
}
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
return e;
}
// then start looking at smaller buffers.
for (auto it = mid; it != begin;) {
--it;
StorageEntry *e = it->second;
if (e->attach_scope_ != attach_scope)
continue;
if (e->scope != scope)
continue;
if (e->elem_type != op->dtype.element_of())
continue;
if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) {
continue;
}
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
return e;
}
} else {
// Simple strategy: round roubin.
for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) {
StorageEntry *e = *it;
if (e->attach_scope_ != attach_scope)
continue;
if (e->scope != scope)
continue;
if (e->elem_type != op->dtype.element_of())
continue;
sym_free_list_.erase(it);
return e;
}
}
return NewAlloc(op, attach_scope, scope, const_nbits);
}
// simulated free.
void Free(const VarNode *var) {
auto it = alloc_map_.find(var);
ICHECK(it != alloc_map_.end());
StorageEntry *e = it->second;
ICHECK_NE(e->allocs.size(), 0U);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if (e->scope.tag.length() == 0) {
// Disable sharing of local memory.
if (e->scope.rank >= StorageRank::kWarp ||
e->allocs[0]->dtype.is_handle())
return;
// disable reuse of small arrays
if (e->const_nbits > 0 && e->const_nbits <= 32)
return;
}
// normal free.
if (e->const_nbits != 0) {
const_free_map_.insert({e->const_nbits, e});
} else {
sym_free_list_.push_back(e);
}
}
// thread scope.
const Object *thread_scope_{nullptr};
// whether enable inplace detection.
bool detect_inplace_{false};
// Locations of free ops.
std::unordered_map<const Object *, EventEntry> event_map_;
// constant size free map.
std::multimap<uint64_t, StorageEntry *> const_free_map_;
// symbolic free list, for non constant items.
std::list<StorageEntry *> sym_free_list_;
// The allocation attach map
std::unordered_map<const Object *, std::vector<StorageEntry *>> attach_map_;
// The allocation assign map
std::unordered_map<const VarNode *, StorageEntry *> alloc_map_;
// The allocations
std::vector<std::unique_ptr<StorageEntry>> alloc_vec_;
// The buffer objects being remapped
std::unordered_map<const BufferNode *, Buffer> buffer_remap_;
// Buffers whose DeclBuffer has been hoisted to be adjacent to the new
// Allocate location
std::unordered_set<const BufferNode *> hoisted_buffer_decls_;
// Any buffers that is accessed at some point. DeclBuffer instances
// that do not appear in this list may be removed.
std::unordered_set<const BufferNode *> all_buffers_accessed_;
// analyzer
arith::Analyzer analyzer_;
};
/* Helper struct containing information on how a buffer is declared and used
*
*/
struct BufferVarInfo {
enum DeclarationLocation {
kPrimFuncParam = (1 << 0),
kPrimFuncBufferMap = (1 << 1),
kAllocateNode = (1 << 2),
kAllocateConstNode = (1 << 3),
kLetNode = (1 << 4),
};
// The tir::Var that represents this buffer.
Var var;
// The data type of an element of the buffer.
DataType element_dtype;
/* The extent of the buffer.
*
* If multidimensional, the extent of the last dimension of the buffer. If
* the size is unknown (e.g. pointer arguments to PrimFunc with no
* corresponding entry in buffer_map), then extent is zero.
*/
PrimExpr extent;
// Where the buffer was declared
DeclarationLocation declaration_location;
// When accessed, which element type is it accessed as. This may
// differ both in base type (e.g. int32* cast to float32* after
// packing in StorageRewrite) or in number of lanes (e.g. float16*
// cast to float16x4*).
std::unordered_set<DataType> access_dtype;
// Data types used for scalar reads. This is used to record vectorized read
// dtypes that can be shuffled for scalar reads when
// rewrite_scalar_read_to_vector_shuffle is enabled.
std::unordered_set<DataType> scalar_read_dtype;
DataType get_preferred_dtype() const {
std::unordered_set<DataType> base_access_dtype;
for (auto dtype : access_dtype) {
base_access_dtype.insert(dtype.element_of());
}
for (auto dtype : scalar_read_dtype) {
base_access_dtype.insert(dtype.element_of());
}
// If the array is accessed as multiple base types within a
// function, no point in changing the declared type. CodeGenC can
// handle this with a type-cast prior to indexing. Vulkan will
// raise an error at code-gen time, if a later pass doesn't split
// it out.
if (base_access_dtype.size() != 1) {
return element_dtype;
}
DataType preferred_base_type = *base_access_dtype.begin();
// If there is only one vectorizable size used to access the
// buffer, and if that access size is compatible with the array
// size, then the buffer is vectorizable. In the future, this
// could be improved to allow vectorized buffer access of size
// GCD(*lanes_used), if necessary.
// When there are scalar reads and no writes, access_dtype can be empty and
// we should avoid rewriting.
int preferred_lanes = element_dtype.lanes();
if (element_dtype.lanes() == 1 && (access_dtype.size() == 1)) {
int lanes = access_dtype.begin()->lanes();
// Check the scalar read dtypes are compatible with the vectorized access
// dtype.
for (auto dtype : scalar_read_dtype) {
if (dtype.lanes() % lanes != 0) {
return element_dtype;
}
}
arith::Analyzer analyzer_;
arith::ModularSet me = analyzer_.modular_set(extent);
if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) {
preferred_lanes = lanes;
}
}
return preferred_base_type.with_lanes(preferred_lanes);
}
};
/* Checks whether buffers are accessed as scalar or vector parameters in a
* function.
*
*/
class VectorTypeAccessChecker : public StmtExprVisitor {
public:
/* Constructor
*
* @param params The parameters passed to a PrimFunc
*
* @param buffer_map The buffer_map associated with a PrimFunc
*
* @param allow_untyped_handles If a buffer or pointer variable is
* missing a type annotation, assume that it has the same underlying
* type as it is later accessed, with scalar element types.
*/
VectorTypeAccessChecker(const Array<tir::Var> &params,
const Map<Var, Buffer> &buffer_map,
bool allow_untyped_pointers = false,
bool detect_scalar_read_patterns = true)
: allow_untyped_pointers_(allow_untyped_pointers),
detect_scalar_read_patterns_(detect_scalar_read_patterns) {
// If a parameter is in the buffer map, we want to track the
// version in the map.
for (auto it : buffer_map) {
Buffer &buffer = it.second;
Var buffer_var = buffer->data;
DataType dtype = buffer->dtype;
PrimExpr extent =
buffer->shape.size() ? buffer->shape[buffer->shape.size() - 1] : 0;
OnArrayDeclaration(buffer_var, dtype, extent,
BufferVarInfo::kPrimFuncParam);
}
// If a pointer parameter isn't in the buffer map, then we want to
// track the parameter itself.
for (Var buffer_var : params) {
auto pointer_type = GetPointerType(buffer_var->type_annotation);
if (pointer_type.has_value() && (buffer_map.count(buffer_var) == 0)) {
DataType dtype = pointer_type.value();
PrimExpr extent = 0;
OnArrayDeclaration(buffer_var, dtype, extent,
BufferVarInfo::kPrimFuncBufferMap);
}
}
}
void VisitExpr_(const BufferLoadNode *op) final {
OnArrayAccess(op->dtype, op->buffer->data.get(), op->indices,
/*is_buffer_load=*/true);
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
OnArrayAccess(op->value.dtype(), op->buffer->data.get(), op->indices,
/*is_buffer_load=*/false);
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
DataType dtype = op->args[0].dtype();
const VarNode *buffer = op->args[1].as<VarNode>();
PrimExpr index = op->args[2];
OnArrayAccess(dtype, buffer, {index}, false);
} else if (op->op.same_as(builtin::address_of())) {
if (auto load = op->args[0].as<BufferLoadNode>()) {
OnArrayAccess(load->dtype, load->buffer->data.get(), load->indices,
/*is_buffer_load=*/false);
}
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const AllocateNode *op) final {
const Array<PrimExpr> &extents = op->extents;
PrimExpr extent = extents[extents.size() - 1];
OnArrayDeclaration(op->buffer_var, op->dtype, extent,
BufferVarInfo::kAllocateNode);
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AllocateConstNode *op) final {
const Array<PrimExpr> &extents = op->extents;
PrimExpr extent =
extents.size() ? extents[extents.size() - 1] : NullValue<PrimExpr>();
OnArrayDeclaration(op->buffer_var, op->dtype, extent,
BufferVarInfo::kAllocateConstNode);
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const LetNode *op) final {
HandleLetNode(op->var);
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const LetStmtNode *op) final {
HandleLetNode(op->var);
StmtExprVisitor::VisitStmt_(op);
}
void HandleLetNode(Var let_var) {
if (let_var->dtype.is_handle()) {
auto pointer_type = GetPointerType(let_var->type_annotation);
if (pointer_type.has_value()) {
OnArrayDeclaration(let_var, pointer_type.value(), 0,
BufferVarInfo::kLetNode);
} else if (allow_untyped_pointers_) {
OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode);
} else {
LOG(FATAL) << "Let statement of variable " << let_var->name_hint
<< " is missing a type annotation, "
<< "or type annotation is not a pointer to primitive";
}
}
}
/* Update the type map for a buffer based on its declaration
*
* @param buffer The VarNode representing the buffer.
*
* @param element_dtype The dtype of a single element of the buffer.
* If unknown, when used with the allow_untyped_handles option,
* should be a handle dtype.
*
* @param extent The extent of the buffer. Zero if size is unknown.
*
* @param declaration_location How the buffer was allocated, so that
* some locations can be rewritten without others.
*/
void
OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent,
BufferVarInfo::DeclarationLocation declaration_location) {
ICHECK(info_map_.find(buffer.get()) == info_map_.end())
<< "Array declaration of " << buffer->name_hint
<< " occurred multiple times.";
if (element_dtype == DataType::Bool()) {
element_dtype = DataType::Int(8).with_lanes(element_dtype.lanes());
}
info_map_[buffer.get()] =
BufferVarInfo{buffer, element_dtype, extent, declaration_location};
}
/* Update the type map for a buffer based on its usage
*
* @param value_dtype The dtype of the value being stored to or
* loaded from the buffer.
*
* @param buffer The VarNode representing the buffer.
*
* @param indices The index at which the value is being stored/loaded.
*
* @param is_buffer_load Whether the access is BufferLoad
*/
void OnArrayAccess(DataType value_dtype, const VarNode *buffer,
const Array<PrimExpr> &indices, bool is_buffer_load) {
auto it = info_map_.find(buffer);
ICHECK(it != info_map_.end())
<< "Load/Store of buffer " << buffer->name_hint << " (" << buffer
<< ") occurred before its declaration.";
if (value_dtype.is_scalable_vector()) {
// Scalable types are not currently supported in storage_rewrite. Scalable
// buffer accesses are not currently checked and therefore are not
// rewritten.
return;
}
BufferVarInfo &var_info = it->second;
if (value_dtype.element_of() == DataType::Bool()) {
value_dtype = DataType::Int(8).with_lanes(value_dtype.lanes());
}
if (var_info.element_dtype.is_handle()) {
ICHECK(allow_untyped_pointers_)
<< "Variable " << buffer->name_hint
<< " was missing a type annotation in its declaration";
var_info.element_dtype = value_dtype.element_of();
}
for (int i = 0; i < static_cast<int>(indices.size()) - 1; i++) {
ICHECK(indices[i].dtype().is_scalar())
<< "Only the last index of a buffer access may be a vector type.";
}
int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1;
DataType access_dtype = value_dtype;
int lanes_used = var_info.element_dtype.lanes();
// This can happen due to a previous pass that had rewrite_store_load =
// false. This occurs from the StorageRewrite in tvm::lower, followed by
// the PointerValueTypeRewrite in BuildSPIRV. The rewrite_store_load =
// false is necessary because the C-based codegens do not yet support
// vectorized pointer types (e.g. float16x4*). Once they do, this if
// statement should instead be replaced by the below ICHECK_EQ.
if (index_lanes * var_info.element_dtype.lanes() != value_dtype.lanes()) {
ICHECK_EQ(index_lanes, value_dtype.lanes());
lanes_used = 1;
var_info.element_dtype = var_info.element_dtype.with_lanes(1);
}
// TODO(Lunderberg): Uncomment this check once it can be applied.
// See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
// for discussion.
// ICHECK_EQ(index_lanes * var_info.element_dtype.lanes(),
// value_dtype.lanes())
// << "Attempting to retrieve " << value_dtype.lanes() << " lanes of
// data with "
// << index_lanes << " indices into an array whose elements have "
// << var_info.element_dtype.lanes() << " lanes. "
// << "Expected output with " << index_lanes *
// var_info.element_dtype.lanes()
// << " lanes.";
// If the index is a RampNode with stride of 1 and offset
// divisible by the number of number of lanes, and the predicate
// does not apply any masking, then this array access could be
// vectorized.
if (indices.size()) {
const RampNode *ramp_index = indices[indices.size() - 1].as<RampNode>();
if (ramp_index && is_one(ramp_index->stride)) {
if (ramp_index->lanes->IsInstance<IntImmNode>()) {
int lanes =
static_cast<int>(Downcast<IntImm>(ramp_index->lanes)->value);
arith::ModularSet me = analyzer_.modular_set(ramp_index->base);
if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) {
lanes_used = lanes;
}
}
}
}
if (detect_scalar_read_patterns_ && is_buffer_load && indices.size()) {
const PrimExpr last_dim_index = indices[indices.size() - 1];
if (last_dim_index.dtype().lanes() == 1) {
arith::ModularSet me = analyzer_.modular_set(last_dim_index);
var_info.scalar_read_dtype.emplace(access_dtype.with_lanes(me->coeff));
return;
}
}
var_info.access_dtype.insert(access_dtype.with_lanes(lanes_used));
}
// Map of buffer variable information determined
std::unordered_map<const VarNode *, BufferVarInfo> info_map_;
//
bool allow_untyped_pointers_{false};
// Whether to detect scalar read patterns for rewriting to vector shuffle
bool detect_scalar_read_patterns_{true};
// internal analyzer
arith::Analyzer analyzer_;
};
/* \brief Rewrites buffer/pointer variables from scalar types to vectorized
* types.
*
* Some runtimes do not allow casting between composite types and the underlying
* base type (e.g. Vulkan, casting from 1-lane float16* to 4-lane float16x4*).
* In these cases, in order to have vectorized load/store on an array, the
* element type of that array must be vectorized. This is in contrast to
* C-style runtimes, in which `float16x4* vec = *(float16x4*)(float_arr +
* offset)` is valid.
*
* By default, VectorTypeRewriter will attempt to rewrite all buffer variables
* to vectorized access, if the load/store occurring in the PrimFunc are all
* vectorized. This includes adjusting the indices being used to access the
* array. (e.g. If `float16* scalar_arr` is being converted to `float16x4*
* vec_arr`, then `scalar_arr[Ramp(offset, 1, 4)]` will be converted to
* `vec_arr[offset/4]`.)
*
* Currently, several of the C-style runtimes do not support buffers whose
* elements are vectorized types, or rely on the presence of the Ramp nodes to
* identify vectorized loads. The boolean parameters in the constructor are to
* mimic the previous behavior of VectorTypeRewriter, to avoid breaking these
* runtimes. Once all runtimes support vectorized buffer elements, these
* parameters can be removed.
*/
class VectorTypeRewriter : public StmtExprMutator {
public:
/* Constructor
*
* @param checker The VectorTypeAccessChecker that has previously read out
* information from the PrimFunc
*
* @param rewrite_params Whether pointer-type parameters passed into the
* function should be rewritten from scalar types to vectorized types.
*
* @param rewrite_buffer_map Whether buffers present in the buffer_map should
* have their data variable be rewritten from scalar types to vectorized
* types.
*
* @param rewrite_allocate_node Whether the buffer variable associated with
* AllocateNodes should be rewritten from scalar types to vectorized types.
*
* @param rewrite_indices Whether the indices to the Load and Store nodes
* should be rewritten to correspond to the new buffer_var type.
*
* @param rewrite_let_node Whether pointer declarations in let nodes
* should be re-written.
*/
VectorTypeRewriter(
const std::unordered_map<const VarNode *, BufferVarInfo> &info_map,
bool rewrite_params = true, bool rewrite_buffer_map = true,
bool rewrite_allocate_node = true, bool rewrite_indices = true,
bool rewrite_let_node = true, bool rewrite_allocate_const_node = true,
bool rewrite_scalar_read_to_vector_shuffle = true)
: rewrite_indices_(rewrite_indices) {
int rewrite_mask = 0;
if (rewrite_params) {
rewrite_mask |= BufferVarInfo::kPrimFuncParam;
}
if (rewrite_buffer_map) {
rewrite_mask |= BufferVarInfo::kPrimFuncBufferMap;
}
if (rewrite_allocate_node) {
rewrite_mask |= BufferVarInfo::kAllocateNode;
}
if (rewrite_let_node) {
rewrite_mask |= BufferVarInfo::kLetNode;
}
if (rewrite_allocate_const_node) {
rewrite_mask |= BufferVarInfo::kAllocateConstNode;
}
// Rewrite any buffer variables whose preferred type isn't their current
// type.
for (const auto &pair : info_map) {
const auto &var_info = pair.second;
DataType preferred = var_info.get_preferred_dtype();
if (preferred != var_info.element_dtype &&
(rewrite_mask & var_info.declaration_location)) {
Var old_buffer_var = var_info.var;
Var new_buffer_var(old_buffer_var->name_hint,
PointerType(PrimType(preferred),
GetPtrStorageScope(old_buffer_var)),
old_buffer_var->span);
rewrite_map_[var_info.var.get()] = {var_info.var, new_buffer_var,
var_info.element_dtype, preferred};
}
}
}
/*!
* \brief Mutator for BufferLoad or BufferStore.
* \return The rewritten node and the shuffle index. (Only for BufferLoad)
* When the shuffle index is non-negative, the caller should generate Shuffle
* to extract the element from the vector.
*/
template <typename Node> std::pair<Node, int> VisitBufferAccess(Node node) {
int shuffle_index = -1;
if (!rewrite_indices_) {
return {node, shuffle_index};
}
auto it = rewrite_map_.find(node->buffer->data.get());
if (it == rewrite_map_.end()) {
return {node, shuffle_index};
}
const auto &info = it->second;
Array<PrimExpr> indices = node->indices;
const PrimExpr &last_dim_index = indices[indices.size() - 1];
const RampNode *ramp_index = indices[indices.size() - 1].as<RampNode>();
if (node->buffer->dtype.is_scalable_vector() ||
last_dim_index.dtype().is_scalable_vector()) {
// Scalable types are not currently supported in storage_rewrite. Scalable
// buffer accesses are not currently checked and therefore are not
// rewritten.
return {node, shuffle_index};
}
if (ramp_index && is_one(ramp_index->stride) &&
ramp_index->lanes->IsInstance<IntImmNode>()) {
int lanes = static_cast<int>(Downcast<IntImm>(ramp_index->lanes)->value);
PrimExpr new_index =
ramp_index->base / make_const(ramp_index->base.dtype(), lanes);
if (lanes != info.factor()) {
ICHECK(info.factor() && lanes % info.factor() == 0);
int new_lanes = lanes / info.factor();
new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes,
ramp_index->span);
}
indices.Set(indices.size() - 1, new_index);
} else if (last_dim_index.dtype().lanes() == 1 && info.factor() > 1) {
arith::ModularSet me = analyzer_.modular_set(last_dim_index);
ICHECK(me->coeff == 0 || info.factor() % me->coeff == 0);
PrimExpr new_index =
last_dim_index / make_const(last_dim_index.dtype(), info.factor());
shuffle_index = me->base % info.factor();
;
indices.Set(indices.size() - 1, new_index);
}
auto writer = node.CopyOnWrite();
writer->buffer = RemapBuffer(node->buffer);
writer->indices = indices;
return {node, shuffle_index};
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto [modified, shuffle_index] = VisitBufferAccess(node);
// Not needed for BufferStoreNode, so we can't just call
// LegalizeDtype() in VisitBufferAccess.
if (node.same_as(modified)) {
return std::move(node);
} else {
auto writer = modified.CopyOnWrite();
// writer->LegalizeDType();
LegalizeBufferLoadDType(writer);
if (shuffle_index >= 0) {
return Shuffle::ExtractElement(std::move(modified), shuffle_index);
}
return std::move(modified);
}
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto [modified, shuffle_index] = VisitBufferAccess(std::move(node));
ICHECK(shuffle_index < 0);
return std::move(modified);
}
Stmt VisitStmt_(const LetStmtNode *op) final {
auto it = rewrite_map_.find(op->var.get());
PrimExpr value = this->VisitExpr(op->value);
Stmt body = this->VisitStmt(op->body);
Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var;
if (var.same_as(op->var) && value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Stmt>(op);
}
return LetStmt(var, value, body);
}
Buffer RemapBuffer(Buffer buf) {
auto cache_key = buf.get();
auto cache_it = buffer_map_.find(cache_key);
if (cache_it != buffer_map_.end()) {
return cache_it->second;
}
auto info_it = rewrite_map_.find(buf->data.get());
if (info_it != rewrite_map_.end()) {
auto &info = info_it->second;
Array<PrimExpr> shape = buf->shape;
PrimExpr last_dim = shape[shape.size() - 1];
shape.Set(shape.size() - 1,
last_dim / make_const(last_dim.dtype(), info.factor()));
auto writer = buf.CopyOnWrite();
writer->data = info.new_buffer_var;
writer->dtype = info.new_element_dtype;
writer->shape = shape;
}
buffer_map_[cache_key] = buf;
return buf;
}
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
if (!rewrite_indices_) {
return expr;
}
const VarNode *buffer_var = op->args[1].as<VarNode>();
auto it = rewrite_map_.find(buffer_var);
if (it == rewrite_map_.end()) {
return expr;
}
const auto &info = it->second;
PrimExpr index = op->args[2];
PrimExpr extent = op->args[3];
PrimExpr flag = op->args[4];
PrimExpr e_dtype = tir::TypeAnnotation(info.new_element_dtype);
int factor = info.factor();
extent = extent / make_const(extent.dtype(), factor);
index = index / make_const(index.dtype(), factor);
Array<PrimExpr> acc_args{e_dtype, info.new_buffer_var, index, extent,
flag};
return Call(info.new_element_dtype, builtin::tvm_access_ptr(), acc_args);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
Stmt VisitStmt_(const AllocateNode *op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
auto it = rewrite_map_.find(op->buffer_var.get());
if (it == rewrite_map_.end()) {
return stmt;
}
const auto &info = it->second;
Var new_buffer_var = info.new_buffer_var;
Array<PrimExpr> extents = op->extents;
PrimExpr last_extent = extents[extents.size() - 1];
extents.Set(extents.size() - 1,
last_extent / make_const(last_extent.dtype(), info.factor()));
return Allocate(new_buffer_var, info.new_element_dtype, extents,
op->condition, op->body);
}
Stmt VisitStmt_(const AllocateConstNode *op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateConstNode>();
auto it = rewrite_map_.find(op->buffer_var.get());
if (it == rewrite_map_.end()) {
return stmt;
}
const auto &info = it->second;
Var new_buffer_var = info.new_buffer_var;
int factor = info.new_element_dtype.lanes() / op->dtype.lanes();
Array<PrimExpr> extents = op->extents;
extents.Set(extents.size() - 1, extents[extents.size() - 1] /
make_const(extents[0].dtype(), factor));
return AllocateConst(new_buffer_var, info.new_element_dtype, extents,
op->data, op->body);
}
/* Update the parameters and all remaining variable references
*
* Should be called after calling operator() on the body of the
* function.
*
* @param func A pointer to the PrimFunc being modified.
*/
void Finalize(PrimFunc *func_ptr) {
ICHECK(func_ptr) << "Finalize expects a non-null pointer";
auto &func = *func_ptr;
auto *n = func.CopyOnWrite();
// Remap any remaining references to the old buffer variables
Map<Var, Var> var_remap;
for (const auto &pair : rewrite_map_) {
const auto &info = pair.second;
var_remap.Set(info.old_buffer_var, info.new_buffer_var);
}
n->body = Substitute(n->body, var_remap);
// Remap the argument list to use the new buffer variables.
Array<Var> new_params;
for (const auto &old_param : n->params) {
auto it = rewrite_map_.find(old_param.get());
if (it == rewrite_map_.end()) {
new_params.push_back(old_param);
} else {
const auto &info = it->second;
new_params.push_back(info.new_buffer_var);
}
}
n->params = new_params;
// Remap the Buffer objects in PrimFunc::buffer_map so that the
// buffers use the new buffer variables
Map<Var, Buffer> new_buffer_map;
for (const auto &pair : n->buffer_map) {
Var key = pair.first;
Buffer old_buffer = pair.second;
Var old_var = old_buffer->data;
Buffer new_buffer = RemapBuffer(old_buffer);
new_buffer_map.Set(key, new_buffer);
}
n->buffer_map = new_buffer_map;
}
private:
struct RewriteInfo {
Var old_buffer_var;
Var new_buffer_var;
DataType old_element_dtype;
DataType new_element_dtype;
int factor() const {
int old_lanes = old_element_dtype.lanes();
int new_lanes = new_element_dtype.lanes();
ICHECK_EQ(new_lanes % old_lanes, 0);
return new_lanes / old_lanes;
}
};
bool rewrite_indices_{true};
std::unordered_map<const VarNode *, RewriteInfo> rewrite_map_;
std::unordered_map<const BufferNode *, Buffer> buffer_map_;
arith::Analyzer analyzer_;
};
// Rewrite allocates, pointer parameters, and buffer map into vectorized
// versions if each access into a buffer is the same vector type.
PrimFunc PointerValueTypeRewrite(
PrimFunc f, bool allow_untyped_pointers = false, bool rewrite_params = true,
bool rewrite_buffer_map = true, bool rewrite_allocate_node = true,
bool rewrite_indices = true, bool rewrite_let_node = true,
bool rewrite_allocate_const_node = true,
bool rewrite_scalar_read_to_vector_shuffle = true) {
VectorTypeAccessChecker checker(f->params, f->buffer_map,
allow_untyped_pointers,
rewrite_scalar_read_to_vector_shuffle);
checker(f->body);
VectorTypeRewriter rewriter(
checker.info_map_, rewrite_params, rewrite_buffer_map,
rewrite_allocate_node, rewrite_indices, rewrite_let_node,
rewrite_allocate_const_node, rewrite_scalar_read_to_vector_shuffle);
PrimFuncNode *n = f.CopyOnWrite();
n->body = rewriter(std::move(n->body));
rewriter.Finalize(&f);
return f;
}
using namespace tir::transform;
namespace transform {
Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool enable_reuse = true;
bool reuse_require_exact_matched_dtype = false;
bool merge_static_smem =
ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
AllocateCollector collector;
collector(f->body);
bool has_dynamic = collector.dyn_shmem_allocs_.size() > 1;
if (has_dynamic || merge_static_smem) {
// For IRModule utilizing dynamic shared memory, reuse is not enabled
// Because dynamic doesn't require maintaining the readability and
// it benefits from a more optimized allocation strategy through the
// Pass `MergeSharedMemoryAllocations`.
// When `merge_static_smem` is true, we will reuse and merge shared
// memory in a dedicated pass `MergeSharedMemoryAllocations`.
// And so we don't enable reuse in this pass.
enable_reuse = false;
}
Optional<Target> target = f->GetAttr<Target>("target");
if (target.defined() && (target.value()->kind->name == "vulkan" ||
target.value()->kind->name == "webgpu")) {
// Require exactly same-dtype matching in smem reuse for Vulkan and WebGPU
reuse_require_exact_matched_dtype = true;
}
auto *n = f.CopyOnWrite();
n->body =
StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse,
reuse_require_exact_matched_dtype);
// Parameters may not be rewritten, but internal allocations may.
// Vectorization of AllocateConst is currently disabled, as it has
// indexing issues for types that include padding (e.g. int8x3
// padded out to 32 bits) would require either rewriting
// AllocateConst::data, or would require the code generators to
// handle vectorized constants.
return PointerValueTypeRewrite(std::move(f), true, false, false, false,
true, true, false, false);
};
return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.StorageRewrite", StorageRewrite);
});
Pass PointerValueTypeRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
return tl::PointerValueTypeRewrite(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.PointerValueTypeRewrite", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PointerValueTypeRewrite",
PointerValueTypeRewrite);
});
} // namespace transform
} // namespace tl
} // namespace tvm
/*!
* \file thread_storage_sync.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
......@@ -269,7 +270,7 @@ private:
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
num_partial_threads_ = NullOpt;
num_partial_threads_ = std::nullopt;
} else {
TileLangStorageAccessVisitor::VisitStmt_(op);
}
......@@ -371,8 +372,11 @@ Pass TileLangThreadPartialSync(String storage_scope) {
return CreatePrimFuncPass(pass_func, 0, "tl.ThreadPartialSync", {});
}
TVM_REGISTER_GLOBAL("tl.transform.ThreadPartialSync")
.set_body_typed(TileLangThreadPartialSync);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ThreadPartialSync",
TileLangThreadPartialSync);
});
} // namespace transform
} // namespace tl
......
......@@ -20,7 +20,8 @@
/*!
* \file thread_storage_sync.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
......@@ -367,7 +368,7 @@ private:
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
num_partial_threads_ = NullOpt;
num_partial_threads_ = std::nullopt;
} else {
TileLangStorageAccessVisitor::VisitStmt_(op);
}
......@@ -786,7 +787,10 @@ tvm::transform::Pass ThreadSync(String storage_scope) {
return CreatePrimFuncPass(pass_func, 0, "tl.ThreadSync", {});
}
TVM_REGISTER_GLOBAL("tl.transform.ThreadSync").set_body_typed(ThreadSync);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ThreadSync", ThreadSync);
});
} // namespace transform
} // namespace tl
......
......@@ -22,7 +22,8 @@
*/
// Loop vectorizer as in Halide pipeline.
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
......@@ -631,7 +632,7 @@ public:
return Scalarize(GetRef<Stmt>(op));
}
Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = NullOpt;
Optional<Stmt> else_case = std::nullopt;
if (op->else_case) {
else_case = this->VisitStmt(op->else_case.value());
}
......@@ -688,10 +689,6 @@ public:
stmt = Substitute(stmt, {{var_, idx}});
return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt);
}
// ProducerStore
Stmt VisitStmt_(const ProducerStoreNode *op) final {
LOG(FATAL) << "ProducerProvide cannot appear in a TIR PrimFunc";
}
private:
// analyzer
......@@ -787,6 +784,10 @@ private:
}
};
inline bool TargetHasSVE() {
return Target::Current()->GetFeature<Bool>("has_sve").value_or(false);
}
class LoopVectorizer : public StmtMutator {
public:
Stmt VisitStmt_(const ForNode *op) final {
......@@ -796,7 +797,7 @@ public:
if (!extent_as_int || extent_as_int->value < 1) {
bool is_scalable_expr =
CheckContains::ExprContains(op->extent, arith::IsVScaleCall);
ICHECK(is_scalable_expr && arith::TargetHasSVE())
ICHECK(is_scalable_expr && TargetHasSVE())
<< "Failed to vectorize loop with extent " << op->extent
<< " for target " << Target::Current();
}
......@@ -837,7 +838,10 @@ tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) {
return CreatePrimFuncPass(pass_func, 0, "tl.VectorizeLoop", {});
}
TVM_REGISTER_GLOBAL("tl.transform.VectorizeLoop").set_body_typed(VectorizeLoop);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.VectorizeLoop", VectorizeLoop);
});
} // namespace tl
} // namespace tvm
......@@ -5,6 +5,7 @@
#include "arith/ir_visitor_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
......@@ -447,7 +448,7 @@ private:
order_anno.push_back(Integer(op_info.order));
stage_anno.push_back(Integer(op_info.stage));
}
Map<String, ObjectRef> for_annotations = op->annotations;
Map<String, Any> for_annotations = op->annotations;
for_annotations.erase("tl_pipeline_group");
for_annotations.Set("software_pipeline_order", order_anno);
for_annotations.Set("software_pipeline_stage", stage_anno);
......@@ -636,9 +637,9 @@ private:
Stmt VisitStmt_(const ForNode *op) final {
int num_stages = 1;
auto num_stages_anno = op->annotations.Get("num_stages");
if (num_stages_anno.defined()) {
ICHECK(num_stages_anno.as<IntImmNode>());
num_stages = static_cast<int>(num_stages_anno.as<IntImmNode>()->value);
if (num_stages_anno) {
ICHECK(num_stages_anno->as<IntImmNode>());
num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
}
loop_stack_.emplace_back(op->loop_var, op->extent);
......@@ -648,16 +649,16 @@ private:
Array<Integer> stage_info_array;
auto group_anno = op->annotations.Get("tl_pipeline_group");
if (group_anno.defined()) {
group_info_array = Downcast<Array<Array<Integer>>>(group_anno);
if (group_anno) {
group_info_array = Downcast<Array<Array<Integer>>>(group_anno.value());
}
auto order_anno = op->annotations.Get("tl_pipeline_order");
if (order_anno.defined()) {
order_info_array = Downcast<Array<Integer>>(order_anno);
if (order_anno) {
order_info_array = Downcast<Array<Integer>>(order_anno.value());
}
auto stage_anno = op->annotations.Get("tl_pipeline_stage");
if (stage_anno.defined()) {
stage_info_array = Downcast<Array<Integer>>(stage_anno);
if (stage_anno) {
stage_info_array = Downcast<Array<Integer>>(stage_anno.value());
}
PipelineInfo pipeline_info(group_info_array, order_info_array,
......@@ -686,8 +687,8 @@ private:
auto result = FilterByRole(op);
Stmt grouped_for_node;
if (result.as<ForNode>() && group_anno.defined() &&
group_info_array.size() > 0 && !is_emitting_producer_) {
if (result.as<ForNode>() && group_anno && group_info_array.size() > 0 &&
!is_emitting_producer_) {
GroupOpRewriter group_op_rewriter(pipeline_info_);
auto for_node = Downcast<For>(result);
grouped_for_node = group_op_rewriter(for_node);
......@@ -707,7 +708,7 @@ private:
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order");
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage");
}
if (is_emitting_producer_ || !group_anno.defined() ||
if (is_emitting_producer_ || !group_anno ||
group_info_array.size() == 0) {
loop_stack_.pop_back();
return for_node;
......@@ -1230,8 +1231,10 @@ tvm::transform::Pass WarpSpecialized() {
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}
TVM_REGISTER_GLOBAL("tl.transform.WarpSpecialized")
.set_body_typed(WarpSpecialized);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized);
});
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file warp_specialized_pipeline.cc
* \brief Warp specialized Pipeline for cuda GPU (sm90+)
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
......@@ -131,7 +113,7 @@ private:
Stmt VisitStmt_(const ForNode *op) final {
auto order_anno = op->annotations.Get("tl_pipeline_order");
if (!order_anno.defined()) {
if (!order_anno) {
return StmtExprMutator::VisitStmt_(op);
}
......@@ -281,8 +263,10 @@ tvm::transform::Pass RewriteWgmmaSync() {
return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {});
}
TVM_REGISTER_GLOBAL("tl.transform.RewriteWgmmaSync")
.set_body_typed(RewriteWgmmaSync);
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.RewriteWgmmaSync", RewriteWgmmaSync);
});
} // namespace tl
} // namespace tvm
......@@ -4,6 +4,8 @@ from tilelang import tvm as tvm
import tilelang.language as T
import torch
tilelang.disable_cache()
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
num_stages = 0
......
......@@ -40,8 +40,8 @@ def tl_matmul(
assert in_dtype in [
"float16",
"bfloat16",
"e4m3_float8",
"e5m2_float8",
"float8_e4m3",
"float8_e5m2",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
......@@ -52,7 +52,7 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["e4m3_float8", "e5m2_float8"]
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
......@@ -220,4 +220,5 @@ def test_assert_tl_matmul_bfloat16():
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
test_assert_tl_matmul_bfloat16()
# ruff: noqa
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
import torch
from typing import Optional, Union
from einops import rearrange, repeat
tilelang.testing.set_random_seed(42)
def naive_nsa_ref(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
raise RuntimeError(
"Sequences with variable lengths are not supported for head-first mode")
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
dtype = q.dtype
G = q.shape[2] // k.shape[2]
BS = block_size
S = block_indices.shape[-1]
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
if isinstance(block_counts, torch.Tensor):
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o_slc = torch.zeros_like(v)
o_swa = torch.zeros_like(v) if window_size > 0 else None
varlen = True
if cu_seqlens is None:
varlen = False
B, T = q.shape[:2]
cu_seqlens = torch.cat(
[block_indices.new_tensor(range(0, B * T, T)),
block_indices.new_tensor([B * T])])
for i in range(len(cu_seqlens) - 1):
if not varlen:
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[
i], block_indices[i]
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[i]
else:
s_b = block_counts
else:
T = cu_seqlens[i + 1] - cu_seqlens[i]
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map(
lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]],
(q, k, v, g_slc, g_swa, block_indices))
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]]
else:
s_b = block_counts
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(T):
# [HQ, D]
q_i = q_b[i_q] * scale
# [HQ]
g_slc_i = g_slc_b[i_q]
# [HQ]
g_swa_i = g_swa_b[i_q]
# [S*BS, HQ]
i_i = i_b[i_q]
# [HQ]
if isinstance(block_counts, torch.Tensor):
s_i = s_b[i_q]
else:
s_i = s_b
# [S*BS, HQ, -1]
k_i_slc, v_i_slc = map(
lambda x: x.gather(
0,
i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
# [S*BS, HQ]
attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill(
torch.logical_or(i_i < 0, i_i > i_q) |
(c >= s_i if block_counts is not None else False), float('-inf')).softmax(0)
if not varlen:
o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
else:
o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
if window_size > 0:
k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1],
(k_b, v_b))
attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0)
if not varlen:
o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
else:
o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
if head_first:
o_slc = rearrange(o_slc, 'b t h d -> b h t d')
o_swa = rearrange(o_swa, 'b t h d -> b h t d')
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
def native_sparse_attention(batch,
heads,
seq_len,
dim,
is_causal,
scale=None,
block_size=64,
groups=16,
selected_blocks=16,
num_stages=0,
threads=32):
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
else:
scale = scale * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
block_indices_dtype = "int32"
dtype = "float16"
accum_dtype = "float"
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
NK = tilelang.cdiv(dim, block_T)
NV = tilelang.cdiv(dim, block_T)
assert NK == 1, "The key dimension can not be larger than 256"
S = selected_blocks
G = groups
BS = block_S
BK = BV = block_T
@T.prim_func
def native_sparse_attention(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype)
K_shared = T.alloc_shared([BS, BK], dtype)
V_shared = T.alloc_shared([BS, BV], dtype)
O_shared = T.alloc_shared([G, BV], dtype)
acc_s = T.alloc_fragment([G, BS], accum_dtype)
acc_s_cast = T.alloc_fragment([G, BS], dtype)
acc_o = T.alloc_fragment([G, BV], accum_dtype)
scores_max = T.alloc_fragment([G], accum_dtype)
scores_max_prev = T.alloc_fragment([G], accum_dtype)
scores_scale = T.alloc_fragment([G], accum_dtype)
scores_sum = T.alloc_fragment([G], accum_dtype)
logsum = T.alloc_fragment([G], accum_dtype)
i_t, i_v, i_bh = bx, by, bz
i_b, i_h = i_bh // head_kv, i_bh % head_kv
NS = S
T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Pipelined(NS, num_stages=num_stages):
i_s = BlockIndices[i_b, i_t, i_h, i] * BS
if i_s <= i_t and i_s >= 0:
# [BS, BK]
T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared)
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
for i in T.Parallel(G):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(G):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
# Rescale
for i, j in T.Parallel(G, BV):
acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV])
return native_sparse_attention
def run_native_sparse_attention(batch,
heads,
seq_len,
dim,
is_causal,
scale=None,
block_size=64,
groups=16,
selected_blocks=16,
num_stages=0,
threads=32):
dtype = torch.float16
head_kv = heads // groups
program = native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale, block_size,
groups, selected_blocks, num_stages, threads)
kernel = tilelang.compile(program, out_idx=-1)
Q = torch.randn((batch, seq_len, heads, dim), dtype=dtype).cuda()
K = torch.randn((batch, seq_len, head_kv, dim), dtype=dtype).cuda()
V = torch.randn((batch, seq_len, head_kv, dim), dtype=dtype).cuda()
g_slc = torch.ones((batch, seq_len, heads), dtype=dtype).cuda()
g_swa = torch.ones((batch, seq_len, heads), dtype=dtype).cuda()
block_indices = torch.full((batch, seq_len, head_kv, selected_blocks),
seq_len,
dtype=torch.long,
device='cuda')
for b in range(batch):
for t in range(seq_len):
for h in range(head_kv):
i_i = torch.randperm(max(1, (t // block_size)))[:selected_blocks]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, selected_blocks + 1, (batch, seq_len, head_kv), device='cuda')
out = kernel(Q, K, V, block_indices.to(torch.int32))
ref = naive_nsa_ref(
q=Q,
k=K,
v=V,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
scale=scale,
)
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
def test_tilelang_kernel_deepseek_nsa():
# disable pipeline
run_native_sparse_attention(
batch=2,
heads=64,
seq_len=1,
dim=16,
is_causal=True,
scale=None,
block_size=32,
groups=16,
selected_blocks=16,
num_stages=0,
threads=32)
# enable pipeline
run_native_sparse_attention(
batch=2,
heads=64,
seq_len=1,
dim=16,
is_causal=True,
scale=None,
block_size=32,
groups=16,
selected_blocks=16,
num_stages=2,
threads=32)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -97,7 +97,7 @@ def test_fp4_fp16_convert_close():
block_K,
"float16",
)
print(program.script())
kernel = tilelang.compile(program, out_idx=[1])
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
......@@ -642,4 +642,5 @@ def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
test_fp4_fp16_convert_close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment