Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
......@@ -30,3 +30,4 @@ scipy
tabulate
tornado
wheel
z3-solver>=4.13.0
\ No newline at end of file
# Runtime requirements
apache-tvm-ffi~=0.1.0
apache-tvm-ffi>=0.1.3
torch-c-dlpack-ext
cloudpickle
ml-dtypes
numpy>=1.23.5
......@@ -8,3 +9,4 @@ torch
torch>=2.7; platform_system == 'Darwin'
tqdm>=4.62.3
typing-extensions>=4.10.0
z3-solver>=4.13.0
\ No newline at end of file
......@@ -44,16 +44,22 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
n->vars.push_back(var);
n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
const Stmt &body) -> Stmt {
const Array<Optional<PrimExpr>> &steps,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), 1);
ICHECK_EQ(doms.size(), 1);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body);
Optional<PrimExpr> step =
!steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body,
/*thread_binding=*/std::nullopt,
/*annotations=*/tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any>{},
/*step=*/step);
};
return ForFrame(n);
}
ForFrame ParallelFor(const Array<PrimExpr> &extents,
const Map<String, ObjectRef> &annotations) {
const Map<String, tvm::ffi::Any> &annotations) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(extents.size());
......@@ -63,16 +69,19 @@ ForFrame ParallelFor(const Array<PrimExpr> &extents,
n->vars.push_back(Var("v", extent.dtype()));
n->doms.push_back(Range(make_const(dtype, 0), extent));
}
n->f_make_for_loop = [annotations](const Array<Var> &vars,
const Array<Range> &doms,
Stmt body) -> Stmt {
n->f_make_for_loop =
[annotations](const Array<Var> &vars, const Array<Range> &doms,
const Array<Optional<PrimExpr>> &steps, Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
int n = vars.size();
for (int i = n - 1; i >= 0; --i) {
Range dom = doms[i];
Var var = vars[i];
Optional<PrimExpr> step =
i < steps.size() ? steps[i] : Optional<PrimExpr>(std::nullopt);
body = For(var, dom->min, dom->extent, ForKind::kParallel, body,
/*thread_binding=*/std::nullopt, /*annotations=*/annotations);
/*thread_binding=*/std::nullopt, /*annotations=*/annotations,
/*step=*/step);
}
return body;
};
......@@ -90,11 +99,12 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
n->vars.push_back(Var("v", dtype));
n->doms.push_back(Range(std::move(start), stop));
n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
const Array<Optional<PrimExpr>> &steps,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
int n = vars.size();
ICHECK(n == 1);
Map<String, ObjectRef> anno;
Map<String, tvm::ffi::Any> anno;
if (num_stages > 0)
anno.Set("num_stages", PrimExpr(num_stages));
if (!order.empty())
......@@ -105,8 +115,11 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
anno.Set("tl_pipeline_sync", sync);
if (!groups.empty())
anno.Set("tl_pipeline_group", groups);
Optional<PrimExpr> step =
!steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body,
/*thread_binding=*/std::nullopt, /*annotations=*/anno);
/*thread_binding=*/std::nullopt, /*annotations=*/anno,
/*step=*/step);
return body;
};
return ForFrame(n);
......@@ -145,9 +158,10 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
grouped_domain.push_back(group_size);
n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
const Stmt &body) -> Stmt {
const Array<Optional<PrimExpr>> &steps,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
Map<String, ObjectRef> anno;
Map<String, tvm::ffi::Any> anno;
Array<PrimExpr> idxs(grouped_domain.size(), PrimExpr());
PrimExpr rem = loop_var * wave_size + index;
......@@ -168,8 +182,11 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
if (analyzer.CanProveGreaterEqual(waves, 2)) {
new_body = SeqStmt({out_if, body});
}
Stmt outer =
For(loop_var, 0, waves, ForKind::kSerial, new_body, std::nullopt, anno);
Optional<PrimExpr> step =
!steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
Stmt outer = For(loop_var, 0, waves, ForKind::kSerial, new_body,
/*thread_binding=*/std::nullopt, /*annotations=*/anno,
/*step=*/step);
for (int i = 0; i < vars.size() - 1; ++i) {
outer = tvm::tir::LetStmt(vars[i], idxs[i + 1], outer);
}
......
......@@ -12,6 +12,8 @@
#include <tvm/tir/stmt_functor.h>
#include "arith/pattern_match.h"
#include "tvm/node/functor.h"
#include "tvm/node/repr_printer.h"
#include "utils.h"
namespace tvm {
......@@ -78,7 +80,8 @@ void LayoutNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<LayoutNode>()
.def_ro("input_size", &LayoutNode::input_size_)
.def_ro("forward_index", &LayoutNode::forward_index_);
.def_ro("forward_index", &LayoutNode::forward_index_)
.def("_DebugOutput", &LayoutNode::DebugOutput);
}
void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
......@@ -297,13 +300,17 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
}
Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
arith::Analyzer *analyzer,
const PrimExpr rescale_num,
const PrimExpr rescale_den) const {
// Fast path: if shape is the same, return the original layout
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Layout>(this);
}
// Step 1. Prove the product of InputShape is equal to the product of shape
// Step 1. Prove the product relation holds under rescale:
// prod(InputShape) * rescale_num == prod(shape) * rescale_den
PrimExpr input_shape_product = Integer(1);
for (const auto &dim : InputShape()) {
input_shape_product *= dim;
......@@ -317,8 +324,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
// potential null dereference paths flagged by static analysis.
arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_shape_product, shape_product))
<< "InputShape() = " << InputShape() << " shape = " << shape;
ICHECK(az->CanProveEqual(input_shape_product * rescale_num,
shape_product * rescale_den))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den;
// Step 2. Create new forward indices by reshaping
// For each dimension in the new shape, we create a placeholder variable
......@@ -339,13 +348,17 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
}
flat_index = flat_index + new_vars[i] * stride;
}
// Convert new flat index (in units of new elements) to the old flat index
// (in units of old elements) using the rational rescale factor.
// old_flat = floor((flat_index * rescale_den) / rescale_num)
PrimExpr old_flat_index = floordiv(flat_index * rescale_den, rescale_num);
// Step 4. Convert flat index back to original shape indices
// For original shape [s0, s1, ..., sm]:
// i0 = flat_index // (s1 * s2 * ... * sm)
// i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm)
// ...
Array<PrimExpr> original_indices;
PrimExpr remaining = flat_index;
PrimExpr remaining = old_flat_index;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j) {
......@@ -373,7 +386,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
}
Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
arith::Analyzer *analyzer,
const PrimExpr rescale_num,
const PrimExpr rescale_den) const {
// Fast path: identical input shape, return self
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Fragment>(this);
......@@ -390,8 +406,9 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
// Use provided analyzer if present, otherwise a local fallback.
arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_prod, shape_prod))
ICHECK(az->CanProveEqual(input_prod * rescale_num, shape_prod * rescale_den))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den
<< " input fragment layout is = " << DebugOutput();
// 2) Build flat index from new-shape indices
......@@ -414,9 +431,12 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
stride = stride * shape[j];
flat = flat + new_vars[i] * stride;
}
// Convert to old flat index units using the rational rescale factor.
// old_flat = floor((flat * rescale_den) / rescale_num)
PrimExpr old_flat = floordiv(flat * rescale_den, rescale_num);
// 3) Recover original indices from flat index
Array<PrimExpr> orig_indices;
PrimExpr remain = flat;
PrimExpr remain = old_flat;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j)
......@@ -529,6 +549,12 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
data_ = std::move(n);
}
Fragment Fragment::FullyReplicated(Array<PrimExpr> shape,
PrimExpr thread_extent) {
return Fragment(shape, {}, ReplicationPlaceholder(), thread_extent,
std::nullopt);
}
// which means the forward_thread is rep_var -> lambda i, rep: rep
bool FragmentNode::IsCompletedReplicated() const {
arith::Analyzer analyzer;
......@@ -536,6 +562,52 @@ bool FragmentNode::IsCompletedReplicated() const {
ReplicationPlaceholder());
}
arith::IterMapResult FragmentNode::DetectInjective() const {
// lei:To perform injective check, we need to reverse the layout
// and use surjective check, now we use bijective check for convenience
// can be relaxed in future
arith::Analyzer analyzer;
// Build a flat indices array: [forward_thread_, forward_index_[...]]
Array<PrimExpr> indices;
indices.push_back(forward_thread_);
for (const auto &e : forward_index_) {
indices.push_back(e);
}
// Mirror Layout::InverseWithLevel(): if any participating shape is
// symbolic, relax to NoCheck and rely on runtime guards elsewhere.
auto collect_symbolic = [&](const Array<PrimExpr> &shape) {
Array<PrimExpr> symbolic_dims;
for (const auto &dim : shape) {
if (!as_const_int(dim)) {
symbolic_dims.push_back(dim);
}
}
return symbolic_dims;
};
Array<PrimExpr> symbolic_dims = collect_symbolic(InputShape());
Array<PrimExpr> output_shape = OutputShape();
symbolic_dims.insert(symbolic_dims.end(), output_shape.begin(),
output_shape.end());
// Also consider replicate size for fragments
if (!as_const_int(ReplicateExtent())) {
symbolic_dims.push_back(ReplicateExtent());
}
symbolic_dims = collect_symbolic(symbolic_dims);
bool is_static_shape = symbolic_dims.empty();
auto level = is_static_shape ? arith::IterMapLevel::Bijective
: arith::IterMapLevel::NoCheck;
if (!is_static_shape) {
DLOG(WARNING)
<< "Fragment::DetectInjective on symbolic layout, falling back to "
<< "NoCheck; symbolic dims: " << symbolic_dims;
}
return arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer);
}
PrimExpr FragmentNode::ThreadExtent() const {
Array<PrimExpr> ret(OutputDim(), 1);
arith::Analyzer analyzer;
......@@ -653,8 +725,19 @@ void FragmentNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<FragmentNode>()
.def_ro("forward_thread", &FragmentNode::forward_thread_)
.def_ro("replicate_size", &FragmentNode::replicate_size_);
}
.def_ro("replicate_size", &FragmentNode::replicate_size_)
.def("_DebugOutput", &FragmentNode::DebugOutput);
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FragmentNode>([](const ObjectRef &obj, ReprPrinter *p) {
auto *node = static_cast<const FragmentNode *>(obj.get());
p->stream << node->DebugOutput();
})
.set_dispatch<LayoutNode>([](const ObjectRef &obj, ReprPrinter *p) {
auto *node = static_cast<const LayoutNode *>(obj.get());
p->stream << node->DebugOutput();
});
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
......
......@@ -6,6 +6,7 @@
#ifndef TVM_TL_LAYOUT_LAYOUT_H_
#define TVM_TL_LAYOUT_LAYOUT_H_
#include <exception>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/object.h>
......@@ -18,6 +19,25 @@ namespace tl {
using namespace tir;
// Common layout-related exceptions
class LayoutConflictException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
explicit LayoutConflictException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
class LoopLayoutInjectiveException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
explicit LoopLayoutInjectiveException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
class Layout;
class Fragment;
......@@ -42,8 +62,18 @@ public:
virtual Layout Inverse() const;
// Reshape the layout to a new logical shape. When aliasing buffers of
// different dtypes, the element count may change while the underlying
// byte-size stays equal. Use rescale_num/rescale_den to represent the
// ratio between the old element size and the new element size in bytes.
// Specifically, define factor = rescale_num / rescale_den where:
// new_num_elems = old_num_elems * factor
// For example, f32->i8 (4B -> 1B) uses rescale_num=4, rescale_den=1.
// i8->f32 (1B -> 4B) uses rescale_num=1, rescale_den=4.
virtual Layout Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const;
arith::Analyzer *analyzer,
const PrimExpr rescale_num = Integer(1),
const PrimExpr rescale_den = Integer(1)) const;
virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;
......@@ -86,7 +116,9 @@ public:
Layout Inverse() const final;
Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer) const;
Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer,
const PrimExpr rescale_num = Integer(1),
const PrimExpr rescale_den = Integer(1)) const;
std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;
......@@ -116,6 +148,8 @@ public:
bool IsCompletedReplicated() const;
arith::IterMapResult DetectInjective() const;
static void RegisterReflection();
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode);
......@@ -141,6 +175,20 @@ public:
PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var);
/*!
* \brief Create a fully replicated fragment layout.
*
* A fully replicated fragment means all threads hold identical copies of the
* entire buffer. This is useful for index buffers or masks that need to be
* accessed uniformly across all threads.
*
* \param shape The shape of the buffer.
* \param thread_extent The number of threads.
* \return A Fragment where each thread has a complete copy of all elements.
*/
TVM_DLL static Fragment FullyReplicated(Array<PrimExpr> shape,
PrimExpr thread_extent);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode);
};
......
......@@ -5,7 +5,7 @@
*/
#include "./atomic_add.h"
#include "./region.h"
#include "utils.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
......@@ -26,32 +26,27 @@ using namespace tir;
* @brief Construct an AtomicAdd operator from call arguments and a buffer map.
*
* Builds the internal AtomicAddNode, extracts the source and destination
* regions and their backing Buffers from the first two call-style expressions
* in `args` (via RegionOp), and stores them along with their ranges. If a third
* argument is provided, it is interpreted as an integer immediate and stored as
* the node's coalesced width.
* regions and their backing Buffers from the first two region-style expressions
* in `args` (BufferLoad/BufferRegion), and stores them along with their
* ranges. If a third argument is provided, it is interpreted as an integer
* immediate and stored as the node's coalesced width.
*
* @param args Call-style PrimExprs where:
* - args[0] is the source region call,
* - args[1] is the destination region call,
* - args[2] (optional) is an IntImm specifying coalesced width.
* @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects.
*
* Notes:
* - The constructor checks that args[0] and args[1] are CallNodes.
* - The constructor checks that args[0] and args[1] are region-compatible.
* - The constructed node is stored in this->data_.
*/
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
AtomicAdd::AtomicAdd(Array<PrimExpr> args) {
ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
auto expr = args[i];
auto call = expr.as<CallNode>();
ICHECK(call);
auto region = RegionOp(call->args, vmap);
rgs[i] = region->GetRanges();
bf[i] = region->GetBuffer();
auto region = NormalizeToBufferRegion(args[i]);
rgs[i] = region->region;
bf[i] = region->buffer;
}
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
......@@ -272,22 +267,22 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);
Array<PrimExpr> new_args;
// Optional bounds predicates for src and dst
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
Array<PrimExpr> new_args;
// Load source value and cast to dst dtype if needed
PrimExpr src_value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype)
src_value = Cast(dst->dtype, src_value);
if (src_predicate.defined())
src_value = if_then_else(src_predicate, src_value, make_zero(dst->dtype));
PrimExpr dst_value = BufferLoad(dst, dst_indices);
if (dst_predicate.defined())
dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype));
// Build a pointer to destination element using tvm_access_ptr
PrimExpr dst_ptr = Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(dst, dst_indices)});
new_args.push_back(dst_value);
new_args.push_back(dst_ptr);
new_args.push_back(src_value);
new_args.push_back(memory_order);
......@@ -544,7 +539,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
}
TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
TIR_REGISTER_TL_TILE_OP(AtomicAdd, atomicadd)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......@@ -552,4 +547,4 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
\ No newline at end of file
} // namespace tvm
......@@ -65,7 +65,7 @@ class AtomicAdd : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator,
AtomicAddNode);
TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL AtomicAdd(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -22,8 +22,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableThreadStorageSync, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kForceLetInline, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
......@@ -34,6 +32,9 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationEnable, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationFormats, String);
TVM_REGISTER_PASS_CONFIG_OPTION(kDeviceCompileFlags, ffi::Array<ffi::String>);
DataType cuTensorMapType() { return DataType::UInt(8, 128); }
......@@ -99,6 +100,12 @@ TIR_DEFINE_TL_BUILTIN(ieee_frsqrt)
TIR_DEFINE_TL_BUILTIN(ieee_fdiv).set_num_inputs(3).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(3).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(rng_rand).set_num_inputs(0).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
......@@ -344,5 +351,35 @@ TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_sum)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_max)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_min)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_bitand)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_bitor)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
// __ldg(BufferLoad | Buffer, idx?) -> value
// Treat as a pure call that returns the loaded value.
TIR_DEFINE_TL_BUILTIN(__ldg).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure));
} // namespace tl
} // namespace tvm
......@@ -28,6 +28,10 @@ static constexpr const char *kWarpSpecializationScope =
static constexpr const char *kCustomWarpSpecialization =
"kCustomWarpSpecialization";
static constexpr const char *kLocalVarInit = "tl.local_var_init";
// A PrimFunc-level attribute carrying a list of handle Vars
// that must NOT be marked with the restrict qualifier in codegen.
// Type: Array<tir::Var>
static constexpr const char *kNonRestrictParams = "tl.non_restrict_params";
} // namespace attr
static constexpr const char *kDebugMergeSharedMemoryAllocations =
......@@ -51,14 +55,11 @@ static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
static constexpr const char *kStorageRewriteDetectInplace =
"tl.storage_rewrite_detect_inplace";
/*!
* \brief Whether to disable dynamic tail split
*
* kDisableDynamicTailSplit = "tl.disable_dynamic_tail_split"
*
*/
static constexpr const char *kDisableDynamicTailSplit =
"tl.disable_dynamic_tail_split";
static constexpr const char *kLayoutVisualizationEnable =
"tl.layout_visualization_enable";
static constexpr const char *kLayoutVisualizationFormats =
"tl.layout_visualization_formats";
static constexpr const char *kDeviceCompileFlags = "tl.device_compile_flags";
/*!
* \brief Whether to disable thread storage synchronization
......@@ -82,18 +83,6 @@ static constexpr const char *kDisableThreadStorageSync =
*/
static constexpr const char *kForceLetInline = "tl.force_let_inline";
/*!
* \brief The size of the vectorized dimension in buffer, designed by user
*
* For example, if the vectorized dimension is 128 bits and the dtype of buffer
* A[m, k] is float16, the size of the vectorized dimension (i.e. k) in buffer A
* should be divisible by 8 (8 = 128 / 16).
*
* kDynamicAlignment = "tl.dynamic_alignment"
*
*/
static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
/*!
* \brief Get the type of the CUDA tensor map
*
......@@ -138,6 +127,10 @@ TVM_DLL const Op &ieee_frsqrt();
// ieee_fdiv(x, y, rounding_mode) - IEEE-compliant division
TVM_DLL const Op &ieee_fdiv();
// random op
TVM_DLL const Op &rng_init();
TVM_DLL const Op &rng_rand();
/*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load
*
......@@ -582,6 +575,49 @@ TVM_DLL const Op &device_assert();
*/
TVM_DLL const Op &device_assert_with_msg();
/*!
* \brief tilelang intrinsic for warp reduction sum.
*/
TVM_DLL const Op &warp_reduce_sum();
/*!
* \brief tilelang intrinsic for warp reduction max.
*/
TVM_DLL const Op &warp_reduce_max();
/*!
* \brief tilelang intrinsic for warp reduction min.
*/
TVM_DLL const Op &warp_reduce_min();
/*!
* \brief tilelang intrinsic for warp reduction bitand.
*/
TVM_DLL const Op &warp_reduce_bitand();
/*!
* \brief tilelang intrinsic for warp reduction bitor.
*/
TVM_DLL const Op &warp_reduce_bitor();
/*!
* \brief tilelang intrinsic for CUDA read-only cache load (__ldg).
*
* This op allows users to explicitly request a non-coherent cached load
* from global memory on CUDA by emitting `__ldg(&ptr[idx])` for 32-bit
* element types on supported architectures. It provides a direct way to
* leverage the read-only data cache for performance-sensitive loads when
* the compiler cannot infer `const __restrict__` automatically.
*
* Usage from TVMScript:
* y[i] = T.__ldg(x[i])
*
* The op takes one argument preferred as a BufferLoad identifying the
* source element; alternatively, backends may support passing a Buffer and
* index expression.
*/
TVM_DLL const Op &__ldg();
} // namespace tl
} // namespace tvm
......
......@@ -16,7 +16,7 @@
#include "../transform/common/loop_parallel_transform_utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "region.h"
#include "utils.h"
#include "../target/cuda.h"
#include "../target/utils.h"
......@@ -57,7 +57,7 @@ static int to_CUtensorMapDataType(DataType dtype) {
}
} else if (dtype.is_bfloat16()) {
tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if (dtype.is_float8_e4m3() || dtype.is_float8_e5m2()) {
} else if (dtype.is_float8()) {
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if (dtype.is_int()) {
switch (dtype.bits()) {
......@@ -110,36 +110,32 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
/*!
* \brief Construct a Copy operator node from call arguments and a buffer map.
*
* This constructor parses the first two entries of `args` as Call nodes
* describing source and destination Regions (via RegionOp), extracts their
* Buffers and Ranges, and stores them on the newly created CopyNode. It also
* This constructor parses the first two entries of `args` as regions
* (BufferLoad/BufferRegion), extracts their Buffers and Ranges, and stores
* them on the newly created CopyNode. It also
* reads optional arguments:
* - args[2] (IntImm): coalesced width (stored only if > 0),
* - args[3] (Bool): disable TMA lowering flag,
* - args[4] (IntImm): eviction policy.
*
* Preconditions:
* - `args` must contain at least two Call-compatible PrimExpr entries
* describing regions; an ICHECK will fail if they are not CallNodes.
* - `args` must contain at least two region-compatible PrimExpr entries
* (BufferLoad/BufferRegion); ICHECK will fail otherwise.
*
* @param args Array of PrimExpr where:
* - args[0] is the source Region call,
* - args[1] is the destination Region call,
* - optional args[2..4] are coalesced width, disable_tma, and eviction
* policy.
* @param vmap BufferMap used to resolve RegionOp buffers and ranges.
*/
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
Copy::Copy(Array<PrimExpr> args) {
ObjectPtr<CopyNode> node = tvm::ffi::make_object<CopyNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
auto expr = args[i];
auto call = expr.as<CallNode>();
ICHECK(call);
auto region = RegionOp(call->args, vmap);
rgs[i] = region->GetRanges();
bf[i] = region->GetBuffer();
auto region = NormalizeToBufferRegion(args[i]);
rgs[i] = region->region;
bf[i] = region->buffer;
}
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
......@@ -183,15 +179,95 @@ TileOperator CopyNode::Clone() const {
* copy operation.
*/
Array<IterVar> CopyNode::MakeIterVars() const {
// Choose the range set from the lowest-level memory scope between src and
// dst. Scope levels: global < shared/shared.dyn/shared.tmem < local.fragment
// (fragment)
auto scope_level = [](const Buffer &b) -> int {
String s = b.scope();
if (s == "local.fragment" || s == "local")
return 2;
if (s == "shared" || s == "shared.dyn" || s == "shared.tmem")
return 1;
// default to global level for unknown scopes
return 0;
};
int src_level = scope_level(src);
int dst_level = scope_level(dst);
bool base_is_src = (src_level >= dst_level);
const Array<Range> &base_ranges = base_is_src ? src_range : dst_range;
// Sanity check: when switching away from the original (src_range),
// ensure the chosen base ranges are not provably smaller than the original
// per dimension. This guards against generating undersized loop domains.
// Improved logic: use two pointers to traverse both base_ranges and
// src_range, skipping dimensions with extent == 1. The number of non-1
// extents must match.
arith::Analyzer analyzer;
size_t base_dim = 0, src_dim = 0;
while (base_dim < base_ranges.size() && src_dim < src_range.size()) {
// Skip base extents that are 1
while (base_dim < base_ranges.size() &&
is_one(base_ranges[base_dim]->extent)) {
++base_dim;
}
// Skip src extents that are 1
while (src_dim < src_range.size() && is_one(src_range[src_dim]->extent)) {
++src_dim;
}
// Both indices now at non-1, or at end
if (base_dim < base_ranges.size() && src_dim < src_range.size()) {
PrimExpr base_ext = base_ranges[base_dim]->extent;
PrimExpr src_ext = src_range[src_dim]->extent;
// Only fail if base extent is provably smaller than src extent
if (analyzer.CanProve(base_ext < src_ext)) {
std::ostringstream oss;
oss << "Selected loop range is smaller than original src range at "
"matched non-1 dimension: "
<< "base(extent=" << base_ext
<< ", scope=" << (base_is_src ? src.scope() : dst.scope())
<< ", min=" << base_ranges[base_dim]->min
<< ", base_dim=" << base_dim << ") < src(extent=" << src_ext
<< ", min=" << src_range[src_dim]->min << ", src_dim=" << src_dim
<< ", scope=" << src.scope() << ") for src=" << src->name
<< ", dst=" << dst->name << "\n";
oss << "src buffer: " << src->name << ", scope=" << src.scope() << "\n";
oss << "dst buffer: " << dst->name << ", scope=" << dst.scope() << "\n";
oss << "base_ranges[" << base_dim
<< "]: min=" << base_ranges[base_dim]->min
<< ", extent=" << base_ext << "\n";
oss << "src_ranges[" << src_dim << "]: min=" << src_range[src_dim]->min
<< ", extent=" << src_ext << "\n";
LOG(FATAL) << oss.str();
}
++base_dim;
++src_dim;
}
}
// Any remaining unmatched dimensions in either range must all have extent ==
// 1
while (base_dim < base_ranges.size()) {
ICHECK(is_one(base_ranges[base_dim]->extent))
<< "base_ranges has extra non-1 extent at dim " << base_dim;
++base_dim;
}
while (src_dim < src_range.size()) {
ICHECK(is_one(src_range[src_dim]->extent))
<< "src_range has extra non-1 extent at dim " << src_dim;
++src_dim;
}
Array<IterVar> loop_vars;
size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) {
if (is_one(src_range[i]->extent))
for (size_t i = 0; i < base_ranges.size(); i++) {
if (is_one(base_ranges[i]->extent))
continue;
Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
Var var = Var(std::string{char('i' + idx)}, base_ranges[i]->extent->dtype);
idx++;
loop_vars.push_back(
{Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
{Range(0, base_ranges[i]->extent), var, IterVarType::kDataPar});
}
return loop_vars;
}
......@@ -250,6 +326,7 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
size_t idx = 0;
......@@ -302,7 +379,6 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
ICHECK(loop_vars.size() <= src_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", src_range.size() = " << src_range.size() << ", src = " << src->name
......@@ -475,16 +551,38 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
// This must be a global/shared layout, so we can skip the parallel op
// layout inference (parallel layout inference only annotate the loop layout
// and the register layout).
bool is_load = copy_inst == CopyInst::kBulkLoad;
bool is_load =
copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkLoad1D;
Buffer global_tensor = is_load ? src : dst;
Buffer shared_tensor = is_load ? dst : src;
Map<Buffer, Layout> result_map;
// Collect fragment buffers from indices and mark them as fully replicated
// For Bulk Load/Store, fragment buffers used as indices should be
// replicated across all threads
PrimExpr thread_extent = T.thread_bounds->extent;
for (const auto &range : src_range) {
CollectFragmentLayouts(range->min, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
CollectFragmentLayouts(range->extent, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
}
for (const auto &range : dst_range) {
CollectFragmentLayouts(range->min, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
CollectFragmentLayouts(range->extent, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
}
// check shared layout is non-swizzle
// skip layout inference if shared layout is already annotated
if (level == InferLevel::kFree && !T.layout_map.count(shared_tensor)) {
// create a new layout map for tma linear layout
Layout linear_layout = ComputeLinearLayout(shared_tensor);
return Map<Buffer, Layout>({{shared_tensor, linear_layout}});
result_map.Set(shared_tensor, linear_layout);
}
return result_map;
}
// for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy
......@@ -493,7 +591,8 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
arith::Analyzer analyzer;
par_op_ = ParallelOp((MakeSIMTLoop(&analyzer)));
}
return par_op_->InferLayout(T, level);
auto layout_map = par_op_->InferLayout(T, level);
return layout_map;
}
/**
* @brief Determine whether this CopyNode can be lowered to a Bulk Load (TMA)
......@@ -851,21 +950,31 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
For vectorized_thread_loop;
auto par_op = ParallelOp(transformed_loop);
if (is_cpu_target) {
if (is_cpu_target || dst.scope() == "local" || src.scope() == "local") {
if (src.scope() == "local" && dst.scope() != "local") {
LOG(WARNING) << "Copy from local buffer `" << src->name << "` to "
<< dst.scope() << " buffer `" << dst->name
<< "` may cause conflicted write.";
}
vectorized_thread_loop = VectorizeLoop(transformed_loop);
} else {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
false, T.buffer_remap},
par_op->InferLayout({T.target,
T.thread_bounds,
T.layout_map,
analyzer,
false,
T.buffer_remap,
{}},
level);
}
auto loop_layout = par_op->GetLoopLayout();
auto thread_var = T.thread_var;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeLoop(thread_loop);
vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
}
if (par_op->GetPredicate(T.thread_var).defined()) {
......@@ -1117,6 +1226,11 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
bool is_ld = false; // tcgen05.ld (tensor memory -> register)
bool is_st = false; // tcgen05.st (register -> tensor memory)
bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory)
bool src_needs_pack =
16 == src->dtype.bits(); // if needs .pack::16b when is_ld
bool dst_needs_unpack =
16 == dst->dtype.bits(); // if needs .unpack::16b when is_st
if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") {
is_ld = true;
} else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") {
......@@ -1124,9 +1238,8 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
} else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") {
is_cp = true;
} else {
ICHECK(0) << "Unsupported tensor memory copy: "
<< "src scope = " << src.scope()
<< ", dst scope = " << dst.scope();
ICHECK(0) << "Unsupported tensor memory copy: " << "src scope = "
<< src.scope() << ", dst scope = " << dst.scope();
}
// Currently tcgen05.cp is not supported
// TODO (mzw) Support tcgen05.cp
......@@ -1246,8 +1359,10 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
: relative_wg_idx * (num_chunks_each_wg * meta.width);
have_succeeded = true;
Array<PrimExpr> args;
const char *bool_str = src_needs_pack ? "true" : "false";
args.push_back(StringImm(meta.intrinsics_name + "<" +
std::to_string(num_chunks_each_wg) + ">"));
std::to_string(num_chunks_each_wg) + ", " +
bool_str + ">"));
args.push_back(
BufferLoad(src, {(int)logical_row_min,
(int)logical_col_min})); // Will be translated later
......@@ -1724,20 +1839,21 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
* GPU intrinsics.
*
* @param args Array of PrimExpr TL-call arguments (see list above).
* @param vmap Mapping from original buffer variables to actual Buffer objects.
*/
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args) {
ObjectPtr<Conv2DIm2ColOpNode> node =
tvm::ffi::make_object<Conv2DIm2ColOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->nhw_step = args[2];
node->c_step = args[3];
node->kernel = args[4].as<IntImm>().value()->value;
node->stride = args[5].as<IntImm>().value()->value;
node->dilation = args[6].as<IntImm>().value()->value;
node->padding = args[7].as<IntImm>().value()->value;
node->eviction_policy = args[8].as<IntImm>().value()->value;
node->srcRegion_ = NormalizeToBufferRegion(args[0]);
node->dstRegion_ = NormalizeToBufferRegion(args[1]);
node->src_ = node->srcRegion_->buffer;
node->dst_ = node->dstRegion_->buffer;
node->nhw_step_ = args[2];
node->c_step_ = args[3];
node->kernel_ = args[4].as<IntImm>().value()->value;
node->stride_ = args[5].as<IntImm>().value()->value;
node->dilation_ = args[6].as<IntImm>().value()->value;
node->padding_ = args[7].as<IntImm>().value()->value;
node->eviction_policy_ = args[8].as<IntImm>().value()->value;
data_ = std::move(node);
}
......@@ -1788,24 +1904,24 @@ TileOperator Conv2DIm2ColOpNode::Clone() const {
Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
ICHECK(TargetIsHopper(T.target));
ICHECK(src.scope() == "global" &&
(dst.scope() == "shared.dyn" || dst.scope() == "shared"));
ICHECK(src->shape.size() == 4);
ICHECK(dst->shape.size() == 2);
ICHECK(src->dtype == dst->dtype);
ICHECK(src_.scope() == "global" &&
(dst_.scope() == "shared.dyn" || dst_.scope() == "shared"));
ICHECK(src_->shape.size() == 4);
ICHECK(dst_->shape.size() == 2);
ICHECK(src_->dtype == dst_->dtype);
Layout shared_layout;
if (T.layout_map.count(dst)) {
shared_layout = T.layout_map[dst];
if (T.layout_map.count(dst_)) {
shared_layout = T.layout_map[dst_];
}
TMAIm2ColDesc desc;
desc.rank = src->shape.size();
desc.data_type = to_CUtensorMapDataType(src->dtype);
desc.global_addr = src->data;
desc.global_shape = ReverseArray(src->shape);
desc.rank = src_->shape.size();
desc.data_type = to_CUtensorMapDataType(src_->dtype);
desc.global_addr = src_->data;
desc.global_shape = ReverseArray(src_->shape);
if (!src->strides.empty()) {
desc.global_stride = ReverseArray(src->strides);
if (!src_->strides.empty()) {
desc.global_stride = ReverseArray(src_->strides);
} else {
// Create stride from shape
PrimExpr stride = 1;
......@@ -1819,13 +1935,13 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes
desc.global_stride = desc.global_stride.Map([&](PrimExpr e) {
return cast(DataType::Int(64), e) * src->dtype.bytes();
return cast(DataType::Int(64), e) * src_->dtype.bytes();
});
desc.elem_stride = {1, stride, stride, 1};
desc.lower_corner = {-padding, -padding};
desc.upper_corner = {-padding, -padding};
desc.smem_box_pixel = Downcast<IntImm>(dst->shape[0])->value;
desc.smem_box_channel = Downcast<IntImm>(dst->shape[1])->value;
desc.elem_stride = {1, stride_, stride_, 1};
desc.lower_corner = {-padding_, -padding_};
desc.upper_corner = {-padding_, -padding_};
desc.smem_box_pixel = Downcast<IntImm>(dst_->shape[0])->value;
desc.smem_box_channel = Downcast<IntImm>(dst_->shape[1])->value;
desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
......@@ -1839,15 +1955,15 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
if (StructuralEqual()(shared_layout,
makeQuarterBankSwizzleLayout(*stride, *continuous,
dst->dtype.bits()))) {
dst_->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B);
} else if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout(
*stride, *continuous,
dst->dtype.bits()))) {
dst_->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
} else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout(
*stride, *continuous,
dst->dtype.bits()))) {
dst_->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
} else {
ICHECK(0) << "Cannot detect TMA layout.";
......@@ -1866,43 +1982,43 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
<< "Currently can only support divisible channel case";
global_coords.push_back(
FloorMod(c_step * desc.smem_box_channel, desc.global_shape[0]));
FloorMod(c_step_ * desc.smem_box_channel, desc.global_shape[0]));
image_offset.push_back(
dilation *
FloorMod(FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0]),
kernel));
image_offset.push_back(dilation * FloorDiv(c_step * desc.smem_box_channel,
desc.global_shape[0] * kernel));
dilation_ *
FloorMod(FloorDiv(c_step_ * desc.smem_box_channel, desc.global_shape[0]),
kernel_));
image_offset.push_back(dilation_ * FloorDiv(c_step_ * desc.smem_box_channel,
desc.global_shape[0] * kernel_));
PrimExpr h_dim =
FloorDiv(src->shape[1] + 2 * padding - (kernel - 1) * dilation - 1,
stride) +
FloorDiv(src_->shape[1] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1,
stride_) +
1;
PrimExpr w_dim =
FloorDiv(src->shape[2] + 2 * padding - (kernel - 1) * dilation - 1,
stride) +
FloorDiv(src_->shape[2] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1,
stride_) +
1;
global_coords.push_back(
stride * FloorMod(nhw_step * desc.smem_box_pixel, w_dim) - padding);
stride_ * FloorMod(nhw_step_ * desc.smem_box_pixel, w_dim) - padding_);
global_coords.push_back(
stride *
FloorMod(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim), h_dim) -
padding);
stride_ *
FloorMod(FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim), h_dim) -
padding_);
global_coords.push_back(
FloorDiv(nhw_step * desc.smem_box_pixel, w_dim * h_dim));
FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim * h_dim));
Array<PrimExpr> args;
args.reserve(desc.rank * 2 + 2);
args.push_back(create_desc);
args.push_back(0); // mbar placeholder
auto dst_buffer = T.buffer_remap.count(dst) ? T.buffer_remap[dst] : dst;
auto dst_buffer = T.buffer_remap.count(dst_) ? T.buffer_remap[dst_] : dst_;
auto shared_addr = dst_buffer.access_ptr(2);
args.push_back(shared_addr);
for (auto coord : global_coords)
args.push_back(coord);
for (auto offset : image_offset)
args.push_back(offset);
args.push_back(this->eviction_policy);
args.push_back(this->eviction_policy_);
Stmt tma_copy =
IfThenElse(EQ(T.thread_var, T.thread_bounds->min),
Evaluate(Call(DataType::Handle(), tma_load_im2col(), args)));
......@@ -1944,12 +2060,37 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
return args;
}
void CopyNode::CollectFragmentLayouts(const PrimExpr &expr,
const Map<Var, PrimExpr> &let_var_to_expr,
const LayoutMap &existing_layouts,
PrimExpr thread_extent,
Range thread_bounds,
Map<Buffer, Layout> &result_map) const {
PostOrderVisit(expr, [&](const ObjectRef &node) {
if (auto bl = node.as<BufferLoadNode>()) {
if (bl->buffer.scope() == "local.fragment" &&
!existing_layouts.count(bl->buffer) &&
!result_map.count(bl->buffer)) {
auto f = Fragment::FullyReplicated(bl->buffer->shape, thread_extent);
result_map.Set(bl->buffer, f->BindThreadRange(thread_bounds));
}
} else if (auto var_node = node.as<VarNode>()) {
auto var = tvm::ffi::GetRef<Var>(var_node);
if (let_var_to_expr.count(var)) {
CollectFragmentLayouts(let_var_to_expr[var], let_var_to_expr,
existing_layouts, thread_extent, thread_bounds,
result_map);
}
}
});
}
// Register the Copy operation with TVM's TIR system
// This makes the copy operation available for use in TVM programs
// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
// eviction_policy
// - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Copy, copy)
TIR_REGISTER_TL_TILE_OP(Copy, copy)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......@@ -1974,7 +2115,7 @@ LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T,
// - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride,
// dilation, padding, eviction_policy
// - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
TIR_REGISTER_TL_TILE_OP(Conv2DIm2ColOp, c2d_im2col)
.set_num_inputs(9)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......
......@@ -269,6 +269,28 @@ protected:
* @return Reference to the singleton TVM Op representing this operator.
*/
TileOperator Clone() const;
private:
/*!
* \brief Collect fragment buffers from expression and create fully replicated
* layouts.
*
* Recursively searches the expression for BufferLoad nodes with
* "local.fragment" scope, following let bindings. For each found fragment
* buffer, creates a fully replicated layout and adds it to result_map.
*
* \param expr Expression to search.
* \param let_var_to_expr Map from let variables to their bound expressions.
* \param existing_layouts Existing layout map to check for already-inferred
* layouts. \param thread_extent Number of threads for replication. \param
* thread_bounds Thread bounds for binding the layout. \param result_map
* Output map to store collected fragment layouts.
*/
void CollectFragmentLayouts(const PrimExpr &expr,
const Map<Var, PrimExpr> &let_var_to_expr,
const LayoutMap &existing_layouts,
PrimExpr thread_extent, Range thread_bounds,
Map<Buffer, Layout> &result_map) const;
};
class Copy : public TileOperator {
......@@ -280,7 +302,7 @@ public:
* \param args Expression arguments for the copy.
* \param vmap Buffer variable mapping.
*/
TVM_DLL Copy(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL Copy(Array<PrimExpr> args);
/*!
* \brief Get the TVM Op handle corresponding to this Copy op.
......@@ -296,14 +318,16 @@ public:
*/
class Conv2DIm2ColOpNode : public TileOperatorNode {
public:
Buffer src, dst; // Source (input feature map) and destination (im2col matrix)
int stride; // Stride for convolution
int padding; // Padding amount
int dilation; // Dilation factor
int kernel; // Kernel size
int eviction_policy; // Cache eviction policy
PrimExpr nhw_step; // Step size in NHW dimensions
PrimExpr c_step; // Step size in channel dimension
BufferRegion srcRegion_, dstRegion_;
Buffer src_,
dst_; // Source (input feature map) and destination (im2col matrix)
int stride_; // Stride for convolution
int padding_; // Padding amount
int dilation_; // Dilation factor
int kernel_; // Kernel size
int eviction_policy_; // Cache eviction policy
PrimExpr nhw_step_; // Step size in NHW dimensions
PrimExpr c_step_; // Step size in channel dimension
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Conv2DIm2Col", Conv2DIm2ColOpNode,
TileOperatorNode);
......@@ -311,13 +335,15 @@ public:
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<Conv2DIm2ColOpNode>()
.def_ro("src", &Conv2DIm2ColOpNode::src)
.def_ro("dst", &Conv2DIm2ColOpNode::dst)
.def_ro("stride", &Conv2DIm2ColOpNode::stride)
.def_ro("padding", &Conv2DIm2ColOpNode::padding)
.def_ro("dilation", &Conv2DIm2ColOpNode::dilation)
.def_ro("kernel", &Conv2DIm2ColOpNode::kernel)
.def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy);
.def_ro("srcRegion", &Conv2DIm2ColOpNode::srcRegion_)
.def_ro("dstRegion", &Conv2DIm2ColOpNode::dstRegion_)
.def_ro("src", &Conv2DIm2ColOpNode::src_)
.def_ro("dst", &Conv2DIm2ColOpNode::dst_)
.def_ro("stride", &Conv2DIm2ColOpNode::stride_)
.def_ro("padding", &Conv2DIm2ColOpNode::padding_)
.def_ro("dilation", &Conv2DIm2ColOpNode::dilation_)
.def_ro("kernel", &Conv2DIm2ColOpNode::kernel_)
.def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy_);
}
/*!
......@@ -342,7 +368,7 @@ class Conv2DIm2ColOp : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Conv2DIm2ColOp, TileOperator,
Conv2DIm2ColOpNode);
TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -17,7 +17,7 @@
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"
#include "region.h"
#include "utils.h"
namespace tvm {
namespace tl {
......@@ -52,62 +52,18 @@ using namespace tir;
* value].
* - args[0]: destination access (BufferLoad or pointer expression).
* - args[1]: value to fill (scalar or vector).
* @param vmap Mapping from buffer variables to Buffer objects; used to resolve
* the destination when args[0] is not a BufferLoad.
*
* Notes:
* - The constructor enforces constraints (e.g., stride == 1 ramps, constant
* lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out
* of bounds.
*/
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
Fill::Fill(Array<PrimExpr> args) {
ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();
// Case 1: Region descriptor call (tl.region)
if (const auto *call = args[0].as<CallNode>()) {
if (call->op.same_as(RegionOp::Get())) {
auto region = RegionOp(call->args, vmap);
node->dst = region->GetBuffer();
node->region = region->GetRanges();
} else if (call->op.same_as(builtin::tvm_access_ptr())) {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
} else {
ICHECK(false) << "Unsupported call op in tl.fill: "
<< Downcast<Op>(call->op)->name;
}
// Case 2: Explicit BufferRegion (legacy path)
} else if (args[0]->IsInstance<BufferRegionNode>()) {
auto region = Downcast<BufferRegion>(args[0]);
node->dst = region->buffer;
node->region = region->region;
// Case 3: Vector/scalar region expressed via BufferLoad indices
} else if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]);
for (const auto &index : buffer_load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
CHECK(ramp->stride.as<IntImmNode>()->value == 1)
<< "Only stride 1 ramps are supported";
const auto *lanes = ramp->lanes.as<IntImmNode>();
CHECK(lanes)
<< "Scalable vectors not supported in BufferRegion conversion";
node->region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
node->region.push_back(Range::FromMinExtent(index, 1));
}
}
node->dst = buffer_load->buffer;
// Case 4: Access pointer, fill the full buffer
} else {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
}
BufferRegion region = NormalizeToBufferRegion(args[0]);
node->dst = region->buffer;
node->region = region->region;
if (args[1]->dtype != node->dst->dtype) {
node->value = Cast(node->dst->dtype, args[1]);
......@@ -202,12 +158,17 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (dst.scope() == "local.fragment") {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
false, T.buffer_remap},
par_op->InferLayout({T.target,
T.thread_bounds,
T.layout_map,
analyzer,
false,
T.buffer_remap,
{}},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
......@@ -215,17 +176,22 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
} else if (dst.scope() == "local") {
auto init_loop = MakeSIMTLoop(analyzer);
auto vectorized_thread_loop = VectorizeLoop(init_loop);
auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer);
return vectorized_thread_loop;
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" ||
dst.scope() == "global") {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
false, T.buffer_remap},
par_op->InferLayout({T.target,
T.thread_bounds,
T.layout_map,
analyzer,
false,
T.buffer_remap,
{}},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
......@@ -253,7 +219,7 @@ LayoutMap FillNode::InferLayout(const LayoutInferArgs &T,
return {};
}
TIR_REGISTER_TL_OP(Fill, fill)
TIR_REGISTER_TL_TILE_OP(Fill, fill)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......
......@@ -45,7 +45,7 @@ private:
class Fill : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fill, TileOperator, FillNode);
TVM_DLL Fill(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL Fill(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -12,6 +12,7 @@
#include <tvm/tir/op_attr_types.h>
#include "../target/utils.h"
#include "utils.h"
namespace tvm {
namespace tl {
......@@ -29,12 +30,14 @@ using namespace tir;
* @param args TL operator arguments: expects at least two elements where
* `args[0]` is an access pointer identifying the reducer variable
* and `args[1]` is an integer encoding a `ReducerOpType` (e.g., Sum/Max/Min).
* @param vmap Mapping from variables to Buffers used to look up the reducer
* Buffer.
*/
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args) {
auto node = tvm::ffi::make_object<FinalizeReducerOpNode>();
node->reducer = vmap[GetVarFromAccessPtr(args[0])];
// Normalize any supported region expression
// (BufferRegion/BufferLoad/tl.region) to a BufferRegion, then take the
// underlying Buffer as reducer.
auto region = NormalizeToBufferRegion(args[0]);
node->reducer = region->buffer;
node->op = (ReducerOpType)*as_const_int(args[1]);
data_ = std::move(node);
}
......@@ -156,7 +159,7 @@ TileOperator FinalizeReducerOpNode::Clone() const {
return TileOperator(node);
}
TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer)
TIR_REGISTER_TL_TILE_OP(FinalizeReducerOp, finalize_reducer)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......
......@@ -48,7 +48,7 @@ class FinalizeReducerOp : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FinalizeReducerOp, TileOperator,
FinalizeReducerOpNode);
TVM_DLL FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL FinalizeReducerOp(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -12,8 +12,8 @@
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "region.h"
#include "tcgen5_meta.h"
#include "utils.h"
namespace tvm {
namespace tl {
......@@ -41,106 +41,21 @@ using namespace tir;
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region)
// to BufferRegion
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in GEMM region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in GEMM region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap[var];
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
}
LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg;
throw; // Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims";
// Compute row-major strides
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// NormalizeToBufferRegion moved to src/op/utils.{h,cc}
// Extent: last two extents product (elements)
PrimExpr extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
Gemm::Gemm(Array<PrimExpr> args) {
ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->bRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->cRegion_ = NormalizeToBufferRegion(args[2], vmap);
node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->bRegion_ = NormalizeToBufferRegion(args[1]);
node->cRegion_ = NormalizeToBufferRegion(args[2]);
node->a_ = node->aRegion_->buffer;
node->b_ = node->bRegion_->buffer;
......@@ -165,11 +80,14 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() > 15) {
node->wgWait_ = args[15].as<IntImm>().value()->value;
}
node->mbarPtr_ = args[16];
if (node->mbarPtr_.as<CallNode>()) {
node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)];
} else {
node->mbar_ = std::nullopt;
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
} else {
node->mbar_ = std::nullopt;
}
}
node->cCoords_ = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
......@@ -443,13 +361,7 @@ bool GemmNode::checkWgmma() const {
if (c_->dtype == DataType::Float(16)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
......@@ -462,13 +374,7 @@ bool GemmNode::checkWgmma() const {
else if (a_->dtype == DataType::Float(32) &&
b_->dtype == DataType::Float(32))
return (!transA_) && transB_ && k_ % 8 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
......@@ -535,9 +441,12 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
// Build access pointers from regions locally
PrimExpr Aptr = MakeAccessPtrFromRegion(aRegion_, /*r*/ 1);
PrimExpr Bptr = MakeAccessPtrFromRegion(bRegion_, /*r*/ 1);
PrimExpr Cptr = MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3);
PrimExpr Aptr =
MakeAccessPtrFromRegion(aRegion_, /*r*/ 1, /*require_2d*/ true);
PrimExpr Bptr =
MakeAccessPtrFromRegion(bRegion_, /*r*/ 1, /*require_2d*/ true);
PrimExpr Cptr =
MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3, /*require_2d*/ true);
std::stringstream ss;
std::string op_name;
......@@ -579,11 +488,13 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_;
Array<PrimExpr> new_args;
auto mbarPtr =
MakeAccessPtrFromRegion(mbarRegion_, /*rw*/ 3, /*require_2d*/ true);
new_args.push_back(StringImm(ss.str()));
new_args.push_back(Aptr);
new_args.push_back(Bptr);
new_args.push_back(BufferLoad(C_buffer, cCoords_));
new_args.push_back(mbarPtr_);
new_args.push_back(mbarPtr);
new_args.push_back(clearAccum_);
auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
......@@ -908,7 +819,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
return results;
}
TIR_REGISTER_TL_OP(Gemm, gemm)
TIR_REGISTER_TL_TILE_OP(Gemm, gemm)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......
......@@ -97,7 +97,7 @@ public:
// only will be enabled under cdna mfma instructions
int kPack_ = 1;
int wgWait_ = 0;
PrimExpr mbarPtr_;
BufferRegion mbarRegion_;
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_;
mutable GemmWarpPolicy policy_;
......@@ -144,7 +144,7 @@ private:
class Gemm : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode);
TVM_DLL Gemm(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL Gemm(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -12,100 +12,17 @@
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "region.h"
#include "tcgen5_meta.h"
#include "utils.h"
namespace tvm {
namespace tl {
using namespace tir;
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region)
// to BufferRegion
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in GEMM region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in GEMM region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap.at(var);
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
}
// NormalizeToBufferRegion moved to src/op/utils.{h,cc}
LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg;
throw; // Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims";
// Compute row-major strides
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements)
PrimExpr extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
......@@ -128,19 +45,17 @@ static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
GemmPy::GemmPy(Array<PrimExpr> args) {
ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->bRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->cRegion_ = NormalizeToBufferRegion(args[2], vmap);
node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->bRegion_ = NormalizeToBufferRegion(args[1]);
node->cRegion_ = NormalizeToBufferRegion(args[2]);
node->a_ = node->aRegion_->buffer;
node->b_ = node->bRegion_->buffer;
......@@ -165,11 +80,12 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() > 15) {
node->wgWait_ = args[15].as<IntImm>().value()->value;
}
node->mbarPtr_ = args[16];
if (node->mbarPtr_.as<CallNode>()) {
node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)];
} else {
node->mbar_ = std::nullopt;
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
}
}
node->cCoords_ = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
......@@ -219,7 +135,7 @@ GemmInst GemmPyNode::getGemmInst(int block_size, Target target) const {
return GemmInst::kMFMA;
} else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
TargetIsTuring(target) || TargetIsHopper(target) ||
TargetIsSm100(target)) {
TargetIsSm100(target) || TargetIsSM120(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
......@@ -266,13 +182,7 @@ bool GemmPyNode::checkWgmma() const {
if (c_->dtype == DataType::Float(16)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
......@@ -285,13 +195,7 @@ bool GemmPyNode::checkWgmma() const {
else if (a_->dtype == DataType::Float(32) &&
b_->dtype == DataType::Float(32))
return (!transA_) && transB_ && k_ % 8 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
......@@ -402,7 +306,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
return results;
}
TIR_REGISTER_TL_OP(GemmPy, gemm_py)
TIR_REGISTER_TL_TILE_OP(GemmPy, gemm_py)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......@@ -428,6 +332,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
result.push_back(Integer(meta.atom_m));
result.push_back(Integer(meta.atom_n));
result.push_back(Integer(meta.atom_k));
result.push_back(Integer(meta.enable_ws));
result.push_back(Integer(meta.enable_2cta));
}
return result;
});
......
......@@ -29,8 +29,8 @@ public:
int strideA_, strideB_;
int offsetA_, offsetB_;
PrimExpr clearAccum_ = const_false();
PrimExpr mbarPtr_;
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
BufferRegion mbarRegion_;
tir::Buffer mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
......@@ -59,7 +59,8 @@ public:
.def_ro("offsetA", &GemmPyNode::offsetA_)
.def_ro("offsetB", &GemmPyNode::offsetB_)
.def_ro("clearAccum", &GemmPyNode::clearAccum_)
.def_ro("mbarPtr", &GemmPyNode::mbarPtr_)
.def_ro("mbarRegion", &GemmPyNode::mbarRegion_)
.def_ro("mbar", &GemmPyNode::mbar_)
.def_ro("cCoords", &GemmPyNode::cCoords_)
.def_ro("kPack", &GemmPyNode::kPack_)
.def_ro("wgWait", &GemmPyNode::wgWait_)
......@@ -82,7 +83,7 @@ private:
class GemmPy : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode);
TVM_DLL GemmPy(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL GemmPy(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -14,6 +14,7 @@
#include "../target/utils.h"
#include "builtin.h"
#include "gemm.h"
#include "utils.h"
namespace tvm {
namespace tl {
......@@ -79,16 +80,19 @@ std::pair<int, int> GemmSPWarpPolicyNode::computeWarpPartition(int M, int N,
* The populated GemmSPNode is stored in the instance's internal data_ pointer.
*
* @param args Positional TL call arguments in the above order.
* @param vmap BufferMap mapping access pointers (from args) to Buffer objects.
*
* @note An ICHECK failure is raised if a provided kPack is not 1 or 2.
*/
GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
GemmSP::GemmSP(Array<PrimExpr> args) {
ObjectPtr<GemmSPNode> node = tvm::ffi::make_object<GemmSPNode>();
node->a_ = vmap[GetVarFromAccessPtr(args[0])];
node->e_ = vmap[GetVarFromAccessPtr(args[1])];
node->b_ = vmap[GetVarFromAccessPtr(args[2])];
node->c_ = vmap[GetVarFromAccessPtr(args[3])];
node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->eRegion_ = NormalizeToBufferRegion(args[1]);
node->bRegion_ = NormalizeToBufferRegion(args[2]);
node->cRegion_ = NormalizeToBufferRegion(args[3]);
node->a_ = node->aRegion_->buffer;
node->e_ = node->eRegion_->buffer;
node->b_ = node->bRegion_->buffer;
node->c_ = node->cRegion_->buffer;
node->transA_ = args[4].as<Bool>().value();
node->transB_ = args[5].as<Bool>().value();
node->m_ = args[6].as<IntImm>().value()->value;
......@@ -298,12 +302,25 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
return results;
}
TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
TIR_REGISTER_TL_TILE_OP(GemmSP, gemm_sp)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK() { GemmSPNode::RegisterReflection(); }
TVM_REGISTER_OP("tl.GemmSPWarpPolicy")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmSPWarpPolicy");
TVM_FFI_STATIC_INIT_BLOCK() {
GemmSPNode::RegisterReflection();
GemmSPWarpPolicyNode::RegisterReflection();
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(
"tl.GemmSPWarpPolicyComputeWarpPartition",
[](GemmSPWarpPolicy policy, int M, int N, int block_size, Target target,
bool use_wgmma, int bits) {
policy->computeWarpPartition(M, N, block_size, target, use_wgmma, bits);
return;
});
}
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment