Unverified Commit 8eab7755 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Reducer] Introduce `alloc_reducer` to separate inter and intra warp reduction (#757)



* [Enhancement] Introduce finalize_reducer operator and layout reducer support

- Added `FinalizeReducer` operator to handle reduction finalization in the TileLang framework, allowing for efficient reduction operations.
- Implemented layout inference for local.reducer buffers, enhancing the handling of layout mappings and reducing complexity in buffer management.
- Updated `setup.py` to include logging for build directory paths, improving build process visibility.
- Enhanced atomic operations with new functions for atomic max, min, load, and store, providing more robust atomicity control in memory operations.
- Refactored parallel loop handling to incorporate reducer information, ensuring proper management of reduction operations in parallel contexts.
- Cleaned up test cases by removing unnecessary cache disabling and optimizing test parameters for better performance.

* Refactor code formatting and improve readability in multiple files

- Cleaned up whitespace in `setup.py` to enhance logging clarity.
- Reformatted `AtomicMax` and `AtomicMin` functions in `common.h` for better alignment and readability.
- Adjusted `debug_print_var` function in `debug.h` to improve code structure and maintainability.
- Enhanced readability of the `atomic_add` function in `customize.py` by breaking long lines for better clarity.

* Remove debug print statements from `copy.cc` and `inject_tma_barrier.cc` to enhance code clarity and maintainability.

* [Enhancement] Disable reuse of small arrays in shared memory allocation

- Added logic to prevent the reuse of small arrays (<= 32 bits) in `merge_shared_memory_allocations.cc`, ensuring they are lowered to registers in LLVM for improved performance and memory management.

* Refactor `setup.py` to remove duplicate logging statements and enhance clarity. Update `finalize_reducer` function documentation in `reduce.py` to include detailed parameter and return descriptions, improving code readability and maintainability.

* Refactor `finalize_reducer` and `reduce` functions to remove redundant target checks. Simplified conditionals by retaining only the `TargetIsHopper` check, enhancing code clarity and maintainability.

* bug fix

* Add thread checks workaround for replicated cases

* Remove the is_one check

* fix lint error

* lint fix

* Update autotune tests to use smaller matrix sizes for improved performance and reliability

* [Refactor] Update FinalizeReducer to FinalizeReducerOp and adjust related methods

- Refactored FinalizeReducer class to FinalizeReducerOp, updating constructor and method signatures for consistency with the new TileOperator structure.
- Enhanced layout inference and cloning methods in FinalizeReducerOpNode.
- Updated test_example_flash_attention.py to call test_example_gqa_bwd instead of tilelang.testing.main.
- Adjusted header inclusions for improved organization and clarity across multiple files.

* [Refactor] Update atomic operations in common.h and modify test_example_flash_attention.py

- Enhanced atomic operations (Add, Min, Max) in common.h to handle half and bfloat16 types more efficiently.
- Updated test_example_flash_attention.py to call test_example_gqa_bwd instead of tilelang.testing.main, improving test organization.

* [Refactor] Simplify CopyNode::LowerBulkCopy logic and update test execution

- Removed redundant checks for contiguous memory access in CopyNode::LowerBulkCopy, streamlining the logic for TMA copy operations.
- Updated test_tilelang_kernel_gemm.py to comment out the main testing function and call a specific test for i8i8i32 tensor operations instead, improving test focus.

---------
Co-authored-by: default avatarHuanqi Cao <caohuanqi@deepseek.com>
Co-authored-by: default avatarFreebase6912 <amid-gauze-racing@duck.com>
parent b38bd69e
......@@ -767,8 +767,6 @@ class TilelangExtensionBuild(build_ext):
if self.inplace:
extdir = os.path.abspath('./tilelang/lib/')
logger.info(f"{extdir=}")
# Prepare arguments for the CMake configuration step.
# -DCMAKE_LIBRARY_OUTPUT_DIRECTORY sets where built libraries go
# -DPYTHON_EXECUTABLE ensures that the correct Python is used
......
......@@ -129,6 +129,14 @@ TVM_DLL const Op &tma_load_im2col();
*/
TVM_DLL const Op &tma_store();
/*!
* \brief tvm intrinsics for barrier initialization fence
*
* ptx_fence_barrier_init()
*
*/
const Op &ptx_fence_barrier_init();
/*!
* \brief tvm intrinsics for mbarrier wait with parity bit
*
......
/*!
* \file src/op/finalize_reducer.cc
*
* Define finalize_reducer operator.
*/
#include "finalize_reducer.h"
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../target/utils.h"
namespace tvm {
namespace tl {
using namespace tir;
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
auto node = make_object<FinalizeReducerOpNode>();
node->reducer = vmap[GetVarFromAccessPtr(args[0])];
node->op = (ReducerOpType)*as_const_int(args[1]);
data_ = std::move(node);
}
Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
auto buffer = T.buffer_remap[reducer];
auto opt_layout = T.layout_map.Get(reducer);
ICHECK(opt_layout);
ICHECK(opt_layout->as<Fragment>());
auto layout = opt_layout->as<Fragment>().value();
Array<PrimExpr> indices_0;
indices_0.reserve(layout->OutputDim());
for (int i = 0; i < layout->OutputDim(); ++i)
indices_0.push_back(Var("__finred_" + std::to_string(i)));
const int64_t *p_extent = as_const_int(layout->ReplicateExtent());
ICHECK(p_extent);
int extent = *p_extent, scale = 1;
ICHECK(extent == 1 || extent == *as_const_int(T.thread_bounds->extent))
<< "Illegal finalize_reducer: extent=" << extent
<< "; T.thread_bounds=" << T.thread_bounds;
if (extent == 1)
return Evaluate(0);
std::array op_names{"tl::SumOp", "tl::MaxOp", "tl::MinOp"};
auto op_str = op_names[(int)op];
// adopted from ReduceOp
int reducing_threads = extent;
std::stringstream ss;
auto thread_offset = T.thread_bounds->min;
if (TargetIsHopper(T.target)) {
auto all_threads = T.thread_bounds->extent;
ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1
<< ", " << thread_offset << ", " << all_threads << ">::run_hopper";
} else {
ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1
<< ", " << thread_offset << ">::run";
}
Array<PrimExpr> thread_reduce_args = {StringImm(ss.str()),
BufferLoad(buffer, indices_0)};
if (reducing_threads >= 32) {
PrimExpr workspace =
T.AddWorkspace(*as_const_int(T.thread_bounds->extent), buffer->dtype);
thread_reduce_args.push_back(workspace);
}
auto call = Call(buffer->dtype, builtin::call_extern(), thread_reduce_args);
Stmt body = BufferStore(buffer, call, indices_0);
// make the outer spatial loop
for (int i = layout->OutputDim() - 1; i >= 0; i--) {
body = For(indices_0[i].as<Var>().value(), 0, layout->OutputShape()[i],
ForKind::kParallel, body);
}
return body;
}
LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
LayoutMap layout_map;
layout_map.Set(reducer, T.layout_map.Get(reducer).value());
return layout_map;
}
TileOperator FinalizeReducerOpNode::Clone() const {
auto node = make_object<FinalizeReducerOpNode>(*this);
return TileOperator(node);
}
TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
// Copyright (c) Tile-AI Corporation.
// Licensed under the MIT License.
/*!
* \file src/op/finalize_reducer.h
* \brief Define finalize_reducer operator.
*/
#ifndef TVM_TL_OP_FINALIZE_REDUCER_H_
#define TVM_TL_OP_FINALIZE_REDUCER_H_
#include "../transform/layout_reducer.h"
#include "./operator.h"
namespace tvm {
namespace tl {
using namespace tir;
class FinalizeReducerOpNode : public TileOperatorNode {
public:
tir::Buffer reducer;
ReducerOpType op;
static constexpr const char *_type_key = "tl.FinalizeReducerOp";
TVM_DECLARE_FINAL_OBJECT_INFO(FinalizeReducerOpNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
static const Op &Get();
TileOperator Clone() const;
};
class FinalizeReducerOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(FinalizeReducerOp, TileOperator,
FinalizeReducerOpNode);
TVM_DLL FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_FINALIZE_REDUCER_H_
\ No newline at end of file
......@@ -124,6 +124,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
p->loop_vars_.push_back(
IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
auto reducer_info_map =
op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
if (reducer_info_map) {
for (auto &&[buffer, info] : reducer_info_map.value())
p->reducer_info_map_.Set(buffer, info);
}
StmtExprVisitor::VisitStmt_(op);
}
......@@ -202,6 +208,11 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
Buffer source_buffer, read_source_buffer;
for (const auto &[buffer, indices] : indice_map_) {
if (T.layout_map.count(buffer)) {
// skip reducers with rep=ALL
if (auto info = reducer_info_map_.Get(buffer->data);
info && info.value()->rep == ReducerRepType::ALL)
continue;
auto frag = T.layout_map[buffer].as<Fragment>().value();
if (buffer_is_write_.count(buffer)) {
source_buffer = buffer;
......@@ -298,6 +309,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_);
PrimExpr loop_total_size = 1;
for (Stmt l = root_; l.as<For>().has_value();
l = l.as<For>().value()->body)
loop_total_size = loop_total_size * l.as<For>().value()->extent;
while (!analyzer_.CanProve(
floormod(loop_total_size,
T.thread_bounds->extent * vector_size) == 0) &&
vector_size > 1)
vector_size /= 2;
// Check if coalesced_width is defined
if (auto coalesced_width =
root_->annotations.Get(tl::attr::coalesced_width)) {
......@@ -343,11 +364,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) {
auto fragment = T.layout_map[buffer].as<Fragment>().value();
// TODO: Add thread checks for replicated cases
// need to wildcard match the rhs with lhs
if (!is_one(loop_layout_->ReplicateExtent()) ||
!is_one(fragment->ReplicateExtent()))
continue;
auto vars =
loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
if (!ProveFragmentContains(loop_layout_, fragment, vars,
......
......@@ -10,7 +10,8 @@
#include <tvm/tir/stmt_functor.h>
#include "../layout/layout.h"
#include "operator.h"
#include "../transform/layout_reducer.h"
#include "./operator.h"
namespace tvm {
namespace tl {
......@@ -112,6 +113,8 @@ private:
Array<IterVar> loop_vars_;
// Analyzer for simplifying and analyzing expressions, mutable for lazy use.
mutable arith::Analyzer analyzer_;
// Mapping from buffer to reducer info.
Map<Var, ReducerInfo> reducer_info_map_;
};
class ParallelOp : public TileOperator {
......
......@@ -13,6 +13,7 @@
#include "../layout/utils.h"
#include "../op/parallel.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "tir/transforms/ir_utils.h"
......@@ -237,9 +238,8 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int reducing_threads = (*extent) * (*scale);
std::stringstream ss;
bool has_arch = T.target->attrs.count("arch") > 0;
auto thread_offset = T.thread_bounds->min;
if (has_arch && Downcast<String>(T.target->attrs["arch"]) == "sm_90") {
if (TargetIsHopper(T.target)) {
auto all_threads = T.thread_bounds->extent;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ", " << thread_offset
......
......@@ -1134,10 +1134,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(tl::sync_grid())) {
this->need_cooperative_groups_ = true;
this->PrintIndent();
this->stream << "cooperative_groups::grid_group grid = "
"cooperative_groups::this_grid();\n";
this->PrintIndent();
this->stream << "grid.sync();\n";
this->stream << "cooperative_groups::this_grid().sync();\n";
} else if (op->op.same_as(tl::loop_break())) {
this->PrintIndent();
this->stream << "break;\n";
......
......@@ -4,6 +4,7 @@
#include <cuda_runtime.h>
#endif
#include <cuda/atomic>
#include <cutlass/fast_math.h>
#include <cutlass/numeric_types.h>
#include <math_constants.h>
......@@ -42,7 +43,7 @@ using int4_t = int4;
do { \
cudaError_t __err = cudaGetLastError(); \
if (__err != cudaSuccess) { \
snprintf(error_buf, ERROR_BUF_SIZE, "kernel_name: %s - %s", \
snprintf(error_buf, ERROR_BUF_SIZE, kernel_name ": %s - %s", \
cudaGetErrorName(__err), cudaGetErrorString(__err)); \
return -1; \
} \
......@@ -118,47 +119,72 @@ TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) {
return smem_int;
}
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 *address, T2 val) {
atomicAdd(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
}
template <typename T> struct normalize_atomic_type {
using type = T;
};
// // AtomicAdd Functions for FP32
// TL_DEVICE void AtomicAdd(float *address, float val) {
// atomicAdd(reinterpret_cast<float *>(address), val);
// }
template <> struct normalize_atomic_type<half_t> {
using type = half;
};
// AtomicAdd Functions for FP16
template <> TL_DEVICE void AtomicAdd(half_t *address, half_t val) {
// Use atomicCAS with built-in cuda_fp16 support
atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(val));
}
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
template <> struct normalize_atomic_type<bfloat16_t> {
using type = __nv_bfloat16;
};
#endif
// AtomicAdd Functions for FP16
template <> TL_DEVICE void AtomicAdd(half_t *address, half_t *val) {
atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(*val));
template <typename T1, typename T2> TL_DEVICE T1 cuda_cast(T2 val) {
return T1(val);
}
// AtomicAdd Functions for FP16
template <> TL_DEVICE void AtomicAdd(half_t *address, float val) {
// Use atomicCAS with built-in cuda_fp16 support
atomicAdd(reinterpret_cast<half *>(address), __float2half(val));
template <> TL_DEVICE half cuda_cast<half, float>(float val) {
return __float2half(val);
}
// AtomicAdd Functions for BFLOAT16
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
// AtomicAdd Functions for BFLOAT16
template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t *val) {
atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
static_cast<__nv_bfloat16>(*val));
template <> TL_DEVICE __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) {
return __float2bfloat16(val);
}
#endif
// AtomicAdd Functions for BFLOAT16
template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, float val) {
atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), __float2bfloat16(val));
template <typename T1, typename T2>
TL_DEVICE void AtomicMax(T1 *address, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
atomicMax(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_max(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
}
}
#endif
template <typename T1, typename T2>
TL_DEVICE void AtomicMin(T1 *address, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
atomicMin(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_min(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
}
}
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 *address, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.fetch_add(cuda_cast<NT1>(val), cuda::memory_order(memory_order));
}
}
// AtomicAdd Functions for FP16x2
TL_DEVICE void AtomicAddx2(half_t *address, half_t *val) {
......@@ -168,12 +194,6 @@ TL_DEVICE void AtomicAddx2(half_t *address, half_t *val) {
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
// AtomicAdd Functions for BFLOAT16
template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t val) {
atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
static_cast<__nv_bfloat16>(val));
}
// AtomicAdd Functions for BFLOAT16x2
TL_DEVICE void AtomicAddx2(bfloat16_t *address, bfloat16_t *val) {
atomicAdd(
......@@ -195,6 +215,18 @@ TL_DEVICE void AtomicAddx4(float *address, float *val) {
}
#endif
template <typename T> TL_DEVICE T AtomicLoad(T *address, int memory_order) {
cuda::atomic_ref<T, cuda::thread_scope_device> aref(*address);
return aref.load(cuda::memory_order(memory_order));
}
template <typename T1, typename T2>
TL_DEVICE void AtomicStore(T1 *address, T2 value, int memory_order) {
using NT1 = typename normalize_atomic_type<T1>::type;
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order));
}
// DP4A
template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
......
......@@ -10,6 +10,15 @@
// Template declaration for device-side debug printing (variable only)
template <typename T> __device__ void debug_print_var(const char *msg, T var);
// Overload for pointer type (supports any cv-qualified T*)
template <typename T> __device__ void debug_print_var(const char *msg, T *var) {
printf(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=pointer "
"value=%p\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, var);
}
// Specialization for signed char type
template <>
__device__ void debug_print_var<signed char>(const char *msg, signed char var) {
......
#pragma once
#include "common.h"
#include "cuda_fp8.h"
#include "gemm_mma.h"
#include "intrin.h"
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm90.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/collective/collective_builder.hpp>
......
......@@ -21,6 +21,7 @@
#include "common/loop_fusion_utils.h"
#include "common/loop_parallel_transform_utils.h"
#include "common/union_find.h"
#include "layout_reducer.h"
#include "loop_partition.h"
#include "loop_vectorize.h"
#include "runtime/thread_storage_scope.h"
......@@ -570,6 +571,12 @@ private:
}
Stmt VisitStmt_(const ForNode *op) final {
Map<Var, ReducerInfo> reducer_info;
if (op->annotations.count(attr::kReducerInfo))
reducer_info = op->annotations.Get(attr::kReducerInfo)
->as<Map<Var, ReducerInfo>>()
.value();
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(GetRef<For>(op))) {
auto root = GetRef<For>(op);
......@@ -614,8 +621,17 @@ private:
}
}
});
// Workaround: if reducer is presented, don't vectorize loop
// Best solution should be isolate reduction axis out of vectorization
bool has_reducer = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (!has_reducer)
if (const auto *store = obj.as<BufferStoreNode>()) {
has_reducer = reducer_info.count(store->buffer->data) != 0;
}
});
if (has_non_local) {
if (has_non_local && !has_reducer) {
for_node = VectorizeLoop(for_node);
}
......
/*!
* \file layout_reducer.cc
*
* Compute layout for local.reducer buffers and lower them to local.fragment.
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../layout/layout.h"
#include "../op/elem.h"
#include "../op/finalize_reducer.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "layout_reducer.h"
namespace tvm {
namespace tl {
using namespace tir;
using namespace tir::transform;
using arith::IRMutatorWithAnalyzer;
ReducerInfoNode::ReducerInfoNode(const String &op_str, const String &rep_str) {
if (op_str == "sum")
op = ReducerOpType::SUM;
else if (op_str == "max")
op = ReducerOpType::MAX;
else if (op_str == "min")
op = ReducerOpType::MIN;
else
ICHECK(false) << "Unrecognized reducer_info op: " << op_str;
if (rep_str == "all")
rep = ReducerRepType::ALL;
else if (rep_str == "none")
rep = ReducerRepType::NONE;
else
ICHECK(false) << "Unrecognized reducer_info rep: " << rep_str;
}
class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer {
public:
private:
Stmt VisitStmt_(const AttrStmtNode *op) final {
auto prev_thread_var = thread_var_;
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
ICHECK(iv->dom->extent.as<IntImmNode>());
thread_var_ = iv;
}
}
auto result = IRMutatorWithAnalyzer::VisitStmt_(op);
thread_var_ = prev_thread_var;
return result;
}
Stmt VisitStmt_(const BlockNode *op) final {
// Record annotations
if (op->annotations.count(attr::kReducerInfo)) {
auto map = op->annotations.Get(attr::kReducerInfo)
->as<Map<Var, Map<String, String>>>();
ICHECK(map) << "reducer_replication map is not defined";
for (auto &&[var, rep] : map.value()) {
reducer_info_map_.Set(
var, ReducerInfo{rep.Get("op").value(), rep.Get("rep").value()});
}
}
for (auto &&buffer : op->alloc_buffers) {
var_to_buffer_.Set(buffer->data, buffer);
}
auto result = IRMutatorWithAnalyzer::VisitStmt_(op).as<Block>().value();
// After iterating over the body, set all layout_map to block
auto p_result = result.CopyOnWrite();
auto layout_map = p_result->annotations.Get(attr::kLayoutMap)
->as<Map<Var, Layout>>()
.value_or(Map<Var, Layout>());
for (auto &&[k, v] : new_layout_map_)
layout_map.Set(k, v);
if (layout_map.size())
p_result->annotations.Set(attr::kLayoutMap, layout_map);
new_layout_map_.clear();
return result;
}
Stmt VisitStmt_(const ForNode *op) final {
// only annotate the outermost loop
bool should_annotate = false;
if (inside_reducer_range_.size() > 0 && !already_annotated_) {
should_annotate = true;
already_annotated_ = true;
}
auto opt_result = IRMutatorWithAnalyzer::VisitStmt_(op).as<For>();
ICHECK(opt_result);
auto result = opt_result.value();
if (should_annotate) {
// we are leaving the current loop nest. later ones may annotate again
already_annotated_ = false;
auto p_result = result.CopyOnWrite();
p_result->annotations.Set(attr::kReducerInfo, inside_reducer_range_);
// Iterate over local.reducer.* buffers, append to reducer_op_map_, set
// layout by adding layout_map annotations, and convert scope to
// local.fragment
for (auto &&[reducer_var, info] : inside_reducer_range_) {
// analyze thread index bound, need to be inside WS section
ICHECK(thread_var_.defined());
ICHECK(analyzer_->const_int_bound.IsBound(thread_var_->var));
auto const_int_bound = analyzer_->const_int_bound(thread_var_);
auto dtype = thread_var_->var.dtype();
int thread_min = const_int_bound->min_value;
int thread_extent =
const_int_bound->max_value - const_int_bound->min_value + 1;
auto opt_buffer = var_to_buffer_.Get(reducer_var);
ICHECK(opt_buffer);
auto buffer = opt_buffer.value();
Fragment f;
if (info->rep == ReducerRepType::ALL) {
f = Fragment(buffer->shape, {}, ReplicationPlaceholder(),
thread_extent, std::nullopt);
} else if (info->rep == ReducerRepType::NONE) {
PrimExpr flatten_idx = InputPlaceholder(0);
for (int i = 1; i < buffer->shape.size(); ++i)
flatten_idx = flatten_idx * buffer->shape[i] + InputPlaceholder(i);
f = Fragment(buffer->shape, {},
indexmod(flatten_idx, thread_extent) + thread_min, 1,
std::nullopt);
}
new_layout_map_.Set(buffer->data, f);
}
}
return result;
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
//! TODO: check store viable according to info->op
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
PrimExpr VisitExpr_(const CallNode *op_) final {
auto op_ref = IRMutatorWithAnalyzer::VisitExpr_(op_).as<Call>().value();
auto op = op_ref.CopyOnWrite();
if (op->op.same_as(Fill::Get())) {
ICHECK(op->args.size() > 0);
if (auto arg0_call = op->args[0].as<Call>();
arg0_call &&
arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
ICHECK(arg0_call.value()->args.size() > 1);
if (auto var = arg0_call.value()->args[1].as<Var>();
var && reducer_info_map_.count(var.value())) {
ICHECK(inside_reducer_range_.count(var.value()) == 0)
<< "T.fill on reducer must be enclosed with a T.finalize_reducer "
"before next.";
inside_reducer_range_.Set(var.value(),
reducer_info_map_.Get(var.value()).value());
}
}
} else if (op->op.same_as(FinalizeReducerOp::Get())) {
ICHECK(op->args.size() == 1);
auto var = GetVarFromAccessPtr(op->args[0]);
ICHECK(inside_reducer_range_.count(var) == 1)
<< "T.finalize_reducer must have a pairing T.fill ahead of it, "
"enclosing a reduction range.";
op->args.push_back((int)inside_reducer_range_.Get(var).value()->op);
inside_reducer_range_.erase(var);
}
return op_ref;
}
ReducerLayoutAnnotator(arith::Analyzer *analyzer)
: IRMutatorWithAnalyzer(analyzer) {}
IterVar thread_var_;
Map<Var, ReducerInfo> reducer_info_map_;
Map<Var, ReducerInfo> inside_reducer_range_;
bool already_annotated_ = false;
Map<Var, Buffer> var_to_buffer_;
Map<Var, Layout> new_layout_map_;
public:
static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer;
ReducerLayoutAnnotator substituter(&analyzer);
PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = substituter.VisitStmt(f->body);
return f;
}
};
tvm::transform::Pass LayoutReducer() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return ReducerLayoutAnnotator::Substitute(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutReducer", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LayoutReducer", LayoutReducer);
});
} // namespace tl
} // namespace tvm
/*!
* \file layout_reducer.h
*/
#ifndef TVM_TL_TRANSFORM_LAYOUT_REDUCER_H_
#define TVM_TL_TRANSFORM_LAYOUT_REDUCER_H_
#include <tvm/tir/op.h>
#include "../layout/layout.h"
namespace tvm {
namespace tl {
enum class ReducerOpType { SUM, MAX, MIN };
enum class ReducerRepType { ALL, NONE };
struct ReducerInfoNode : Object {
ReducerOpType op;
ReducerRepType rep;
ReducerInfoNode() = default;
ReducerInfoNode(const String &op_str, const String &rep_str);
static constexpr const char *_type_key = "tl.ReducerInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(ReducerInfoNode, Object);
};
struct ReducerInfo : ObjectRef {
public:
TVM_DLL ReducerInfo(const String &op_str, const String &rep_str) {
data_ = make_object<ReducerInfoNode>(op_str, rep_str);
}
TVM_DEFINE_OBJECT_REF_METHODS(ReducerInfo, ObjectRef, ReducerInfoNode);
};
namespace attr {
constexpr const char *kReducerInfo = "reducer_info";
}
} // namespace tl
} // namespace tvm
#endif
......@@ -962,6 +962,7 @@ private:
uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
uint64_t const_nbits =
static_cast<uint64_t>(op->ConstantAllocationSize() * op_elem_bits);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if (const_nbits > 0 && const_nbits <= 32) {
......
......@@ -244,7 +244,8 @@ void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) {
if (op->else_case) {
scope_.push_back(std::vector<StmtEntry>());
{
With<arith::ConstraintContext> constraint(&analyzer_, real_condition);
With<arith::ConstraintContext> constraint(
&analyzer_, analyzer_.rewrite_simplify(Not(real_condition)));
this->VisitStmt(op->else_case.value());
}
auto v = Summarize(std::move(scope_.back()), nullptr);
......
......@@ -649,6 +649,8 @@ public:
*/
bool hasSimtCopy() const { return has_simt_copy_; }
bool onlyHasWgMMA() const { return only_has_wgmma_; }
private:
template <typename NodeType> Stmt FilterByRole(const NodeType *op) {
Role role = marker_.GetRole(op);
......
......@@ -257,13 +257,13 @@ def matmul(M, N, K, with_roller):
def test_autotune_get_configs():
get_configs(8192, 8192, 8192, with_roller=True)
get_configs(8192, 8192, 8192, with_roller=False)
get_configs(1024, 1024, 1024, with_roller=True)
get_configs(1024, 1024, 1024, with_roller=False)
def test_autotune_matmul():
matmul(8192, 8192, 8192, with_roller=True)
matmul(8192, 8192, 8192, with_roller=False)
matmul(1024, 1024, 1024, with_roller=True)
matmul(1024, 1024, 1024, with_roller=False)
if __name__ == "__main__":
......
......@@ -131,7 +131,7 @@ def run_autotune(M: int, N: int, K: int):
def test_autotune_matmul():
run_autotune(8192, 8192, 8192)
run_autotune(1024, 1024, 1024)
if __name__ == "__main__":
......
......@@ -69,6 +69,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.FrontendLegalize()(mod)
# Simplify the IR expressions
mod = tir.transform.Simplify()(mod)
# Set layouts for reducers
mod = tilelang.transform.LayoutReducer()(mod)
# Infer memory layouts for fragments and shared memory
mod = tilelang.transform.LayoutInference()(mod)
# Lower high-level tile operations to low-level operations
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment