"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "3274ca3094bf05d4fb9d6afa554a2bd71001b2d8"
Commit c4638d65 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Introduce `T.any_of` and `T.all_of` to reduce a bool arrary (#371)



* [Enhancement] Introduce logical operations `any_of` and `all_of` for buffer checks

- Added new logical operations `any_of` and `all_of` to the TileLang language interface, allowing users to check conditions across buffer elements.
- Implemented corresponding intrinsic calls for CUDA, enhancing the functionality of the TileLang framework.
- Updated the `allocate.py` to handle boolean types correctly in shared memory allocations.
- Introduced tests for the new logical operations to ensure correctness and performance.
Co-authored-by: default avatarZhiwen Mo <zhiwen.mo25@ic.ac.uk>

* lint fix

---------
Co-authored-by: default avatarZhiwen Mo <zhiwen.mo25@ic.ac.uk>
parent 9a7a569d
Subproject commit 2e033790663f0a470865b38da6931e9addb09238
Subproject commit fbd82e919c01238a5ac78a4d5c66b3da80161255
......@@ -158,6 +158,7 @@ TL_DEVICE void AtomicAddx2(bfloat16_t *address, bfloat16_t *val) {
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
}
#endif
// DP4A
template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
......@@ -166,3 +167,25 @@ TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
const int c_int = *((int *)c);
*c = __dp4a(a_int, b_int, c_int);
}
namespace tl {
// Any
template <typename T> TL_DEVICE bool Any(T *a, int size) {
for (int i = 0; i < size; i++) {
if (a[i]) {
return true;
}
}
return false;
}
// All
template <typename T> TL_DEVICE bool All(T *a, int size) {
for (int i = 0; i < size; i++) {
if (!a[i]) {
return false;
}
}
return true;
}
} // namespace tl
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file flatten_buffer.cc
*/
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "arith/ir_mutator_with_analyzer.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
/*!
* \brief Transform multi-dimension BufferLoad/BufferStore into device-supported
* dimension for the TIR not contains opaque block.
*/
class BufferFlattener : public arith::IRMutatorWithAnalyzer {
public:
static PrimFunc Flatten(PrimFunc func) {
arith::Analyzer ana;
auto pass = BufferFlattener(&ana);
auto writer = func.CopyOnWrite();
pass.MarkBufferMapShapes(func);
writer->body = pass.VisitStmt(func->body);
// The buffers in func->buffer_map are deliberately left
// unflattened, as they are used for validation of user-provided
// arguments. The flattened buffers used in the updated
// function body alias the argument buffers.
return func;
}
private:
using IRMutatorWithAnalyzer::VisitExpr;
using IRMutatorWithAnalyzer::VisitExpr_;
using IRMutatorWithAnalyzer::VisitStmt;
using IRMutatorWithAnalyzer::VisitStmt_;
explicit BufferFlattener(arith::Analyzer *ana) : IRMutatorWithAnalyzer(ana) {}
Stmt VisitStmt_(const BlockNode *op) final {
ICHECK_EQ(op->match_buffers.size(), 0)
<< "Unexpected MatchBufferRegion found during "
"tir.transform.FlattenBuffer. "
<< "All MatchBufferRegion should be removed in "
"tir.transform.LowerMatchBuffer.";
Block block = GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply(
[this](Buffer buf) { return GetFlattenedBuffer(buf); });
if (!alloc_buffers.same_as(op->alloc_buffers)) {
block.CopyOnWrite()->alloc_buffers = alloc_buffers;
}
Array<BufferRegion> reads = op->reads;
reads.MutateByApply(
[this](BufferRegion region) { return MutateBufferRegion(region); });
if (!reads.same_as(op->reads)) {
block.CopyOnWrite()->reads = reads;
}
Array<BufferRegion> writes = op->writes;
writes.MutateByApply(
[this](BufferRegion region) { return MutateBufferRegion(region); });
if (!writes.same_as(op->writes)) {
block.CopyOnWrite()->writes = writes;
}
return StmtExprMutator::VisitStmt_(block.get());
}
Stmt VisitStmt_(const AllocateNode *op) final {
// Determine the flattened extents first, before stripping of
// DeclBuffer.
auto new_extents = [&]() -> Array<PrimExpr> {
if (op->extents.size() == 1) {
// No flattening required for buffers that are already flat
return op->extents;
}
if (auto *decl_buffer = op->body.as<DeclBufferNode>()) {
// N-d buffer, use the DeclBuffer inside to determine how it
// should be flattened.
auto &buffer = decl_buffer->buffer;
bool matching_buffer = [&]() {
if (!decl_buffer->buffer->data.same_as(op->buffer_var)) {
return false;
}
if (op->dtype != buffer->dtype) {
return false;
}
if (op->extents.size() != buffer->shape.size()) {
return false;
}
ExprDeepEqual expr_equal;
for (size_t i = 0; i < op->extents.size(); i++) {
if (!expr_equal(op->extents[i], buffer->shape[i])) {
return false;
}
}
return true;
}();
if (matching_buffer) {
Buffer flattened = GetFlattenedBuffer(buffer);
return flattened->shape;
} else {
ICHECK(decl_buffer->buffer->axis_separators.empty())
<< "DeclBuffer node doesn't match Allocate extents, but also "
"shouldn't be "
"flattened to 1-d physical memory";
}
}
// Fallback, this is an allocation without a matching DeclBuffer
PrimExpr flat_extent = 1;
for (const auto &dim : op->extents) {
flat_extent *= dim;
}
return {flat_extent};
}();
Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (alloc->dtype == DataType::Bool()) {
alloc.CopyOnWrite()->dtype = DataType::Int(8);
}
if (!new_extents.same_as(alloc->extents)) {
alloc.CopyOnWrite()->extents = new_extents;
}
return std::move(alloc);
}
Stmt VisitStmt_(const DeclBufferNode *op) final {
// TODO(rfc-70): Update the DeclBuffer node instead of
// stripping it out. Stripping it out in the current
// implementation as not all lowering passes support
// DeclBuffer.
return VisitStmt(op->body);
}
Buffer GetFlattenedBuffer(Buffer buf) {
auto it = buffer_remap_.find(buf);
if (it != buffer_remap_.end()) {
return it->second;
}
auto flattened = buf.GetFlattenedBuffer();
auto writer = flattened.CopyOnWrite();
// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (flattened->dtype == DataType::Bool()) {
writer->dtype = DataType::Int(8);
}
// canonicalize shape
for (size_t i = 0; i < flattened->shape.size(); ++i) {
writer->shape.Set(i, analyzer_->canonical_simplify(flattened->shape[i]));
}
buffer_remap_[buf] = flattened;
return flattened;
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
bool store_returns_bool = (op->value.dtype() == DataType::Bool());
store = VisitBufferAccess(store);
// Handle casts from the value's dtype to the dtype of the
// backing array.
// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (store_returns_bool) {
ICHECK_EQ(store->buffer->dtype, DataType::Int(8))
<< "Expected int8 backing array for boolean tensor";
auto writer = store.CopyOnWrite();
writer->value = tvm::cast(DataType::Int(8), store->value);
return std::move(store);
}
return std::move(store);
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return VisitBufferAccess(load);
}
Array<PrimExpr> GetSimplifiedElemOffset(const Buffer &buffer,
const Array<PrimExpr> &indices) {
auto flattened_indices = buffer->ElemOffset(indices);
return this->IterMapSimplifyWithContext(flattened_indices, false);
}
template <typename Node> Node VisitBufferAccess(Node node) {
ICHECK(node->buffer.defined());
auto flattened_indices =
GetSimplifiedElemOffset(node->buffer, node->indices);
Buffer flattened_buffer = GetFlattenedBuffer(node->buffer);
auto writer = node.CopyOnWrite();
writer->buffer = flattened_buffer;
writer->indices = flattened_indices;
return node;
}
BufferRegion MutateBufferRegion(BufferRegion region) {
Buffer orig_buf = region->buffer;
Buffer flattened_buf = GetFlattenedBuffer(orig_buf);
if (flattened_buf.same_as(orig_buf)) {
return region;
}
Array<PrimExpr> min_values;
Array<PrimExpr> max_values;
for (const auto &range : region->region) {
min_values.push_back(range->min);
max_values.push_back(range->min + range->extent - 1);
}
Array<PrimExpr> flattened_min =
GetSimplifiedElemOffset(orig_buf, min_values);
Array<PrimExpr> flattened_max =
GetSimplifiedElemOffset(orig_buf, max_values);
Array<Range> flattened_ranges;
ICHECK_EQ(flattened_min.size(), flattened_max.size());
for (size_t i = 0; i < flattened_min.size(); i++) {
flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1));
}
return BufferRegion(flattened_buf, flattened_ranges);
}
/*! \brief Map of buffers being remapped. */
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>
buffer_remap_;
/*! \brief The updated external buffer map. */
Map<Var, Buffer> updated_extern_buffer_map_;
};
PrimFunc FlattenBufferRewriter(PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
return BufferFlattener::Flatten(f);
} else {
return f;
}
}
using namespace tir::transform;
tvm::transform::Pass FlattenBuffer() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return FlattenBufferRewriter(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {});
}
TVM_REGISTER_GLOBAL("tl.transform.FlattenBuffer").set_body_typed(FlattenBuffer);
} // namespace tl
} // namespace tvm
......@@ -121,6 +121,16 @@ struct GlobalMemChecker : public StmtExprVisitor {
PrimExpr index = indices[i];
PrimExpr shape_dim = buffer->shape[i];
bool has_variable = false;
PostOrderVisit(index, [&](const ObjectRef &obj) {
if (const VarNode *v = obj.as<VarNode>()) {
has_variable = true;
}
});
if (!has_variable) {
continue;
}
// We want to check if index < shape_dim can be proven.
// If analyzer->CanProve(index < shape_dim) returns false,
// it means we cannot prove the access is within bounds.
......@@ -160,12 +170,18 @@ private:
if (IsGlobalBuffer(store->buffer)) {
Stmt store_with_conditions = store;
for (auto cond : conditions) {
LOG(INFO) << "condition: " << cond;
LOG(INFO) << "store: " << store;
store_with_conditions = IfThenElse(cond, store_with_conditions);
}
return store_with_conditions;
} else if (isSharedBuffer(store->buffer)) {
PrimExpr value = store->value;
LOG(INFO) << "value: " << value;
LOG(INFO) << "conditions: " << conditions;
for (auto cond : conditions) {
ICHECK(cond.dtype() == DataType::Bool(1))
<< "condition is not a boolean: " << cond;
value = if_then_else(cond, value, make_zero(value->dtype));
}
store.CopyOnWrite()->value = value;
......
......@@ -91,17 +91,18 @@ private:
pinfo.original_order = idx;
// copy stage should only have one reads and one writes
bool write_to_shared = false;
bool write_to_shared_or_local = false;
bool read_from_global = false;
for (auto region : pinfo.reads)
if (region->buffer.scope() == "global")
read_from_global = true;
for (auto region : pinfo.writes)
if (region->buffer.scope() == "shared" ||
region->buffer.scope() == "shared.dyn")
write_to_shared = true;
region->buffer.scope() == "shared.dyn" ||
region->buffer.scope() == "local")
write_to_shared_or_local = true;
pinfo.copy_stage = write_to_shared && read_from_global;
pinfo.copy_stage = write_to_shared_or_local && read_from_global;
return std::move(pinfo);
}
......
......@@ -39,7 +39,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) {
Var buf = op->buffer->data;
StorageScope scope = GetScope(buf);
if (Enabled(buf.get(), scope)) {
ICHECK(allow_append_) << op << " " << scope.to_string();
ICHECK(allow_append_) << GetRef<BufferLoad>(op) << " " << scope.to_string();
AccessEntry e;
e.threads = env_threads();
e.buffer = buf;
......@@ -203,7 +203,12 @@ void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) {
if (!is_thread_invariant) {
++condition_counter_;
}
allow_append_ = true;
this->VisitExpr(op->condition);
curr_stmt_.access.clear();
allow_append_ = false;
scope_.push_back(std::vector<StmtEntry>());
this->VisitStmt(op->then_case);
StmtEntry s;
......@@ -244,25 +249,28 @@ void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) {
void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
if (op->op.same_as(builtin::address_of())) {
ICHECK_EQ(op->args.size(), 1U);
const BufferLoadNode *load = op->args[0].as<BufferLoadNode>();
Buffer buffer = load->buffer;
DataType dtype = buffer->dtype;
const VarNode *buffer_var = buffer->data.as<VarNode>();
StorageScope scope = GetScope(GetRef<Var>(buffer_var));
if (Enabled(buffer_var, scope)) {
ICHECK(allow_append_);
AccessEntry e;
e.threads = env_threads();
e.dtype = dtype;
e.buffer = Downcast<Var>(buffer->data);
for (const auto &index : load->indices) {
e.touched.push_back(arith::IntSet::Vector(index));
if (auto load = op->args[0].as<BufferLoadNode>()) {
Buffer buffer = load->buffer;
DataType dtype = buffer->dtype;
const VarNode *buffer_var = buffer->data.as<VarNode>();
StorageScope scope = GetScope(GetRef<Var>(buffer_var));
if (Enabled(buffer_var, scope)) {
ICHECK(allow_append_);
AccessEntry e;
e.threads = env_threads();
e.dtype = dtype;
e.buffer = Downcast<Var>(buffer->data);
for (const auto &index : load->indices) {
e.touched.push_back(arith::IntSet::Vector(index));
}
e.type = kRead;
e.scope = scope;
curr_stmt_.access.emplace_back(e);
}
e.type = kRead;
e.scope = scope;
curr_stmt_.access.emplace_back(e);
StmtExprVisitor::VisitExpr_(load);
} else {
StmtExprVisitor::VisitExpr_(op);
}
StmtExprVisitor::VisitExpr_(load);
} else if (op->op.same_as(builtin::tvm_access_ptr())) {
ICHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
......
......@@ -444,17 +444,20 @@ public:
ICHECK_EQ(op->args.size(), 1U)
<< "address_of should only have one argument (Buffer)";
BufferLoad load = Downcast<BufferLoad>(op->args[0]);
Var buffer_var(Downcast<Var>(load->buffer->data));
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[buffer_var].read_count;
}
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[buffer_var].write_count;
if (auto load = op->args[0].as<BufferLoadNode>()) {
Var buffer_var(Downcast<Var>(load->buffer->data));
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[buffer_var].read_count;
}
if (sync_scope_.rank == StorageRank::kGlobal &&
GetScope(buffer_var).rank == StorageRank::kGlobal) {
++rw_stats_[buffer_var].write_count;
}
return expr;
} else {
return StmtExprMutator::VisitExpr_(op);
}
return expr;
} else {
return StmtExprMutator::VisitExpr_(op);
}
......
import tilelang
import tilelang.testing
import tilelang.language as T
import torch
def ref_program(A, B, BlockMask, block_M, block_N, block_K):
M, K = A.shape
N = B.shape[1]
ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device)
for i in range(M // block_M):
for j in range(N // block_N):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K):
if torch.all(BlockMask[i, j, k]):
accu += A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(
torch.float32) @ B[k * block_K:(k + 1) * block_K,
j * block_N:(j + 1) * block_N].to(torch.float32)
ref_c[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * block_N] = (
accu.to(torch.float16))
return ref_c
def blocksparse_matmul_global(
M,
N,
K,
condition_dim,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float",
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if T.all_of(BlockMask[by, bx, k, :]):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def blocksparse_matmul_shared(
M,
N,
K,
condition_dim,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float",
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
block_mask_shared = T.alloc_shared(condition_dim, "bool")
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
for i in T.serial(condition_dim):
block_mask_shared[i] = BlockMask[by, bx, k, i]
# or T.all_of(block_mask_local[0:condition_dim])
# or T.all_of(block_mask_local[:])
if T.all_of(block_mask_shared):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def blocksparse_matmul_local(
M,
N,
K,
condition_dim,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float",
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
block_mask_local = T.alloc_local(condition_dim, "bool")
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
for i in T.serial(condition_dim):
block_mask_local[i] = BlockMask[by, bx, k, i]
# or T.all_of(block_mask_local[0:condition_dim])
# or T.all_of(block_mask_local[:])
if T.all_of(block_mask_local):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def run_block_sparse_matmul_global(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2):
block_M = 128
block_N = 128
block_K = 32
num_stages = 2
thread_num = 128
enable_rasteration = True
# Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
func = blocksparse_matmul_global(
M,
N,
K,
condition_dim,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
)
kernel = tilelang.compile(func, out_idx=-1)
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
block_mask = block_mask.view(mask_shape + (1,)).repeat(1, 1, 1, condition_dim)
# random set the last dimension to be False
block_mask[:, :, :, 0] = False
# Run the compiled kernel (either tuned or default) with the inputs
c = kernel(a, b, block_mask)
# Compute the reference result using the naive PyTorch implementation
ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2):
block_M = 128
block_N = 128
block_K = 32
num_stages = 2
thread_num = 128
enable_rasteration = True
# Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
func = blocksparse_matmul_shared(
M,
N,
K,
condition_dim,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
)
kernel = tilelang.compile(func, out_idx=-1)
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
block_mask = block_mask.view(mask_shape + (1,)).repeat(1, 1, 1, condition_dim)
# random set the last dimension to be False
block_mask[:, :, :, 0] = False
# Run the compiled kernel (either tuned or default) with the inputs
c = kernel(a, b, block_mask)
# Compute the reference result using the naive PyTorch implementation
ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2):
block_M = 128
block_N = 128
block_K = 32
num_stages = 2
thread_num = 128
enable_rasteration = True
# Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
func = blocksparse_matmul_local(
M,
N,
K,
condition_dim,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
)
kernel = tilelang.compile(func, out_idx=-1)
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
block_mask = block_mask.view(mask_shape + (1,)).repeat(1, 1, 1, condition_dim)
# random set the last dimension to be False
block_mask[:, :, :, 0] = False
# Run the compiled kernel (either tuned or default) with the inputs
c = kernel(a, b, block_mask)
# Compute the reference result using the naive PyTorch implementation
ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
def test_block_sparse_matmul_global():
run_block_sparse_matmul_global(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2)
def test_block_sparse_matmul_shared():
run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2)
def test_block_sparse_matmul_local():
run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2)
if __name__ == "__main__":
tilelang.testing.main()
import tilelang
import tilelang.testing
import tilelang.language as T
import torch
def ref_program(A, B, BlockMask, block_M, block_N, block_K):
M, K = A.shape
N = B.shape[1]
ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device)
for i in range(M // block_M):
for j in range(N // block_N):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K):
if torch.any(BlockMask[i, j, k]):
accu += A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(
torch.float32) @ B[k * block_K:(k + 1) * block_K,
j * block_N:(j + 1) * block_N].to(torch.float32)
ref_c[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * block_N] = (
accu.to(torch.float16))
return ref_c
def blocksparse_matmul_global(
M,
N,
K,
condition_dim,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float",
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if T.any_of(BlockMask[by, bx, k, :]):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def blocksparse_matmul_shared(
M,
N,
K,
condition_dim,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float",
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
block_mask_shared = T.alloc_shared(condition_dim, "bool")
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
for i in T.serial(condition_dim):
block_mask_shared[i] = BlockMask[by, bx, k, i]
# or T.any_of(block_mask_local[0:condition_dim])
# or T.any_of(block_mask_local[:])
if T.any_of(block_mask_shared):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def blocksparse_matmul_local(
M,
N,
K,
condition_dim,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float",
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
block_mask_local = T.alloc_local(condition_dim, "bool")
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
for i in T.serial(condition_dim):
block_mask_local[i] = BlockMask[by, bx, k, i]
# or T.any_of(block_mask_local[0:condition_dim])
# or T.any_of(block_mask_local[:])
if T.any_of(block_mask_local):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def run_block_sparse_matmul_global(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2):
block_M = 128
block_N = 128
block_K = 32
num_stages = 2
thread_num = 128
enable_rasteration = True
# Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
func = blocksparse_matmul_global(
M,
N,
K,
condition_dim,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
)
kernel = tilelang.compile(func, out_idx=-1)
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
block_mask = block_mask.view(mask_shape + (1,)).repeat(1, 1, 1, condition_dim)
# random set the last dimension to be False
block_mask[:, :, :, 0] = False
# Run the compiled kernel (either tuned or default) with the inputs
c = kernel(a, b, block_mask)
# Compute the reference result using the naive PyTorch implementation
ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2):
block_M = 128
block_N = 128
block_K = 32
num_stages = 2
thread_num = 128
enable_rasteration = True
# Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
func = blocksparse_matmul_shared(
M,
N,
K,
condition_dim,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
)
kernel = tilelang.compile(func, out_idx=-1)
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
block_mask = block_mask.view(mask_shape + (1,)).repeat(1, 1, 1, condition_dim)
# random set the last dimension to be False
block_mask[:, :, :, 0] = False
# Run the compiled kernel (either tuned or default) with the inputs
c = kernel(a, b, block_mask)
# Compute the reference result using the naive PyTorch implementation
ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2):
block_M = 128
block_N = 128
block_K = 32
num_stages = 2
thread_num = 128
enable_rasteration = True
# Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
func = blocksparse_matmul_local(
M,
N,
K,
condition_dim,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
)
kernel = tilelang.compile(func, out_idx=-1)
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
block_mask = block_mask.view(mask_shape + (1,)).repeat(1, 1, 1, condition_dim)
# random set the last dimension to be False
block_mask[:, :, :, 0] = False
# Run the compiled kernel (either tuned or default) with the inputs
c = kernel(a, b, block_mask)
# Compute the reference result using the naive PyTorch implementation
ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
def test_block_sparse_matmul_global():
run_block_sparse_matmul_global(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2)
def test_block_sparse_matmul_shared():
run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2)
def test_block_sparse_matmul_local():
run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, condition_dim=2)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -52,7 +52,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# TODO(lei): may need a pass to fuse the if-then-else in the
# pipeline loop when we meet dynamic branch.
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.FlattenBuffer()(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop()(mod)
......
......@@ -58,6 +58,7 @@ from .customize import (
reshape, # noqa: F401
view, # noqa: F401
)
from .logical import any_of, all_of # noqa: F401
from .builtin import * # noqa: F401
from .memscope import * # noqa: F401
......
......@@ -28,6 +28,10 @@ def alloc_shared(shape, dtype, scope="shared.dyn"):
Returns:
T.Buffer: A TVM buffer object allocated in shared memory
"""
if dtype == "bool":
# lei: This is a hack to handle bool type.
# Because tilelang's merge smem pass cannot merge bool type currently.
scope = "shared"
return T.alloc_buffer(shape, dtype, scope=scope)
......
"""The language interface for tl programs."""
from tilelang import language as T
import tvm
from tvm.tir import Buffer, BufferRegion
from tvm.ir import Range
from tvm.ir import register_op_attr, register_intrin_lowering
from tvm import tir
from typing import Union
from tilelang.utils.language import get_buffer_elems
# TODO: move this part into src to reduce runtime overhead
def any_of_op(op):
args = op.args
assert len(args) == 2
buffer_address, elems = args
return T.call_extern("bool", "tl::Any", buffer_address, elems)
register_op_attr("tl.any_of", "TCallEffectKind", tvm.tir.CallEffectKind.Pure)
register_op_attr("tl.any_of", "TScriptPrinterName", "any_of")
register_intrin_lowering("tl.any_of", target="cuda", f=any_of_op)
def any_of(buffer: Union[T.Tensor, BufferRegion]):
"""Check if any element in the buffer is true.
Args:
buffer: Either a TVM buffer or buffer region to be checked
Returns:
A TVM intrinsic call that performs the any operation
"""
return_type: str = "bool"
if isinstance(buffer, Buffer):
elems = get_buffer_elems(buffer)
return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer), elems)
elif isinstance(buffer, BufferRegion):
buffer, region = buffer.buffer, buffer.region
new_region = []
extent = 1
for i, r in enumerate(region):
extent = r.extent
if extent == 1:
new_region.append(r)
else:
# check the idx is the last dimension
if i != len(region) - 1:
raise ValueError(
"Only support the last dimension to be for T.any currently, please contact us if you need this feature"
)
new_region.append(Range(r.min, 1))
buffer = BufferRegion(buffer, new_region)
return T.call_intrin(return_type, tir.op.Op.get("tl.any_of"), T.address_of(buffer), extent)
else:
raise ValueError(f"Invalid buffer type: {type(buffer)}")
def all_of_op(op):
args = op.args
assert len(args) == 2
buffer_address, elems = args
return T.call_extern("bool", "tl::All", buffer_address, elems)
register_op_attr("tl.all_of", "TCallEffectKind", tvm.tir.CallEffectKind.Pure)
register_op_attr("tl.all_of", "TScriptPrinterName", "all_of")
register_intrin_lowering("tl.all_of", target="cuda", f=all_of_op)
def all_of(buffer: Union[T.Tensor, BufferRegion]):
"""Check if all elements in the buffer are true.
Args:
buffer: Either a TVM buffer or buffer region to be checked
Returns:
A TVM intrinsic call that performs the any operation
"""
return_type: str = "bool"
if isinstance(buffer, Buffer):
elems = get_buffer_elems(buffer)
return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer), elems)
elif isinstance(buffer, BufferRegion):
buffer, region = buffer.buffer, buffer.region
new_region = []
extent = 1
for i, r in enumerate(region):
extent = r.extent
if extent == 1:
new_region.append(r)
else:
# check the idx is the last dimension
if i != len(region) - 1:
raise ValueError(
"Only support the last dimension to be for T.any currently, please contact us if you need this feature"
)
new_region.append(Range(r.min, 1))
buffer = BufferRegion(buffer, new_region)
return T.call_intrin(return_type, tir.op.Op.get("tl.all_of"), T.address_of(buffer), extent)
else:
raise ValueError(f"Invalid buffer type: {type(buffer)}")
......@@ -295,3 +295,14 @@ def ConfigIndexBitwidth():
----
"""
return _ffi_api.ConfigIndexBitwidth() # type: ignore
def FlattenBuffer():
"""FlattenBuffer
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FlattenBuffer() # type: ignore
......@@ -77,6 +77,13 @@ def is_fragment(buffer: Buffer) -> bool:
return buffer.scope().startswith("local.fragment")
def get_buffer_elems(buffer: Buffer) -> int:
"""
Get the number of elements in the buffer.
"""
return reduce(lambda x, y: x * y, buffer.shape)
def array_reduce(array: List[int]) -> int:
"""
Reduce an array of integers to a single integer.
......
......@@ -46,7 +46,7 @@ def adapt_torch2tvm(arg):
return arg
def get_tensor_supply(supply_type: TensorSupplyType):
def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
from tilelang.engine.param import KernelParam
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment