region.cc 2.79 KB
Newer Older
1
2
/*!
 * \file tl/op/region.cc
3
 * \brief Define region operator (bridge to carry BufferRegion via Call args).
4
 *
5
6
7
8
9
10
11
 * Notes:
 * - BufferLoad/Ramp cannot represent a general PrimExpr as a vector lane
 *   count. Dynamic extents like (H1 - H0) cannot be encoded as
 *   Ramp(lanes = H1 - H0), and lowering BufferRegion to BufferLoad loses the
 *   explicit extent information.
 * - tl.region carries both mins and extents in Call args and lets the backend
 *   reconstruct a BufferRegion faithfully.
12
13
14
15
16
17
18
19
20
 */

#include "region.h"
#include <tvm/tir/op.h>

namespace tvm {
namespace tl {
using namespace tir;

21
RegionOp::RegionOp(Array<PrimExpr> args) {
22
23
24
25
26
27
28
  size_t n = args.size();
  size_t ndim = n - 2;
  auto load = args[0].as<BufferLoadNode>();
  ICHECK(load);
  ICHECK(load->indices.size() == ndim)
      << "load->indices.size() = " << load->indices << " ndim = " << ndim;
  Array<Range> ranges;
29
  // Rebuild per-axis ranges from mins (BufferLoad indices) and provided extents
30
  for (size_t i = 0; i < ndim; i++) {
31
    PrimExpr index = load->indices[i];
32
    PrimExpr extent = args[2 + i];
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    if (const auto *ramp = index.as<RampNode>()) {
      const auto *stride_imm = ramp->stride.as<IntImmNode>();
      ICHECK(stride_imm && stride_imm->value == 1)
          << "RegionOp expects stride-1 Ramp for index";
      if (const auto *lanes_imm = ramp->lanes.as<IntImmNode>()) {
        if (const auto *ext_imm = extent.as<IntImmNode>()) {
          ICHECK_EQ(lanes_imm->value, ext_imm->value)
              << "Ramp lanes and provided extent must match";
        }
      }
      ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
    } else {
      ranges.push_back(Range::FromMinExtent(index, extent));
    }
47
  }
48
  ObjectPtr<RegionOpNode> node = tvm::ffi::make_object<RegionOpNode>();
49
50
51
52
53
54
55
  node->buffer_ = load->buffer;
  node->access_mask_ = static_cast<int>(*as_const_int(args[1]));
  node->ranges_ = ranges;
  data_ = std::move(node);
}

TileOperator RegionOpNode::Clone() const {
56
  auto op = tvm::ffi::make_object<RegionOpNode>(*this);
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  return RegionOp(op);
}

bool RegionOpNode::IsFullRegion() const {
  for (size_t i = 0; i < ranges_.size(); i++) {
    if (!is_zero(ranges_[i]->min))
      return false;
    if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i]))
      return false;
  }
  return true;
}

Stmt RegionOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
  return Evaluate(0);
}

LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
  return {};
}

79
TIR_REGISTER_TL_TILE_OP(RegionOp, region)
80
81
82
83
    .set_num_inputs(-1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure));

84
85
TVM_FFI_STATIC_INIT_BLOCK() { RegionOpNode::RegisterReflection(); }

86
87
} // namespace tl
} // namespace tvm