Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
...@@ -30,3 +30,4 @@ scipy ...@@ -30,3 +30,4 @@ scipy
tabulate tabulate
tornado tornado
wheel wheel
z3-solver>=4.13.0
\ No newline at end of file
# Runtime requirements # Runtime requirements
apache-tvm-ffi~=0.1.0 apache-tvm-ffi>=0.1.3
torch-c-dlpack-ext
cloudpickle cloudpickle
ml-dtypes ml-dtypes
numpy>=1.23.5 numpy>=1.23.5
...@@ -8,3 +9,4 @@ torch ...@@ -8,3 +9,4 @@ torch
torch>=2.7; platform_system == 'Darwin' torch>=2.7; platform_system == 'Darwin'
tqdm>=4.62.3 tqdm>=4.62.3
typing-extensions>=4.10.0 typing-extensions>=4.10.0
z3-solver>=4.13.0
\ No newline at end of file
...@@ -44,16 +44,22 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) { ...@@ -44,16 +44,22 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
n->vars.push_back(var); n->vars.push_back(var);
n->doms.push_back(Range(0, dom)); n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms, n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
const Stmt &body) -> Stmt { const Array<Optional<PrimExpr>> &steps,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), 1); ICHECK_EQ(vars.size(), 1);
ICHECK_EQ(doms.size(), 1); ICHECK_EQ(doms.size(), 1);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body); Optional<PrimExpr> step =
!steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body,
/*thread_binding=*/std::nullopt,
/*annotations=*/tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any>{},
/*step=*/step);
}; };
return ForFrame(n); return ForFrame(n);
} }
ForFrame ParallelFor(const Array<PrimExpr> &extents, ForFrame ParallelFor(const Array<PrimExpr> &extents,
const Map<String, ObjectRef> &annotations) { const Map<String, tvm::ffi::Any> &annotations) {
using namespace tvm::tir; using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(extents.size()); n->vars.reserve(extents.size());
...@@ -63,16 +69,19 @@ ForFrame ParallelFor(const Array<PrimExpr> &extents, ...@@ -63,16 +69,19 @@ ForFrame ParallelFor(const Array<PrimExpr> &extents,
n->vars.push_back(Var("v", extent.dtype())); n->vars.push_back(Var("v", extent.dtype()));
n->doms.push_back(Range(make_const(dtype, 0), extent)); n->doms.push_back(Range(make_const(dtype, 0), extent));
} }
n->f_make_for_loop = [annotations](const Array<Var> &vars, n->f_make_for_loop =
const Array<Range> &doms, [annotations](const Array<Var> &vars, const Array<Range> &doms,
Stmt body) -> Stmt { const Array<Optional<PrimExpr>> &steps, Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size()); ICHECK_EQ(vars.size(), doms.size());
int n = vars.size(); int n = vars.size();
for (int i = n - 1; i >= 0; --i) { for (int i = n - 1; i >= 0; --i) {
Range dom = doms[i]; Range dom = doms[i];
Var var = vars[i]; Var var = vars[i];
Optional<PrimExpr> step =
i < steps.size() ? steps[i] : Optional<PrimExpr>(std::nullopt);
body = For(var, dom->min, dom->extent, ForKind::kParallel, body, body = For(var, dom->min, dom->extent, ForKind::kParallel, body,
/*thread_binding=*/std::nullopt, /*annotations=*/annotations); /*thread_binding=*/std::nullopt, /*annotations=*/annotations,
/*step=*/step);
} }
return body; return body;
}; };
...@@ -90,11 +99,12 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages, ...@@ -90,11 +99,12 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
n->vars.push_back(Var("v", dtype)); n->vars.push_back(Var("v", dtype));
n->doms.push_back(Range(std::move(start), stop)); n->doms.push_back(Range(std::move(start), stop));
n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms, n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
const Array<Optional<PrimExpr>> &steps,
Stmt body) -> Stmt { Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size()); ICHECK_EQ(vars.size(), doms.size());
int n = vars.size(); int n = vars.size();
ICHECK(n == 1); ICHECK(n == 1);
Map<String, ObjectRef> anno; Map<String, tvm::ffi::Any> anno;
if (num_stages > 0) if (num_stages > 0)
anno.Set("num_stages", PrimExpr(num_stages)); anno.Set("num_stages", PrimExpr(num_stages));
if (!order.empty()) if (!order.empty())
...@@ -105,8 +115,11 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages, ...@@ -105,8 +115,11 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
anno.Set("tl_pipeline_sync", sync); anno.Set("tl_pipeline_sync", sync);
if (!groups.empty()) if (!groups.empty())
anno.Set("tl_pipeline_group", groups); anno.Set("tl_pipeline_group", groups);
Optional<PrimExpr> step =
!steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body, body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body,
/*thread_binding=*/std::nullopt, /*annotations=*/anno); /*thread_binding=*/std::nullopt, /*annotations=*/anno,
/*step=*/step);
return body; return body;
}; };
return ForFrame(n); return ForFrame(n);
...@@ -145,9 +158,10 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size, ...@@ -145,9 +158,10 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
grouped_domain.push_back(group_size); grouped_domain.push_back(group_size);
n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms, n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
const Stmt &body) -> Stmt { const Array<Optional<PrimExpr>> &steps,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size()); ICHECK_EQ(vars.size(), doms.size());
Map<String, ObjectRef> anno; Map<String, tvm::ffi::Any> anno;
Array<PrimExpr> idxs(grouped_domain.size(), PrimExpr()); Array<PrimExpr> idxs(grouped_domain.size(), PrimExpr());
PrimExpr rem = loop_var * wave_size + index; PrimExpr rem = loop_var * wave_size + index;
...@@ -168,8 +182,11 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size, ...@@ -168,8 +182,11 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
if (analyzer.CanProveGreaterEqual(waves, 2)) { if (analyzer.CanProveGreaterEqual(waves, 2)) {
new_body = SeqStmt({out_if, body}); new_body = SeqStmt({out_if, body});
} }
Stmt outer = Optional<PrimExpr> step =
For(loop_var, 0, waves, ForKind::kSerial, new_body, std::nullopt, anno); !steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
Stmt outer = For(loop_var, 0, waves, ForKind::kSerial, new_body,
/*thread_binding=*/std::nullopt, /*annotations=*/anno,
/*step=*/step);
for (int i = 0; i < vars.size() - 1; ++i) { for (int i = 0; i < vars.size() - 1; ++i) {
outer = tvm::tir::LetStmt(vars[i], idxs[i + 1], outer); outer = tvm::tir::LetStmt(vars[i], idxs[i + 1], outer);
} }
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include "arith/pattern_match.h" #include "arith/pattern_match.h"
#include "tvm/node/functor.h"
#include "tvm/node/repr_printer.h"
#include "utils.h" #include "utils.h"
namespace tvm { namespace tvm {
...@@ -78,7 +80,8 @@ void LayoutNode::RegisterReflection() { ...@@ -78,7 +80,8 @@ void LayoutNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::ObjectDef<LayoutNode>() refl::ObjectDef<LayoutNode>()
.def_ro("input_size", &LayoutNode::input_size_) .def_ro("input_size", &LayoutNode::input_size_)
.def_ro("forward_index", &LayoutNode::forward_index_); .def_ro("forward_index", &LayoutNode::forward_index_)
.def("_DebugOutput", &LayoutNode::DebugOutput);
} }
void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const { void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
...@@ -297,13 +300,17 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const { ...@@ -297,13 +300,17 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
} }
Layout LayoutNode::Reshape(const Array<PrimExpr> &shape, Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const { arith::Analyzer *analyzer,
const PrimExpr rescale_num,
const PrimExpr rescale_den) const {
// Fast path: if shape is the same, return the original layout // Fast path: if shape is the same, return the original layout
if (StructuralEqual()(InputShape(), shape)) { if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Layout>(this); return ffi::GetRef<Layout>(this);
} }
// Step 1. Prove the product of InputShape is equal to the product of shape // Step 1. Prove the product relation holds under rescale:
// prod(InputShape) * rescale_num == prod(shape) * rescale_den
PrimExpr input_shape_product = Integer(1); PrimExpr input_shape_product = Integer(1);
for (const auto &dim : InputShape()) { for (const auto &dim : InputShape()) {
input_shape_product *= dim; input_shape_product *= dim;
...@@ -317,8 +324,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape, ...@@ -317,8 +324,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
// potential null dereference paths flagged by static analysis. // potential null dereference paths flagged by static analysis.
arith::Analyzer fallback_analyzer; arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer; arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_shape_product, shape_product)) ICHECK(az->CanProveEqual(input_shape_product * rescale_num,
<< "InputShape() = " << InputShape() << " shape = " << shape; shape_product * rescale_den))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den;
// Step 2. Create new forward indices by reshaping // Step 2. Create new forward indices by reshaping
// For each dimension in the new shape, we create a placeholder variable // For each dimension in the new shape, we create a placeholder variable
...@@ -339,13 +348,17 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape, ...@@ -339,13 +348,17 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
} }
flat_index = flat_index + new_vars[i] * stride; flat_index = flat_index + new_vars[i] * stride;
} }
// Convert new flat index (in units of new elements) to the old flat index
// (in units of old elements) using the rational rescale factor.
// old_flat = floor((flat_index * rescale_den) / rescale_num)
PrimExpr old_flat_index = floordiv(flat_index * rescale_den, rescale_num);
// Step 4. Convert flat index back to original shape indices // Step 4. Convert flat index back to original shape indices
// For original shape [s0, s1, ..., sm]: // For original shape [s0, s1, ..., sm]:
// i0 = flat_index // (s1 * s2 * ... * sm) // i0 = flat_index // (s1 * s2 * ... * sm)
// i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm) // i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm)
// ... // ...
Array<PrimExpr> original_indices; Array<PrimExpr> original_indices;
PrimExpr remaining = flat_index; PrimExpr remaining = old_flat_index;
for (size_t i = 0; i < InputShape().size(); ++i) { for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1); PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j) { for (size_t j = i + 1; j < InputShape().size(); ++j) {
...@@ -373,7 +386,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape, ...@@ -373,7 +386,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
} }
Layout FragmentNode::Reshape(const Array<PrimExpr> &shape, Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const { arith::Analyzer *analyzer,
const PrimExpr rescale_num,
const PrimExpr rescale_den) const {
// Fast path: identical input shape, return self // Fast path: identical input shape, return self
if (StructuralEqual()(InputShape(), shape)) { if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Fragment>(this); return ffi::GetRef<Fragment>(this);
...@@ -390,8 +406,9 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape, ...@@ -390,8 +406,9 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
// Use provided analyzer if present, otherwise a local fallback. // Use provided analyzer if present, otherwise a local fallback.
arith::Analyzer fallback_analyzer; arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer; arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_prod, shape_prod)) ICHECK(az->CanProveEqual(input_prod * rescale_num, shape_prod * rescale_den))
<< "InputShape() = " << InputShape() << " shape = " << shape << "InputShape() = " << InputShape() << " shape = " << shape
<< ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den
<< " input fragment layout is = " << DebugOutput(); << " input fragment layout is = " << DebugOutput();
// 2) Build flat index from new-shape indices // 2) Build flat index from new-shape indices
...@@ -414,9 +431,12 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape, ...@@ -414,9 +431,12 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
stride = stride * shape[j]; stride = stride * shape[j];
flat = flat + new_vars[i] * stride; flat = flat + new_vars[i] * stride;
} }
// Convert to old flat index units using the rational rescale factor.
// old_flat = floor((flat * rescale_den) / rescale_num)
PrimExpr old_flat = floordiv(flat * rescale_den, rescale_num);
// 3) Recover original indices from flat index // 3) Recover original indices from flat index
Array<PrimExpr> orig_indices; Array<PrimExpr> orig_indices;
PrimExpr remain = flat; PrimExpr remain = old_flat;
for (size_t i = 0; i < InputShape().size(); ++i) { for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1); PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j) for (size_t j = i + 1; j < InputShape().size(); ++j)
...@@ -529,6 +549,12 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, ...@@ -529,6 +549,12 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
data_ = std::move(n); data_ = std::move(n);
} }
Fragment Fragment::FullyReplicated(Array<PrimExpr> shape,
PrimExpr thread_extent) {
return Fragment(shape, {}, ReplicationPlaceholder(), thread_extent,
std::nullopt);
}
// which means the forward_thread is rep_var -> lambda i, rep: rep // which means the forward_thread is rep_var -> lambda i, rep: rep
bool FragmentNode::IsCompletedReplicated() const { bool FragmentNode::IsCompletedReplicated() const {
arith::Analyzer analyzer; arith::Analyzer analyzer;
...@@ -536,6 +562,52 @@ bool FragmentNode::IsCompletedReplicated() const { ...@@ -536,6 +562,52 @@ bool FragmentNode::IsCompletedReplicated() const {
ReplicationPlaceholder()); ReplicationPlaceholder());
} }
arith::IterMapResult FragmentNode::DetectInjective() const {
// lei:To perform injective check, we need to reverse the layout
// and use surjective check, now we use bijective check for convenience
// can be relaxed in future
arith::Analyzer analyzer;
// Build a flat indices array: [forward_thread_, forward_index_[...]]
Array<PrimExpr> indices;
indices.push_back(forward_thread_);
for (const auto &e : forward_index_) {
indices.push_back(e);
}
// Mirror Layout::InverseWithLevel(): if any participating shape is
// symbolic, relax to NoCheck and rely on runtime guards elsewhere.
auto collect_symbolic = [&](const Array<PrimExpr> &shape) {
Array<PrimExpr> symbolic_dims;
for (const auto &dim : shape) {
if (!as_const_int(dim)) {
symbolic_dims.push_back(dim);
}
}
return symbolic_dims;
};
Array<PrimExpr> symbolic_dims = collect_symbolic(InputShape());
Array<PrimExpr> output_shape = OutputShape();
symbolic_dims.insert(symbolic_dims.end(), output_shape.begin(),
output_shape.end());
// Also consider replicate size for fragments
if (!as_const_int(ReplicateExtent())) {
symbolic_dims.push_back(ReplicateExtent());
}
symbolic_dims = collect_symbolic(symbolic_dims);
bool is_static_shape = symbolic_dims.empty();
auto level = is_static_shape ? arith::IterMapLevel::Bijective
: arith::IterMapLevel::NoCheck;
if (!is_static_shape) {
DLOG(WARNING)
<< "Fragment::DetectInjective on symbolic layout, falling back to "
<< "NoCheck; symbolic dims: " << symbolic_dims;
}
return arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer);
}
PrimExpr FragmentNode::ThreadExtent() const { PrimExpr FragmentNode::ThreadExtent() const {
Array<PrimExpr> ret(OutputDim(), 1); Array<PrimExpr> ret(OutputDim(), 1);
arith::Analyzer analyzer; arith::Analyzer analyzer;
...@@ -653,8 +725,19 @@ void FragmentNode::RegisterReflection() { ...@@ -653,8 +725,19 @@ void FragmentNode::RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::ObjectDef<FragmentNode>() refl::ObjectDef<FragmentNode>()
.def_ro("forward_thread", &FragmentNode::forward_thread_) .def_ro("forward_thread", &FragmentNode::forward_thread_)
.def_ro("replicate_size", &FragmentNode::replicate_size_); .def_ro("replicate_size", &FragmentNode::replicate_size_)
} .def("_DebugOutput", &FragmentNode::DebugOutput);
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FragmentNode>([](const ObjectRef &obj, ReprPrinter *p) {
auto *node = static_cast<const FragmentNode *>(obj.get());
p->stream << node->DebugOutput();
})
.set_dispatch<LayoutNode>([](const ObjectRef &obj, ReprPrinter *p) {
auto *node = static_cast<const LayoutNode *>(obj.get());
p->stream << node->DebugOutput();
});
TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#ifndef TVM_TL_LAYOUT_LAYOUT_H_ #ifndef TVM_TL_LAYOUT_LAYOUT_H_
#define TVM_TL_LAYOUT_LAYOUT_H_ #define TVM_TL_LAYOUT_LAYOUT_H_
#include <exception>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/object.h> #include <tvm/ffi/object.h>
...@@ -18,6 +19,25 @@ namespace tl { ...@@ -18,6 +19,25 @@ namespace tl {
using namespace tir; using namespace tir;
// Common layout-related exceptions
class LayoutConflictException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
explicit LayoutConflictException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
class LoopLayoutInjectiveException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
explicit LoopLayoutInjectiveException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
class Layout; class Layout;
class Fragment; class Fragment;
...@@ -42,8 +62,18 @@ public: ...@@ -42,8 +62,18 @@ public:
virtual Layout Inverse() const; virtual Layout Inverse() const;
// Reshape the layout to a new logical shape. When aliasing buffers of
// different dtypes, the element count may change while the underlying
// byte-size stays equal. Use rescale_num/rescale_den to represent the
// ratio between the old element size and the new element size in bytes.
// Specifically, define factor = rescale_num / rescale_den where:
// new_num_elems = old_num_elems * factor
// For example, f32->i8 (4B -> 1B) uses rescale_num=4, rescale_den=1.
// i8->f32 (1B -> 4B) uses rescale_num=1, rescale_den=4.
virtual Layout Reshape(const Array<PrimExpr> &shape, virtual Layout Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const; arith::Analyzer *analyzer,
const PrimExpr rescale_num = Integer(1),
const PrimExpr rescale_den = Integer(1)) const;
virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const; virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;
...@@ -86,7 +116,9 @@ public: ...@@ -86,7 +116,9 @@ public:
Layout Inverse() const final; Layout Inverse() const final;
Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer) const; Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer,
const PrimExpr rescale_num = Integer(1),
const PrimExpr rescale_den = Integer(1)) const;
std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final; std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;
...@@ -116,6 +148,8 @@ public: ...@@ -116,6 +148,8 @@ public:
bool IsCompletedReplicated() const; bool IsCompletedReplicated() const;
arith::IterMapResult DetectInjective() const;
static void RegisterReflection(); static void RegisterReflection();
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode); TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode);
...@@ -141,6 +175,20 @@ public: ...@@ -141,6 +175,20 @@ public:
PrimExpr forward_thread, PrimExpr replicate_size, PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var); Optional<Var> replicate_var);
/*!
* \brief Create a fully replicated fragment layout.
*
* A fully replicated fragment means all threads hold identical copies of the
* entire buffer. This is useful for index buffers or masks that need to be
* accessed uniformly across all threads.
*
* \param shape The shape of the buffer.
* \param thread_extent The number of threads.
* \return A Fragment where each thread has a complete copy of all elements.
*/
TVM_DLL static Fragment FullyReplicated(Array<PrimExpr> shape,
PrimExpr thread_extent);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode);
}; };
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
*/ */
#include "./atomic_add.h" #include "./atomic_add.h"
#include "./region.h" #include "utils.h"
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
...@@ -26,32 +26,27 @@ using namespace tir; ...@@ -26,32 +26,27 @@ using namespace tir;
* @brief Construct an AtomicAdd operator from call arguments and a buffer map. * @brief Construct an AtomicAdd operator from call arguments and a buffer map.
* *
* Builds the internal AtomicAddNode, extracts the source and destination * Builds the internal AtomicAddNode, extracts the source and destination
* regions and their backing Buffers from the first two call-style expressions * regions and their backing Buffers from the first two region-style expressions
* in `args` (via RegionOp), and stores them along with their ranges. If a third * in `args` (BufferLoad/BufferRegion), and stores them along with their
* argument is provided, it is interpreted as an integer immediate and stored as * ranges. If a third argument is provided, it is interpreted as an integer
* the node's coalesced width. * immediate and stored as the node's coalesced width.
* *
* @param args Call-style PrimExprs where: * @param args Call-style PrimExprs where:
* - args[0] is the source region call, * - args[0] is the source region call,
* - args[1] is the destination region call, * - args[1] is the destination region call,
* - args[2] (optional) is an IntImm specifying coalesced width. * - args[2] (optional) is an IntImm specifying coalesced width.
* @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects.
*
* Notes: * Notes:
* - The constructor checks that args[0] and args[1] are CallNodes. * - The constructor checks that args[0] and args[1] are region-compatible.
* - The constructed node is stored in this->data_. * - The constructed node is stored in this->data_.
*/ */
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { AtomicAdd::AtomicAdd(Array<PrimExpr> args) {
ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>(); ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>();
Array<Range> rgs[2]; Array<Range> rgs[2];
Buffer bf[2]; Buffer bf[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
auto expr = args[i]; auto region = NormalizeToBufferRegion(args[i]);
auto call = expr.as<CallNode>(); rgs[i] = region->region;
ICHECK(call); bf[i] = region->buffer;
auto region = RegionOp(call->args, vmap);
rgs[i] = region->GetRanges();
bf[i] = region->GetBuffer();
} }
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
...@@ -272,22 +267,22 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -272,22 +267,22 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0); Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1); Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);
Array<PrimExpr> new_args;
// Optional bounds predicates for src and dst
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0); PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1); PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
Array<PrimExpr> new_args; // Load source value and cast to dst dtype if needed
PrimExpr src_value = BufferLoad(src, src_indices); PrimExpr src_value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype) if (src->dtype != dst->dtype)
src_value = Cast(dst->dtype, src_value); src_value = Cast(dst->dtype, src_value);
if (src_predicate.defined())
src_value = if_then_else(src_predicate, src_value, make_zero(dst->dtype));
PrimExpr dst_value = BufferLoad(dst, dst_indices); // Build a pointer to destination element using tvm_access_ptr
if (dst_predicate.defined()) PrimExpr dst_ptr = Call(DataType::Handle(), builtin::address_of(),
dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype)); {BufferLoad(dst, dst_indices)});
new_args.push_back(dst_value); new_args.push_back(dst_ptr);
new_args.push_back(src_value); new_args.push_back(src_value);
new_args.push_back(memory_order); new_args.push_back(memory_order);
...@@ -544,7 +539,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -544,7 +539,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop; return vectorized_thread_loop;
} }
TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) TIR_REGISTER_TL_TILE_OP(AtomicAdd, atomicadd)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
...@@ -552,4 +547,4 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) ...@@ -552,4 +547,4 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); } TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -65,7 +65,7 @@ class AtomicAdd : public TileOperator { ...@@ -65,7 +65,7 @@ class AtomicAdd : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator,
AtomicAddNode); AtomicAddNode);
TVM_DLL AtomicAdd(Array<PrimExpr> args, BufferMap vmap); TVM_DLL AtomicAdd(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -22,8 +22,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool); ...@@ -22,8 +22,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableThreadStorageSync, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableThreadStorageSync, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kForceLetInline, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kForceLetInline, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
...@@ -34,6 +32,9 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool); ...@@ -34,6 +32,9 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationEnable, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationFormats, String);
TVM_REGISTER_PASS_CONFIG_OPTION(kDeviceCompileFlags, ffi::Array<ffi::String>);
DataType cuTensorMapType() { return DataType::UInt(8, 128); } DataType cuTensorMapType() { return DataType::UInt(8, 128); }
...@@ -99,6 +100,12 @@ TIR_DEFINE_TL_BUILTIN(ieee_frsqrt) ...@@ -99,6 +100,12 @@ TIR_DEFINE_TL_BUILTIN(ieee_frsqrt)
TIR_DEFINE_TL_BUILTIN(ieee_fdiv).set_num_inputs(3).set_attr<TCallEffectKind>( TIR_DEFINE_TL_BUILTIN(ieee_fdiv).set_num_inputs(3).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure)); "TCallEffectKind", Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(3).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(rng_rand).set_num_inputs(0).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier) TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
...@@ -344,5 +351,35 @@ TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive) ...@@ -344,5 +351,35 @@ TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_sum)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_max)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_min)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_bitand)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_bitor)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
// __ldg(BufferLoad | Buffer, idx?) -> value
// Treat as a pure call that returns the loaded value.
TIR_DEFINE_TL_BUILTIN(__ldg).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -28,6 +28,10 @@ static constexpr const char *kWarpSpecializationScope = ...@@ -28,6 +28,10 @@ static constexpr const char *kWarpSpecializationScope =
static constexpr const char *kCustomWarpSpecialization = static constexpr const char *kCustomWarpSpecialization =
"kCustomWarpSpecialization"; "kCustomWarpSpecialization";
static constexpr const char *kLocalVarInit = "tl.local_var_init"; static constexpr const char *kLocalVarInit = "tl.local_var_init";
// A PrimFunc-level attribute carrying a list of handle Vars
// that must NOT be marked with the restrict qualifier in codegen.
// Type: Array<tir::Var>
static constexpr const char *kNonRestrictParams = "tl.non_restrict_params";
} // namespace attr } // namespace attr
static constexpr const char *kDebugMergeSharedMemoryAllocations = static constexpr const char *kDebugMergeSharedMemoryAllocations =
...@@ -51,14 +55,11 @@ static constexpr const char *kDisableWGMMA = "tl.disable_wgmma"; ...@@ -51,14 +55,11 @@ static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect"; static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
static constexpr const char *kStorageRewriteDetectInplace = static constexpr const char *kStorageRewriteDetectInplace =
"tl.storage_rewrite_detect_inplace"; "tl.storage_rewrite_detect_inplace";
/*! static constexpr const char *kLayoutVisualizationEnable =
* \brief Whether to disable dynamic tail split "tl.layout_visualization_enable";
* static constexpr const char *kLayoutVisualizationFormats =
* kDisableDynamicTailSplit = "tl.disable_dynamic_tail_split" "tl.layout_visualization_formats";
* static constexpr const char *kDeviceCompileFlags = "tl.device_compile_flags";
*/
static constexpr const char *kDisableDynamicTailSplit =
"tl.disable_dynamic_tail_split";
/*! /*!
* \brief Whether to disable thread storage synchronization * \brief Whether to disable thread storage synchronization
...@@ -82,18 +83,6 @@ static constexpr const char *kDisableThreadStorageSync = ...@@ -82,18 +83,6 @@ static constexpr const char *kDisableThreadStorageSync =
*/ */
static constexpr const char *kForceLetInline = "tl.force_let_inline"; static constexpr const char *kForceLetInline = "tl.force_let_inline";
/*!
* \brief The size of the vectorized dimension in buffer, designed by user
*
* For example, if the vectorized dimension is 128 bits and the dtype of buffer
* A[m, k] is float16, the size of the vectorized dimension (i.e. k) in buffer A
* should be divisible by 8 (8 = 128 / 16).
*
* kDynamicAlignment = "tl.dynamic_alignment"
*
*/
static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
/*! /*!
* \brief Get the type of the CUDA tensor map * \brief Get the type of the CUDA tensor map
* *
...@@ -138,6 +127,10 @@ TVM_DLL const Op &ieee_frsqrt(); ...@@ -138,6 +127,10 @@ TVM_DLL const Op &ieee_frsqrt();
// ieee_fdiv(x, y, rounding_mode) - IEEE-compliant division // ieee_fdiv(x, y, rounding_mode) - IEEE-compliant division
TVM_DLL const Op &ieee_fdiv(); TVM_DLL const Op &ieee_fdiv();
// random op
TVM_DLL const Op &rng_init();
TVM_DLL const Op &rng_rand();
/*! /*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load * \brief tvm intrinsics for TMADescriptor creation for tiled load
* *
...@@ -582,6 +575,49 @@ TVM_DLL const Op &device_assert(); ...@@ -582,6 +575,49 @@ TVM_DLL const Op &device_assert();
*/ */
TVM_DLL const Op &device_assert_with_msg(); TVM_DLL const Op &device_assert_with_msg();
/*!
* \brief tilelang intrinsic for warp reduction sum.
*/
TVM_DLL const Op &warp_reduce_sum();
/*!
* \brief tilelang intrinsic for warp reduction max.
*/
TVM_DLL const Op &warp_reduce_max();
/*!
* \brief tilelang intrinsic for warp reduction min.
*/
TVM_DLL const Op &warp_reduce_min();
/*!
* \brief tilelang intrinsic for warp reduction bitand.
*/
TVM_DLL const Op &warp_reduce_bitand();
/*!
* \brief tilelang intrinsic for warp reduction bitor.
*/
TVM_DLL const Op &warp_reduce_bitor();
/*!
* \brief tilelang intrinsic for CUDA read-only cache load (__ldg).
*
* This op allows users to explicitly request a non-coherent cached load
* from global memory on CUDA by emitting `__ldg(&ptr[idx])` for 32-bit
* element types on supported architectures. It provides a direct way to
* leverage the read-only data cache for performance-sensitive loads when
* the compiler cannot infer `const __restrict__` automatically.
*
* Usage from TVMScript:
* y[i] = T.__ldg(x[i])
*
* The op takes one argument preferred as a BufferLoad identifying the
* source element; alternatively, backends may support passing a Buffer and
* index expression.
*/
TVM_DLL const Op &__ldg();
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "../transform/common/loop_parallel_transform_utils.h" #include "../transform/common/loop_parallel_transform_utils.h"
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h" #include "../transform/loop_vectorize.h"
#include "region.h" #include "utils.h"
#include "../target/cuda.h" #include "../target/cuda.h"
#include "../target/utils.h" #include "../target/utils.h"
...@@ -57,7 +57,7 @@ static int to_CUtensorMapDataType(DataType dtype) { ...@@ -57,7 +57,7 @@ static int to_CUtensorMapDataType(DataType dtype) {
} }
} else if (dtype.is_bfloat16()) { } else if (dtype.is_bfloat16()) {
tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if (dtype.is_float8_e4m3() || dtype.is_float8_e5m2()) { } else if (dtype.is_float8()) {
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if (dtype.is_int()) { } else if (dtype.is_int()) {
switch (dtype.bits()) { switch (dtype.bits()) {
...@@ -110,36 +110,32 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) { ...@@ -110,36 +110,32 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
/*! /*!
* \brief Construct a Copy operator node from call arguments and a buffer map. * \brief Construct a Copy operator node from call arguments and a buffer map.
* *
* This constructor parses the first two entries of `args` as Call nodes * This constructor parses the first two entries of `args` as regions
* describing source and destination Regions (via RegionOp), extracts their * (BufferLoad/BufferRegion), extracts their Buffers and Ranges, and stores
* Buffers and Ranges, and stores them on the newly created CopyNode. It also * them on the newly created CopyNode. It also
* reads optional arguments: * reads optional arguments:
* - args[2] (IntImm): coalesced width (stored only if > 0), * - args[2] (IntImm): coalesced width (stored only if > 0),
* - args[3] (Bool): disable TMA lowering flag, * - args[3] (Bool): disable TMA lowering flag,
* - args[4] (IntImm): eviction policy. * - args[4] (IntImm): eviction policy.
* *
* Preconditions: * Preconditions:
* - `args` must contain at least two Call-compatible PrimExpr entries * - `args` must contain at least two region-compatible PrimExpr entries
* describing regions; an ICHECK will fail if they are not CallNodes. * (BufferLoad/BufferRegion); ICHECK will fail otherwise.
* *
* @param args Array of PrimExpr where: * @param args Array of PrimExpr where:
* - args[0] is the source Region call, * - args[0] is the source Region call,
* - args[1] is the destination Region call, * - args[1] is the destination Region call,
* - optional args[2..4] are coalesced width, disable_tma, and eviction * - optional args[2..4] are coalesced width, disable_tma, and eviction
* policy. * policy.
* @param vmap BufferMap used to resolve RegionOp buffers and ranges.
*/ */
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) { Copy::Copy(Array<PrimExpr> args) {
ObjectPtr<CopyNode> node = tvm::ffi::make_object<CopyNode>(); ObjectPtr<CopyNode> node = tvm::ffi::make_object<CopyNode>();
Array<Range> rgs[2]; Array<Range> rgs[2];
Buffer bf[2]; Buffer bf[2];
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
auto expr = args[i]; auto region = NormalizeToBufferRegion(args[i]);
auto call = expr.as<CallNode>(); rgs[i] = region->region;
ICHECK(call); bf[i] = region->buffer;
auto region = RegionOp(call->args, vmap);
rgs[i] = region->GetRanges();
bf[i] = region->GetBuffer();
} }
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
...@@ -183,15 +179,95 @@ TileOperator CopyNode::Clone() const { ...@@ -183,15 +179,95 @@ TileOperator CopyNode::Clone() const {
* copy operation. * copy operation.
*/ */
Array<IterVar> CopyNode::MakeIterVars() const { Array<IterVar> CopyNode::MakeIterVars() const {
// Choose the range set from the lowest-level memory scope between src and
// dst. Scope levels: global < shared/shared.dyn/shared.tmem < local.fragment
// (fragment)
auto scope_level = [](const Buffer &b) -> int {
String s = b.scope();
if (s == "local.fragment" || s == "local")
return 2;
if (s == "shared" || s == "shared.dyn" || s == "shared.tmem")
return 1;
// default to global level for unknown scopes
return 0;
};
int src_level = scope_level(src);
int dst_level = scope_level(dst);
bool base_is_src = (src_level >= dst_level);
const Array<Range> &base_ranges = base_is_src ? src_range : dst_range;
// Sanity check: when switching away from the original (src_range),
// ensure the chosen base ranges are not provably smaller than the original
// per dimension. This guards against generating undersized loop domains.
// Improved logic: use two pointers to traverse both base_ranges and
// src_range, skipping dimensions with extent == 1. The number of non-1
// extents must match.
arith::Analyzer analyzer;
size_t base_dim = 0, src_dim = 0;
while (base_dim < base_ranges.size() && src_dim < src_range.size()) {
// Skip base extents that are 1
while (base_dim < base_ranges.size() &&
is_one(base_ranges[base_dim]->extent)) {
++base_dim;
}
// Skip src extents that are 1
while (src_dim < src_range.size() && is_one(src_range[src_dim]->extent)) {
++src_dim;
}
// Both indices now at non-1, or at end
if (base_dim < base_ranges.size() && src_dim < src_range.size()) {
PrimExpr base_ext = base_ranges[base_dim]->extent;
PrimExpr src_ext = src_range[src_dim]->extent;
// Only fail if base extent is provably smaller than src extent
if (analyzer.CanProve(base_ext < src_ext)) {
std::ostringstream oss;
oss << "Selected loop range is smaller than original src range at "
"matched non-1 dimension: "
<< "base(extent=" << base_ext
<< ", scope=" << (base_is_src ? src.scope() : dst.scope())
<< ", min=" << base_ranges[base_dim]->min
<< ", base_dim=" << base_dim << ") < src(extent=" << src_ext
<< ", min=" << src_range[src_dim]->min << ", src_dim=" << src_dim
<< ", scope=" << src.scope() << ") for src=" << src->name
<< ", dst=" << dst->name << "\n";
oss << "src buffer: " << src->name << ", scope=" << src.scope() << "\n";
oss << "dst buffer: " << dst->name << ", scope=" << dst.scope() << "\n";
oss << "base_ranges[" << base_dim
<< "]: min=" << base_ranges[base_dim]->min
<< ", extent=" << base_ext << "\n";
oss << "src_ranges[" << src_dim << "]: min=" << src_range[src_dim]->min
<< ", extent=" << src_ext << "\n";
LOG(FATAL) << oss.str();
}
++base_dim;
++src_dim;
}
}
// Any remaining unmatched dimensions in either range must all have extent ==
// 1
while (base_dim < base_ranges.size()) {
ICHECK(is_one(base_ranges[base_dim]->extent))
<< "base_ranges has extra non-1 extent at dim " << base_dim;
++base_dim;
}
while (src_dim < src_range.size()) {
ICHECK(is_one(src_range[src_dim]->extent))
<< "src_range has extra non-1 extent at dim " << src_dim;
++src_dim;
}
Array<IterVar> loop_vars; Array<IterVar> loop_vars;
size_t idx = 0; size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) { for (size_t i = 0; i < base_ranges.size(); i++) {
if (is_one(src_range[i]->extent)) if (is_one(base_ranges[i]->extent))
continue; continue;
Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype); Var var = Var(std::string{char('i' + idx)}, base_ranges[i]->extent->dtype);
idx++; idx++;
loop_vars.push_back( loop_vars.push_back(
{Range(0, src_range[i]->extent), var, IterVarType::kDataPar}); {Range(0, base_ranges[i]->extent), var, IterVarType::kDataPar});
} }
return loop_vars; return loop_vars;
} }
...@@ -250,6 +326,7 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer, ...@@ -250,6 +326,7 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs, const Array<IterVar> &ivs,
Array<PrimExpr> extents, int src_dst) const { Array<PrimExpr> extents, int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range; Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list; Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
size_t idx = 0; size_t idx = 0;
...@@ -302,7 +379,6 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -302,7 +379,6 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
for (const auto &iv : loop_vars) for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom); analyzer->Bind(iv->var, iv->dom);
ICHECK(loop_vars.size() <= src_range.size()) ICHECK(loop_vars.size() <= src_range.size())
<< "loop_vars.size() = " << loop_vars.size() << "loop_vars.size() = " << loop_vars.size()
<< ", src_range.size() = " << src_range.size() << ", src = " << src->name << ", src_range.size() = " << src_range.size() << ", src = " << src->name
...@@ -475,16 +551,38 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, ...@@ -475,16 +551,38 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
// This must be a global/shared layout, so we can skip the parallel op // This must be a global/shared layout, so we can skip the parallel op
// layout inference (parallel layout inference only annotate the loop layout // layout inference (parallel layout inference only annotate the loop layout
// and the register layout). // and the register layout).
bool is_load = copy_inst == CopyInst::kBulkLoad; bool is_load =
copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkLoad1D;
Buffer global_tensor = is_load ? src : dst; Buffer global_tensor = is_load ? src : dst;
Buffer shared_tensor = is_load ? dst : src; Buffer shared_tensor = is_load ? dst : src;
Map<Buffer, Layout> result_map;
// Collect fragment buffers from indices and mark them as fully replicated
// For Bulk Load/Store, fragment buffers used as indices should be
// replicated across all threads
PrimExpr thread_extent = T.thread_bounds->extent;
for (const auto &range : src_range) {
CollectFragmentLayouts(range->min, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
CollectFragmentLayouts(range->extent, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
}
for (const auto &range : dst_range) {
CollectFragmentLayouts(range->min, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
CollectFragmentLayouts(range->extent, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
}
// check shared layout is non-swizzle // check shared layout is non-swizzle
// skip layout inference if shared layout is already annotated // skip layout inference if shared layout is already annotated
if (level == InferLevel::kFree && !T.layout_map.count(shared_tensor)) { if (level == InferLevel::kFree && !T.layout_map.count(shared_tensor)) {
// create a new layout map for tma linear layout // create a new layout map for tma linear layout
Layout linear_layout = ComputeLinearLayout(shared_tensor); Layout linear_layout = ComputeLinearLayout(shared_tensor);
return Map<Buffer, Layout>({{shared_tensor, linear_layout}}); result_map.Set(shared_tensor, linear_layout);
} }
return result_map;
} }
// for LDSM/STSM, the layout was deduced from register layout // for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy // so we can directly apply the layout of normal copy
...@@ -493,7 +591,8 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, ...@@ -493,7 +591,8 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
arith::Analyzer analyzer; arith::Analyzer analyzer;
par_op_ = ParallelOp((MakeSIMTLoop(&analyzer))); par_op_ = ParallelOp((MakeSIMTLoop(&analyzer)));
} }
return par_op_->InferLayout(T, level); auto layout_map = par_op_->InferLayout(T, level);
return layout_map;
} }
/** /**
* @brief Determine whether this CopyNode can be lowered to a Bulk Load (TMA) * @brief Determine whether this CopyNode can be lowered to a Bulk Load (TMA)
...@@ -851,21 +950,31 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, ...@@ -851,21 +950,31 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
For vectorized_thread_loop; For vectorized_thread_loop;
auto par_op = ParallelOp(transformed_loop); auto par_op = ParallelOp(transformed_loop);
if (is_cpu_target) { if (is_cpu_target || dst.scope() == "local" || src.scope() == "local") {
if (src.scope() == "local" && dst.scope() != "local") {
LOG(WARNING) << "Copy from local buffer `" << src->name << "` to "
<< dst.scope() << " buffer `" << dst->name
<< "` may cause conflicted write.";
}
vectorized_thread_loop = VectorizeLoop(transformed_loop); vectorized_thread_loop = VectorizeLoop(transformed_loop);
} else { } else {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict, std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree}; InferLevel::kFree};
for (auto level : levels) { for (auto level : levels) {
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer, par_op->InferLayout({T.target,
false, T.buffer_remap}, T.thread_bounds,
T.layout_map,
analyzer,
false,
T.buffer_remap,
{}},
level); level);
} }
auto loop_layout = par_op->GetLoopLayout(); auto loop_layout = par_op->GetLoopLayout();
auto thread_var = T.thread_var; auto thread_var = T.thread_var;
auto thread_loop = auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout); PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeLoop(thread_loop); vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
} }
if (par_op->GetPredicate(T.thread_var).defined()) { if (par_op->GetPredicate(T.thread_var).defined()) {
...@@ -1117,6 +1226,11 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, ...@@ -1117,6 +1226,11 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
bool is_ld = false; // tcgen05.ld (tensor memory -> register) bool is_ld = false; // tcgen05.ld (tensor memory -> register)
bool is_st = false; // tcgen05.st (register -> tensor memory) bool is_st = false; // tcgen05.st (register -> tensor memory)
bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory) bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory)
bool src_needs_pack =
16 == src->dtype.bits(); // if needs .pack::16b when is_ld
bool dst_needs_unpack =
16 == dst->dtype.bits(); // if needs .unpack::16b when is_st
if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") { if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") {
is_ld = true; is_ld = true;
} else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") { } else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") {
...@@ -1124,9 +1238,8 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, ...@@ -1124,9 +1238,8 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
} else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") { } else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") {
is_cp = true; is_cp = true;
} else { } else {
ICHECK(0) << "Unsupported tensor memory copy: " ICHECK(0) << "Unsupported tensor memory copy: " << "src scope = "
<< "src scope = " << src.scope() << src.scope() << ", dst scope = " << dst.scope();
<< ", dst scope = " << dst.scope();
} }
// Currently tcgen05.cp is not supported // Currently tcgen05.cp is not supported
// TODO (mzw) Support tcgen05.cp // TODO (mzw) Support tcgen05.cp
...@@ -1246,8 +1359,10 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, ...@@ -1246,8 +1359,10 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T,
: relative_wg_idx * (num_chunks_each_wg * meta.width); : relative_wg_idx * (num_chunks_each_wg * meta.width);
have_succeeded = true; have_succeeded = true;
Array<PrimExpr> args; Array<PrimExpr> args;
const char *bool_str = src_needs_pack ? "true" : "false";
args.push_back(StringImm(meta.intrinsics_name + "<" + args.push_back(StringImm(meta.intrinsics_name + "<" +
std::to_string(num_chunks_each_wg) + ">")); std::to_string(num_chunks_each_wg) + ", " +
bool_str + ">"));
args.push_back( args.push_back(
BufferLoad(src, {(int)logical_row_min, BufferLoad(src, {(int)logical_row_min,
(int)logical_col_min})); // Will be translated later (int)logical_col_min})); // Will be translated later
...@@ -1724,20 +1839,21 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const { ...@@ -1724,20 +1839,21 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
* GPU intrinsics. * GPU intrinsics.
* *
* @param args Array of PrimExpr TL-call arguments (see list above). * @param args Array of PrimExpr TL-call arguments (see list above).
* @param vmap Mapping from original buffer variables to actual Buffer objects.
*/ */
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) { Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args) {
ObjectPtr<Conv2DIm2ColOpNode> node = ObjectPtr<Conv2DIm2ColOpNode> node =
tvm::ffi::make_object<Conv2DIm2ColOpNode>(); tvm::ffi::make_object<Conv2DIm2ColOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])]; node->srcRegion_ = NormalizeToBufferRegion(args[0]);
node->dst = vmap[GetVarFromAccessPtr(args[1])]; node->dstRegion_ = NormalizeToBufferRegion(args[1]);
node->nhw_step = args[2]; node->src_ = node->srcRegion_->buffer;
node->c_step = args[3]; node->dst_ = node->dstRegion_->buffer;
node->kernel = args[4].as<IntImm>().value()->value; node->nhw_step_ = args[2];
node->stride = args[5].as<IntImm>().value()->value; node->c_step_ = args[3];
node->dilation = args[6].as<IntImm>().value()->value; node->kernel_ = args[4].as<IntImm>().value()->value;
node->padding = args[7].as<IntImm>().value()->value; node->stride_ = args[5].as<IntImm>().value()->value;
node->eviction_policy = args[8].as<IntImm>().value()->value; node->dilation_ = args[6].as<IntImm>().value()->value;
node->padding_ = args[7].as<IntImm>().value()->value;
node->eviction_policy_ = args[8].as<IntImm>().value()->value;
data_ = std::move(node); data_ = std::move(node);
} }
...@@ -1788,24 +1904,24 @@ TileOperator Conv2DIm2ColOpNode::Clone() const { ...@@ -1788,24 +1904,24 @@ TileOperator Conv2DIm2ColOpNode::Clone() const {
Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const { arith::Analyzer *analyzer) const {
ICHECK(TargetIsHopper(T.target)); ICHECK(TargetIsHopper(T.target));
ICHECK(src.scope() == "global" && ICHECK(src_.scope() == "global" &&
(dst.scope() == "shared.dyn" || dst.scope() == "shared")); (dst_.scope() == "shared.dyn" || dst_.scope() == "shared"));
ICHECK(src->shape.size() == 4); ICHECK(src_->shape.size() == 4);
ICHECK(dst->shape.size() == 2); ICHECK(dst_->shape.size() == 2);
ICHECK(src->dtype == dst->dtype); ICHECK(src_->dtype == dst_->dtype);
Layout shared_layout; Layout shared_layout;
if (T.layout_map.count(dst)) { if (T.layout_map.count(dst_)) {
shared_layout = T.layout_map[dst]; shared_layout = T.layout_map[dst_];
} }
TMAIm2ColDesc desc; TMAIm2ColDesc desc;
desc.rank = src->shape.size(); desc.rank = src_->shape.size();
desc.data_type = to_CUtensorMapDataType(src->dtype); desc.data_type = to_CUtensorMapDataType(src_->dtype);
desc.global_addr = src->data; desc.global_addr = src_->data;
desc.global_shape = ReverseArray(src->shape); desc.global_shape = ReverseArray(src_->shape);
if (!src->strides.empty()) { if (!src_->strides.empty()) {
desc.global_stride = ReverseArray(src->strides); desc.global_stride = ReverseArray(src_->strides);
} else { } else {
// Create stride from shape // Create stride from shape
PrimExpr stride = 1; PrimExpr stride = 1;
...@@ -1819,13 +1935,13 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, ...@@ -1819,13 +1935,13 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride; ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes // Make global stride in bytes
desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { desc.global_stride = desc.global_stride.Map([&](PrimExpr e) {
return cast(DataType::Int(64), e) * src->dtype.bytes(); return cast(DataType::Int(64), e) * src_->dtype.bytes();
}); });
desc.elem_stride = {1, stride, stride, 1}; desc.elem_stride = {1, stride_, stride_, 1};
desc.lower_corner = {-padding, -padding}; desc.lower_corner = {-padding_, -padding_};
desc.upper_corner = {-padding, -padding}; desc.upper_corner = {-padding_, -padding_};
desc.smem_box_pixel = Downcast<IntImm>(dst->shape[0])->value; desc.smem_box_pixel = Downcast<IntImm>(dst_->shape[0])->value;
desc.smem_box_channel = Downcast<IntImm>(dst->shape[1])->value; desc.smem_box_channel = Downcast<IntImm>(dst_->shape[1])->value;
desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE); desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
...@@ -1839,15 +1955,15 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, ...@@ -1839,15 +1955,15 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
if (StructuralEqual()(shared_layout, if (StructuralEqual()(shared_layout,
makeQuarterBankSwizzleLayout(*stride, *continuous, makeQuarterBankSwizzleLayout(*stride, *continuous,
dst->dtype.bits()))) { dst_->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B); desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_32B);
} else if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout( } else if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout(
*stride, *continuous, *stride, *continuous,
dst->dtype.bits()))) { dst_->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B); desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
} else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout( } else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout(
*stride, *continuous, *stride, *continuous,
dst->dtype.bits()))) { dst_->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B); desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
} else { } else {
ICHECK(0) << "Cannot detect TMA layout."; ICHECK(0) << "Cannot detect TMA layout.";
...@@ -1866,43 +1982,43 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, ...@@ -1866,43 +1982,43 @@ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T,
<< "Currently can only support divisible channel case"; << "Currently can only support divisible channel case";
global_coords.push_back( global_coords.push_back(
FloorMod(c_step * desc.smem_box_channel, desc.global_shape[0])); FloorMod(c_step_ * desc.smem_box_channel, desc.global_shape[0]));
image_offset.push_back( image_offset.push_back(
dilation * dilation_ *
FloorMod(FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0]), FloorMod(FloorDiv(c_step_ * desc.smem_box_channel, desc.global_shape[0]),
kernel)); kernel_));
image_offset.push_back(dilation * FloorDiv(c_step * desc.smem_box_channel, image_offset.push_back(dilation_ * FloorDiv(c_step_ * desc.smem_box_channel,
desc.global_shape[0] * kernel)); desc.global_shape[0] * kernel_));
PrimExpr h_dim = PrimExpr h_dim =
FloorDiv(src->shape[1] + 2 * padding - (kernel - 1) * dilation - 1, FloorDiv(src_->shape[1] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1,
stride) + stride_) +
1; 1;
PrimExpr w_dim = PrimExpr w_dim =
FloorDiv(src->shape[2] + 2 * padding - (kernel - 1) * dilation - 1, FloorDiv(src_->shape[2] + 2 * padding_ - (kernel_ - 1) * dilation_ - 1,
stride) + stride_) +
1; 1;
global_coords.push_back( global_coords.push_back(
stride * FloorMod(nhw_step * desc.smem_box_pixel, w_dim) - padding); stride_ * FloorMod(nhw_step_ * desc.smem_box_pixel, w_dim) - padding_);
global_coords.push_back( global_coords.push_back(
stride * stride_ *
FloorMod(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim), h_dim) - FloorMod(FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim), h_dim) -
padding); padding_);
global_coords.push_back( global_coords.push_back(
FloorDiv(nhw_step * desc.smem_box_pixel, w_dim * h_dim)); FloorDiv(nhw_step_ * desc.smem_box_pixel, w_dim * h_dim));
Array<PrimExpr> args; Array<PrimExpr> args;
args.reserve(desc.rank * 2 + 2); args.reserve(desc.rank * 2 + 2);
args.push_back(create_desc); args.push_back(create_desc);
args.push_back(0); // mbar placeholder args.push_back(0); // mbar placeholder
auto dst_buffer = T.buffer_remap.count(dst) ? T.buffer_remap[dst] : dst; auto dst_buffer = T.buffer_remap.count(dst_) ? T.buffer_remap[dst_] : dst_;
auto shared_addr = dst_buffer.access_ptr(2); auto shared_addr = dst_buffer.access_ptr(2);
args.push_back(shared_addr); args.push_back(shared_addr);
for (auto coord : global_coords) for (auto coord : global_coords)
args.push_back(coord); args.push_back(coord);
for (auto offset : image_offset) for (auto offset : image_offset)
args.push_back(offset); args.push_back(offset);
args.push_back(this->eviction_policy); args.push_back(this->eviction_policy_);
Stmt tma_copy = Stmt tma_copy =
IfThenElse(EQ(T.thread_var, T.thread_bounds->min), IfThenElse(EQ(T.thread_var, T.thread_bounds->min),
Evaluate(Call(DataType::Handle(), tma_load_im2col(), args))); Evaluate(Call(DataType::Handle(), tma_load_im2col(), args)));
...@@ -1944,12 +2060,37 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const { ...@@ -1944,12 +2060,37 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
return args; return args;
} }
void CopyNode::CollectFragmentLayouts(const PrimExpr &expr,
const Map<Var, PrimExpr> &let_var_to_expr,
const LayoutMap &existing_layouts,
PrimExpr thread_extent,
Range thread_bounds,
Map<Buffer, Layout> &result_map) const {
PostOrderVisit(expr, [&](const ObjectRef &node) {
if (auto bl = node.as<BufferLoadNode>()) {
if (bl->buffer.scope() == "local.fragment" &&
!existing_layouts.count(bl->buffer) &&
!result_map.count(bl->buffer)) {
auto f = Fragment::FullyReplicated(bl->buffer->shape, thread_extent);
result_map.Set(bl->buffer, f->BindThreadRange(thread_bounds));
}
} else if (auto var_node = node.as<VarNode>()) {
auto var = tvm::ffi::GetRef<Var>(var_node);
if (let_var_to_expr.count(var)) {
CollectFragmentLayouts(let_var_to_expr[var], let_var_to_expr,
existing_layouts, thread_extent, thread_bounds,
result_map);
}
}
});
}
// Register the Copy operation with TVM's TIR system // Register the Copy operation with TVM's TIR system
// This makes the copy operation available for use in TVM programs // This makes the copy operation available for use in TVM programs
// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, // - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
// eviction_policy // eviction_policy
// - Marked as opaque since it has side effects (memory writes) // - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Copy, copy) TIR_REGISTER_TL_TILE_OP(Copy, copy)
.set_num_inputs(5) .set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
...@@ -1974,7 +2115,7 @@ LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -1974,7 +2115,7 @@ LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T,
// - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride, // - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride,
// dilation, padding, eviction_policy // dilation, padding, eviction_policy
// - Marked as opaque since it has side effects (memory writes) // - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col) TIR_REGISTER_TL_TILE_OP(Conv2DIm2ColOp, c2d_im2col)
.set_num_inputs(9) .set_num_inputs(9)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -269,6 +269,28 @@ protected: ...@@ -269,6 +269,28 @@ protected:
* @return Reference to the singleton TVM Op representing this operator. * @return Reference to the singleton TVM Op representing this operator.
*/ */
TileOperator Clone() const; TileOperator Clone() const;
private:
/*!
* \brief Collect fragment buffers from expression and create fully replicated
* layouts.
*
* Recursively searches the expression for BufferLoad nodes with
* "local.fragment" scope, following let bindings. For each found fragment
* buffer, creates a fully replicated layout and adds it to result_map.
*
* \param expr Expression to search.
* \param let_var_to_expr Map from let variables to their bound expressions.
* \param existing_layouts Existing layout map to check for already-inferred
* layouts. \param thread_extent Number of threads for replication. \param
* thread_bounds Thread bounds for binding the layout. \param result_map
* Output map to store collected fragment layouts.
*/
void CollectFragmentLayouts(const PrimExpr &expr,
const Map<Var, PrimExpr> &let_var_to_expr,
const LayoutMap &existing_layouts,
PrimExpr thread_extent, Range thread_bounds,
Map<Buffer, Layout> &result_map) const;
}; };
class Copy : public TileOperator { class Copy : public TileOperator {
...@@ -280,7 +302,7 @@ public: ...@@ -280,7 +302,7 @@ public:
* \param args Expression arguments for the copy. * \param args Expression arguments for the copy.
* \param vmap Buffer variable mapping. * \param vmap Buffer variable mapping.
*/ */
TVM_DLL Copy(Array<PrimExpr> args, BufferMap vmap); TVM_DLL Copy(Array<PrimExpr> args);
/*! /*!
* \brief Get the TVM Op handle corresponding to this Copy op. * \brief Get the TVM Op handle corresponding to this Copy op.
...@@ -296,14 +318,16 @@ public: ...@@ -296,14 +318,16 @@ public:
*/ */
class Conv2DIm2ColOpNode : public TileOperatorNode { class Conv2DIm2ColOpNode : public TileOperatorNode {
public: public:
Buffer src, dst; // Source (input feature map) and destination (im2col matrix) BufferRegion srcRegion_, dstRegion_;
int stride; // Stride for convolution Buffer src_,
int padding; // Padding amount dst_; // Source (input feature map) and destination (im2col matrix)
int dilation; // Dilation factor int stride_; // Stride for convolution
int kernel; // Kernel size int padding_; // Padding amount
int eviction_policy; // Cache eviction policy int dilation_; // Dilation factor
PrimExpr nhw_step; // Step size in NHW dimensions int kernel_; // Kernel size
PrimExpr c_step; // Step size in channel dimension int eviction_policy_; // Cache eviction policy
PrimExpr nhw_step_; // Step size in NHW dimensions
PrimExpr c_step_; // Step size in channel dimension
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Conv2DIm2Col", Conv2DIm2ColOpNode, TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Conv2DIm2Col", Conv2DIm2ColOpNode,
TileOperatorNode); TileOperatorNode);
...@@ -311,13 +335,15 @@ public: ...@@ -311,13 +335,15 @@ public:
static void RegisterReflection() { static void RegisterReflection() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::ObjectDef<Conv2DIm2ColOpNode>() refl::ObjectDef<Conv2DIm2ColOpNode>()
.def_ro("src", &Conv2DIm2ColOpNode::src) .def_ro("srcRegion", &Conv2DIm2ColOpNode::srcRegion_)
.def_ro("dst", &Conv2DIm2ColOpNode::dst) .def_ro("dstRegion", &Conv2DIm2ColOpNode::dstRegion_)
.def_ro("stride", &Conv2DIm2ColOpNode::stride) .def_ro("src", &Conv2DIm2ColOpNode::src_)
.def_ro("padding", &Conv2DIm2ColOpNode::padding) .def_ro("dst", &Conv2DIm2ColOpNode::dst_)
.def_ro("dilation", &Conv2DIm2ColOpNode::dilation) .def_ro("stride", &Conv2DIm2ColOpNode::stride_)
.def_ro("kernel", &Conv2DIm2ColOpNode::kernel) .def_ro("padding", &Conv2DIm2ColOpNode::padding_)
.def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy); .def_ro("dilation", &Conv2DIm2ColOpNode::dilation_)
.def_ro("kernel", &Conv2DIm2ColOpNode::kernel_)
.def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy_);
} }
/*! /*!
...@@ -342,7 +368,7 @@ class Conv2DIm2ColOp : public TileOperator { ...@@ -342,7 +368,7 @@ class Conv2DIm2ColOp : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Conv2DIm2ColOp, TileOperator, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Conv2DIm2ColOp, TileOperator,
Conv2DIm2ColOpNode); Conv2DIm2ColOpNode);
TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap); TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h" #include "../transform/loop_vectorize.h"
#include "builtin.h" #include "builtin.h"
#include "region.h" #include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -52,62 +52,18 @@ using namespace tir; ...@@ -52,62 +52,18 @@ using namespace tir;
* value]. * value].
* - args[0]: destination access (BufferLoad or pointer expression). * - args[0]: destination access (BufferLoad or pointer expression).
* - args[1]: value to fill (scalar or vector). * - args[1]: value to fill (scalar or vector).
* @param vmap Mapping from buffer variables to Buffer objects; used to resolve
* the destination when args[0] is not a BufferLoad.
* *
* Notes: * Notes:
* - The constructor enforces constraints (e.g., stride == 1 ramps, constant * - The constructor enforces constraints (e.g., stride == 1 ramps, constant
* lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out * lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out
* of bounds. * of bounds.
*/ */
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) { Fill::Fill(Array<PrimExpr> args) {
ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>(); ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();
// Case 1: Region descriptor call (tl.region) BufferRegion region = NormalizeToBufferRegion(args[0]);
if (const auto *call = args[0].as<CallNode>()) { node->dst = region->buffer;
if (call->op.same_as(RegionOp::Get())) { node->region = region->region;
auto region = RegionOp(call->args, vmap);
node->dst = region->GetBuffer();
node->region = region->GetRanges();
} else if (call->op.same_as(builtin::tvm_access_ptr())) {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
} else {
ICHECK(false) << "Unsupported call op in tl.fill: "
<< Downcast<Op>(call->op)->name;
}
// Case 2: Explicit BufferRegion (legacy path)
} else if (args[0]->IsInstance<BufferRegionNode>()) {
auto region = Downcast<BufferRegion>(args[0]);
node->dst = region->buffer;
node->region = region->region;
// Case 3: Vector/scalar region expressed via BufferLoad indices
} else if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]);
for (const auto &index : buffer_load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
CHECK(ramp->stride.as<IntImmNode>()->value == 1)
<< "Only stride 1 ramps are supported";
const auto *lanes = ramp->lanes.as<IntImmNode>();
CHECK(lanes)
<< "Scalable vectors not supported in BufferRegion conversion";
node->region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
node->region.push_back(Range::FromMinExtent(index, 1));
}
}
node->dst = buffer_load->buffer;
// Case 4: Access pointer, fill the full buffer
} else {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
}
if (args[1]->dtype != node->dst->dtype) { if (args[1]->dtype != node->dst->dtype) {
node->value = Cast(node->dst->dtype, args[1]); node->value = Cast(node->dst->dtype, args[1]);
...@@ -202,12 +158,17 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -202,12 +158,17 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (dst.scope() == "local.fragment") { if (dst.scope() == "local.fragment") {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer, par_op->InferLayout({T.target,
false, T.buffer_remap}, T.thread_bounds,
T.layout_map,
analyzer,
false,
T.buffer_remap,
{}},
InferLevel::kFree); InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout()); par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop); auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
if (par_op->GetPredicate(T.thread_var).defined()) { if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(), return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop); vectorized_thread_loop);
...@@ -215,17 +176,22 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -215,17 +176,22 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop; return vectorized_thread_loop;
} else if (dst.scope() == "local") { } else if (dst.scope() == "local") {
auto init_loop = MakeSIMTLoop(analyzer); auto init_loop = MakeSIMTLoop(analyzer);
auto vectorized_thread_loop = VectorizeLoop(init_loop); auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer);
return vectorized_thread_loop; return vectorized_thread_loop;
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" || } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" ||
dst.scope() == "global") { dst.scope() == "global") {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer, par_op->InferLayout({T.target,
false, T.buffer_remap}, T.thread_bounds,
T.layout_map,
analyzer,
false,
T.buffer_remap,
{}},
InferLevel::kFree); InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout()); par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop); auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
if (par_op->GetPredicate(T.thread_var).defined()) { if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(), return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop); vectorized_thread_loop);
...@@ -253,7 +219,7 @@ LayoutMap FillNode::InferLayout(const LayoutInferArgs &T, ...@@ -253,7 +219,7 @@ LayoutMap FillNode::InferLayout(const LayoutInferArgs &T,
return {}; return {};
} }
TIR_REGISTER_TL_OP(Fill, fill) TIR_REGISTER_TL_TILE_OP(Fill, fill)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -45,7 +45,7 @@ private: ...@@ -45,7 +45,7 @@ private:
class Fill : public TileOperator { class Fill : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fill, TileOperator, FillNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fill, TileOperator, FillNode);
TVM_DLL Fill(Array<PrimExpr> args, BufferMap vmap); TVM_DLL Fill(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <tvm/tir/op_attr_types.h> #include <tvm/tir/op_attr_types.h>
#include "../target/utils.h" #include "../target/utils.h"
#include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -29,12 +30,14 @@ using namespace tir; ...@@ -29,12 +30,14 @@ using namespace tir;
* @param args TL operator arguments: expects at least two elements where * @param args TL operator arguments: expects at least two elements where
* `args[0]` is an access pointer identifying the reducer variable * `args[0]` is an access pointer identifying the reducer variable
* and `args[1]` is an integer encoding a `ReducerOpType` (e.g., Sum/Max/Min). * and `args[1]` is an integer encoding a `ReducerOpType` (e.g., Sum/Max/Min).
* @param vmap Mapping from variables to Buffers used to look up the reducer
* Buffer.
*/ */
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) { FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args) {
auto node = tvm::ffi::make_object<FinalizeReducerOpNode>(); auto node = tvm::ffi::make_object<FinalizeReducerOpNode>();
node->reducer = vmap[GetVarFromAccessPtr(args[0])]; // Normalize any supported region expression
// (BufferRegion/BufferLoad/tl.region) to a BufferRegion, then take the
// underlying Buffer as reducer.
auto region = NormalizeToBufferRegion(args[0]);
node->reducer = region->buffer;
node->op = (ReducerOpType)*as_const_int(args[1]); node->op = (ReducerOpType)*as_const_int(args[1]);
data_ = std::move(node); data_ = std::move(node);
} }
...@@ -156,7 +159,7 @@ TileOperator FinalizeReducerOpNode::Clone() const { ...@@ -156,7 +159,7 @@ TileOperator FinalizeReducerOpNode::Clone() const {
return TileOperator(node); return TileOperator(node);
} }
TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer) TIR_REGISTER_TL_TILE_OP(FinalizeReducerOp, finalize_reducer)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -48,7 +48,7 @@ class FinalizeReducerOp : public TileOperator { ...@@ -48,7 +48,7 @@ class FinalizeReducerOp : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FinalizeReducerOp, TileOperator, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FinalizeReducerOp, TileOperator,
FinalizeReducerOpNode); FinalizeReducerOpNode);
TVM_DLL FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap); TVM_DLL FinalizeReducerOp(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../target/utils.h" #include "../target/utils.h"
#include "region.h"
#include "tcgen5_meta.h" #include "tcgen5_meta.h"
#include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -41,106 +41,21 @@ using namespace tir; ...@@ -41,106 +41,21 @@ using namespace tir;
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)] * (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
* *
* @note If `kPack` is provided it must be 1; otherwise the constructor * @note If `kPack` is provided it must be 1; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is * fails with an ICHECK (runtime assertion). No other validation is
* performed here. * performed here.
*/ */
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region) // NormalizeToBufferRegion moved to src/op/utils.{h,cc}
// to BufferRegion
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in GEMM region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in GEMM region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap[var];
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
}
LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg;
throw; // Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims";
// Compute row-major strides
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements) // MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
PrimExpr extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
// ptype and return handle Gemm::Gemm(Array<PrimExpr> args) {
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>(); ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0], vmap); node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->bRegion_ = NormalizeToBufferRegion(args[1], vmap); node->bRegion_ = NormalizeToBufferRegion(args[1]);
node->cRegion_ = NormalizeToBufferRegion(args[2], vmap); node->cRegion_ = NormalizeToBufferRegion(args[2]);
node->a_ = node->aRegion_->buffer; node->a_ = node->aRegion_->buffer;
node->b_ = node->bRegion_->buffer; node->b_ = node->bRegion_->buffer;
...@@ -165,11 +80,14 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) { ...@@ -165,11 +80,14 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() > 15) { if (args.size() > 15) {
node->wgWait_ = args[15].as<IntImm>().value()->value; node->wgWait_ = args[15].as<IntImm>().value()->value;
} }
node->mbarPtr_ = args[16]; if (args.size() > 16) {
if (node->mbarPtr_.as<CallNode>()) { if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)]; node->mbarRegion_ =
} else { NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = std::nullopt; node->mbar_ = node->mbarRegion_->buffer;
} else {
node->mbar_ = std::nullopt;
}
} }
node->cCoords_ = Array<PrimExpr>( node->cCoords_ = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()}); {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
...@@ -443,13 +361,7 @@ bool GemmNode::checkWgmma() const { ...@@ -443,13 +361,7 @@ bool GemmNode::checkWgmma() const {
if (c_->dtype == DataType::Float(16)) { if (c_->dtype == DataType::Float(16)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16)) if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0; return k_ % 16 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3()) else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else else
return false; return false;
...@@ -462,13 +374,7 @@ bool GemmNode::checkWgmma() const { ...@@ -462,13 +374,7 @@ bool GemmNode::checkWgmma() const {
else if (a_->dtype == DataType::Float(32) && else if (a_->dtype == DataType::Float(32) &&
b_->dtype == DataType::Float(32)) b_->dtype == DataType::Float(32))
return (!transA_) && transB_ && k_ % 8 == 0; return (!transA_) && transB_ && k_ % 8 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3()) else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else else
return false; return false;
...@@ -535,9 +441,12 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -535,9 +441,12 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst); policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
// Build access pointers from regions locally // Build access pointers from regions locally
PrimExpr Aptr = MakeAccessPtrFromRegion(aRegion_, /*r*/ 1); PrimExpr Aptr =
PrimExpr Bptr = MakeAccessPtrFromRegion(bRegion_, /*r*/ 1); MakeAccessPtrFromRegion(aRegion_, /*r*/ 1, /*require_2d*/ true);
PrimExpr Cptr = MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3); PrimExpr Bptr =
MakeAccessPtrFromRegion(bRegion_, /*r*/ 1, /*require_2d*/ true);
PrimExpr Cptr =
MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3, /*require_2d*/ true);
std::stringstream ss; std::stringstream ss;
std::string op_name; std::string op_name;
...@@ -579,11 +488,13 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -579,11 +488,13 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_; auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_;
Array<PrimExpr> new_args; Array<PrimExpr> new_args;
auto mbarPtr =
MakeAccessPtrFromRegion(mbarRegion_, /*rw*/ 3, /*require_2d*/ true);
new_args.push_back(StringImm(ss.str())); new_args.push_back(StringImm(ss.str()));
new_args.push_back(Aptr); new_args.push_back(Aptr);
new_args.push_back(Bptr); new_args.push_back(Bptr);
new_args.push_back(BufferLoad(C_buffer, cCoords_)); new_args.push_back(BufferLoad(C_buffer, cCoords_));
new_args.push_back(mbarPtr_); new_args.push_back(mbarPtr);
new_args.push_back(clearAccum_); new_args.push_back(clearAccum_);
auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args); auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
...@@ -908,7 +819,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, ...@@ -908,7 +819,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
return results; return results;
} }
TIR_REGISTER_TL_OP(Gemm, gemm) TIR_REGISTER_TL_TILE_OP(Gemm, gemm)
.set_num_inputs(5) .set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -97,7 +97,7 @@ public: ...@@ -97,7 +97,7 @@ public:
// only will be enabled under cdna mfma instructions // only will be enabled under cdna mfma instructions
int kPack_ = 1; int kPack_ = 1;
int wgWait_ = 0; int wgWait_ = 0;
PrimExpr mbarPtr_; BufferRegion mbarRegion_;
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_; Array<PrimExpr> cCoords_;
mutable GemmWarpPolicy policy_; mutable GemmWarpPolicy policy_;
...@@ -144,7 +144,7 @@ private: ...@@ -144,7 +144,7 @@ private:
class Gemm : public TileOperator { class Gemm : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode);
TVM_DLL Gemm(Array<PrimExpr> args, BufferMap vmap); TVM_DLL Gemm(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -12,100 +12,17 @@ ...@@ -12,100 +12,17 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../target/utils.h" #include "../target/utils.h"
#include "region.h"
#include "tcgen5_meta.h" #include "tcgen5_meta.h"
#include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region) // NormalizeToBufferRegion moved to src/op/utils.{h,cc}
// to BufferRegion
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in GEMM region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in GEMM region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap.at(var);
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
}
LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg; // MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
throw; // Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims";
// Compute row-major strides
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements)
PrimExpr extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
/** /**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer * @brief Construct a Gemm operator from serialized TL arguments and a buffer
...@@ -128,19 +45,17 @@ static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region, ...@@ -128,19 +45,17 @@ static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)] * (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
* *
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is * fails with an ICHECK (runtime assertion). No other validation is
* performed here. * performed here.
*/ */
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { GemmPy::GemmPy(Array<PrimExpr> args) {
ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>(); ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0], vmap); node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->bRegion_ = NormalizeToBufferRegion(args[1], vmap); node->bRegion_ = NormalizeToBufferRegion(args[1]);
node->cRegion_ = NormalizeToBufferRegion(args[2], vmap); node->cRegion_ = NormalizeToBufferRegion(args[2]);
node->a_ = node->aRegion_->buffer; node->a_ = node->aRegion_->buffer;
node->b_ = node->bRegion_->buffer; node->b_ = node->bRegion_->buffer;
...@@ -165,11 +80,12 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { ...@@ -165,11 +80,12 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() > 15) { if (args.size() > 15) {
node->wgWait_ = args[15].as<IntImm>().value()->value; node->wgWait_ = args[15].as<IntImm>().value()->value;
} }
node->mbarPtr_ = args[16]; if (args.size() > 16) {
if (node->mbarPtr_.as<CallNode>()) { if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)]; node->mbarRegion_ =
} else { NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = std::nullopt; node->mbar_ = node->mbarRegion_->buffer;
}
} }
node->cCoords_ = Array<PrimExpr>( node->cCoords_ = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()}); {args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
...@@ -219,7 +135,7 @@ GemmInst GemmPyNode::getGemmInst(int block_size, Target target) const { ...@@ -219,7 +135,7 @@ GemmInst GemmPyNode::getGemmInst(int block_size, Target target) const {
return GemmInst::kMFMA; return GemmInst::kMFMA;
} else if (TargetIsVolta(target) || TargetIsAmpere(target) || } else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
TargetIsTuring(target) || TargetIsHopper(target) || TargetIsTuring(target) || TargetIsHopper(target) ||
TargetIsSm100(target)) { TargetIsSm100(target) || TargetIsSM120(target)) {
return GemmInst::kMMA; return GemmInst::kMMA;
} else { } else {
ICHECK(0) << "Unsupported target for gemm: " << target->str(); ICHECK(0) << "Unsupported target for gemm: " << target->str();
...@@ -266,13 +182,7 @@ bool GemmPyNode::checkWgmma() const { ...@@ -266,13 +182,7 @@ bool GemmPyNode::checkWgmma() const {
if (c_->dtype == DataType::Float(16)) { if (c_->dtype == DataType::Float(16)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16)) if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0; return k_ % 16 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3()) else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else else
return false; return false;
...@@ -285,13 +195,7 @@ bool GemmPyNode::checkWgmma() const { ...@@ -285,13 +195,7 @@ bool GemmPyNode::checkWgmma() const {
else if (a_->dtype == DataType::Float(32) && else if (a_->dtype == DataType::Float(32) &&
b_->dtype == DataType::Float(32)) b_->dtype == DataType::Float(32))
return (!transA_) && transB_ && k_ % 8 == 0; return (!transA_) && transB_ && k_ % 8 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3()) else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else else
return false; return false;
...@@ -402,7 +306,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T, ...@@ -402,7 +306,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
return results; return results;
} }
TIR_REGISTER_TL_OP(GemmPy, gemm_py) TIR_REGISTER_TL_TILE_OP(GemmPy, gemm_py)
.set_num_inputs(5) .set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
...@@ -428,6 +332,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { ...@@ -428,6 +332,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
result.push_back(Integer(meta.atom_m)); result.push_back(Integer(meta.atom_m));
result.push_back(Integer(meta.atom_n)); result.push_back(Integer(meta.atom_n));
result.push_back(Integer(meta.atom_k)); result.push_back(Integer(meta.atom_k));
result.push_back(Integer(meta.enable_ws));
result.push_back(Integer(meta.enable_2cta));
} }
return result; return result;
}); });
......
...@@ -29,8 +29,8 @@ public: ...@@ -29,8 +29,8 @@ public:
int strideA_, strideB_; int strideA_, strideB_;
int offsetA_, offsetB_; int offsetA_, offsetB_;
PrimExpr clearAccum_ = const_false(); PrimExpr clearAccum_ = const_false();
PrimExpr mbarPtr_; BufferRegion mbarRegion_;
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA tir::Buffer mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_; Array<PrimExpr> cCoords_;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions // only will be enabled under cdna mfma instructions
...@@ -59,7 +59,8 @@ public: ...@@ -59,7 +59,8 @@ public:
.def_ro("offsetA", &GemmPyNode::offsetA_) .def_ro("offsetA", &GemmPyNode::offsetA_)
.def_ro("offsetB", &GemmPyNode::offsetB_) .def_ro("offsetB", &GemmPyNode::offsetB_)
.def_ro("clearAccum", &GemmPyNode::clearAccum_) .def_ro("clearAccum", &GemmPyNode::clearAccum_)
.def_ro("mbarPtr", &GemmPyNode::mbarPtr_) .def_ro("mbarRegion", &GemmPyNode::mbarRegion_)
.def_ro("mbar", &GemmPyNode::mbar_)
.def_ro("cCoords", &GemmPyNode::cCoords_) .def_ro("cCoords", &GemmPyNode::cCoords_)
.def_ro("kPack", &GemmPyNode::kPack_) .def_ro("kPack", &GemmPyNode::kPack_)
.def_ro("wgWait", &GemmPyNode::wgWait_) .def_ro("wgWait", &GemmPyNode::wgWait_)
...@@ -82,7 +83,7 @@ private: ...@@ -82,7 +83,7 @@ private:
class GemmPy : public TileOperator { class GemmPy : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode);
TVM_DLL GemmPy(Array<PrimExpr> args, BufferMap vmap); TVM_DLL GemmPy(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "../target/utils.h" #include "../target/utils.h"
#include "builtin.h" #include "builtin.h"
#include "gemm.h" #include "gemm.h"
#include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -79,16 +80,19 @@ std::pair<int, int> GemmSPWarpPolicyNode::computeWarpPartition(int M, int N, ...@@ -79,16 +80,19 @@ std::pair<int, int> GemmSPWarpPolicyNode::computeWarpPartition(int M, int N,
* The populated GemmSPNode is stored in the instance's internal data_ pointer. * The populated GemmSPNode is stored in the instance's internal data_ pointer.
* *
* @param args Positional TL call arguments in the above order. * @param args Positional TL call arguments in the above order.
* @param vmap BufferMap mapping access pointers (from args) to Buffer objects.
* *
* @note An ICHECK failure is raised if a provided kPack is not 1 or 2. * @note An ICHECK failure is raised if a provided kPack is not 1 or 2.
*/ */
GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) { GemmSP::GemmSP(Array<PrimExpr> args) {
ObjectPtr<GemmSPNode> node = tvm::ffi::make_object<GemmSPNode>(); ObjectPtr<GemmSPNode> node = tvm::ffi::make_object<GemmSPNode>();
node->a_ = vmap[GetVarFromAccessPtr(args[0])]; node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->e_ = vmap[GetVarFromAccessPtr(args[1])]; node->eRegion_ = NormalizeToBufferRegion(args[1]);
node->b_ = vmap[GetVarFromAccessPtr(args[2])]; node->bRegion_ = NormalizeToBufferRegion(args[2]);
node->c_ = vmap[GetVarFromAccessPtr(args[3])]; node->cRegion_ = NormalizeToBufferRegion(args[3]);
node->a_ = node->aRegion_->buffer;
node->e_ = node->eRegion_->buffer;
node->b_ = node->bRegion_->buffer;
node->c_ = node->cRegion_->buffer;
node->transA_ = args[4].as<Bool>().value(); node->transA_ = args[4].as<Bool>().value();
node->transB_ = args[5].as<Bool>().value(); node->transB_ = args[5].as<Bool>().value();
node->m_ = args[6].as<IntImm>().value()->value; node->m_ = args[6].as<IntImm>().value()->value;
...@@ -298,12 +302,25 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, ...@@ -298,12 +302,25 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
return results; return results;
} }
TIR_REGISTER_TL_OP(GemmSP, gemm_sp) TIR_REGISTER_TL_TILE_OP(GemmSP, gemm_sp)
.set_num_inputs(5) .set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK() { GemmSPNode::RegisterReflection(); } TVM_REGISTER_OP("tl.GemmSPWarpPolicy")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmSPWarpPolicy");
TVM_FFI_STATIC_INIT_BLOCK() {
GemmSPNode::RegisterReflection();
GemmSPWarpPolicyNode::RegisterReflection();
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(
"tl.GemmSPWarpPolicyComputeWarpPartition",
[](GemmSPWarpPolicy policy, int M, int N, int block_size, Target target,
bool use_wgmma, int bits) {
policy->computeWarpPartition(M, N, block_size, target, use_wgmma, bits);
return;
});
}
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment