region.cc 3.84 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 * \file tl/op/region.cc
 * \brief Define region operator.
 *
 */

#include "region.h"
#include <tvm/tir/op.h>

namespace tvm {
namespace tl {
using namespace tir;

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
/**
 * @brief Construct a RegionOp from TL operator arguments.
 *
 * Parses the TL `region` operator call arguments to populate the RegionOpNode:
 * - Expects args[0] to be a `BufferLoad` whose `indices` are the per-dimension
 * minima.
 * - args[1] must be a constant integer used as the access mask.
 * - args[2 + i] provides the extent for dimension `i`.
 *
 * The constructor validates that the number of load indices equals `args.size()
 * - 2` and will abort via ICHECK on mismatch or if args[0] is not a
 * `BufferLoad`.
 *
 * Parameters:
 * - args: TL operator call arguments in the form
 *     [BufferLoad(min_i...), access_mask, extent_0, extent_1, ...,
 * extent_{n-1}] where n = number of dimensions.
 * - vmap: BufferMap passed through by the caller (not documented here as a
 * generic utility).
 */
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
  size_t n = args.size();
  size_t ndim = n - 2;
  auto load = args[0].as<BufferLoadNode>();
  ICHECK(load);
  ICHECK(load->indices.size() == ndim)
      << "load->indices.size() = " << load->indices << " ndim = " << ndim;
  Array<Range> ranges;
  for (size_t i = 0; i < ndim; i++) {
    PrimExpr min = load->indices[i];
    PrimExpr extent = args[2 + i];
    ranges.push_back(Range::FromMinExtent(min, extent));
  }
  ObjectPtr<RegionOpNode> node = make_object<RegionOpNode>();
  node->buffer_ = load->buffer;
  node->access_mask_ = static_cast<int>(*as_const_int(args[1]));
  node->ranges_ = ranges;
  data_ = std::move(node);
}

54
55
56
57
58
/**
 * @brief Create a copy of this RegionOpNode and return it as a TileOperator.
 *
 * @return TileOperator A new TileOperator that owns a copied RegionOpNode.
 */
59
60
61
62
63
TileOperator RegionOpNode::Clone() const {
  auto op = make_object<RegionOpNode>(*this);
  return RegionOp(op);
}

64
65
66
67
68
69
70
71
72
73
/**
 * @brief Check whether the region spans the entire underlying buffer.
 *
 * Returns true if for every dimension the range minimum is zero and the
 * range extent is structurally equal to the corresponding buffer shape
 * dimension. Otherwise returns false.
 *
 * @return true if the region covers the full buffer in all dimensions; false
 * otherwise.
 */
74
75
76
77
78
79
80
81
82
83
bool RegionOpNode::IsFullRegion() const {
  for (size_t i = 0; i < ranges_.size(); i++) {
    if (!is_zero(ranges_[i]->min))
      return false;
    if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i]))
      return false;
  }
  return true;
}

84
85
86
87
88
89
90
91
92
93
94
95
/**
 * @brief Lower the region operator to a TIR statement.
 *
 * Lowers this RegionOpNode into a TIR Stmt by delegating to the operator's
 * evaluation path (currently `Evaluate(0)`).
 *
 * @param T Lowering context (provides buffers, producers/consumers and other
 *          environment required for lowering).
 * @param analyzer Optional arithmetic analyzer used for simplification during
 *                 lowering.
 * @return Stmt The lowered TIR statement representing this region operation.
 */
96
97
98
99
Stmt RegionOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
  return Evaluate(0);
}

100
101
102
103
104
105
106
107
108
109
110
/**
 * @brief Infers data layout for the region operator.
 *
 * This operator does not provide any layout inference; the function always
 * returns an empty LayoutMap regardless of the provided arguments or inference
 * level.
 *
 * @param T Layout inference arguments (ignored).
 * @param level Inference granularity level (ignored).
 * @return LayoutMap Empty map indicating no inferred layouts.
 */
111
112
113
114
115
116
117
118
119
120
121
122
LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
  return {};
}

TIR_REGISTER_TL_OP(RegionOp, region)
    .set_num_inputs(-1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure));

} // namespace tl
} // namespace tvm