region.h 3.7 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
/*!
 * \file tl/op/op.h
 * \brief Tile library operations.
 *
 */

#ifndef TVM_TL_OP_REGION_H_
#define TVM_TL_OP_REGION_H_

#include "./operator.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
#include <tvm/target/target.h>
#include <tvm/tir/buffer.h>

/**
 * Tile operator representing a memory region (buffer + ranges) used by TL
 * passes.
 *
 * Encapsulates the target tir::Buffer, the region extents as an Array<Range>,
 * and an access mask that indicates permitted or intended accesses for lowering
 * and layout inference.
 */

/**
 * Lower this RegionOp into a TIR statement representing the region access.
 *
 * @param T Lowering-time arguments (e.g., loop/build context and value
 * mappings).
 * @param analyzer Arithmetic analyzer used to simplify and reason about
 * expressions.
 * @return A tir::Stmt that implements the region access/mutation described by
 * this operator.
 */

/**
 * Infer the layout mapping for this region operator.
 *
 * Produces a LayoutMap describing how loop/axis indices map to buffer axes for
 * layout-aware scheduling and subsequent operators.
 *
 * @param T Layout inference arguments (e.g., input layouts and shapes).
 * @param level The inference detail level to use.
 * @return A LayoutMap describing inferred mappings for the operator.
 */

/**
 * Return true when this RegionOp represents the full buffer region (i.e.,
 * ranges cover the entire buffer extent).
 */

/**
 * Create a shallow copy of this operator as a TileOperator handle.
 *
 * @return A TileOperator that references a cloned RegionOpNode.
 */

/**
 * Construct a RegionOp from argument expressions and a buffer map.
 *
 * @param args Positional expressions used to instantiate the operator
 * (semantics depend on how RegionOp is invoked in TL pipelines).
 * @param vmap Mapping from Buffer to replacement Buffer or buffer metadata used
 * during creation.
 */

/**
 * Return the global Op registration for RegionOp.
 *
 * @return Reference to the registered tvm::Op describing the RegionOp.
 */
namespace tvm {
namespace tl {

using namespace tir;

class RegionOpNode : public TileOperatorNode {
public:
  Buffer buffer_;
  Array<Range> ranges_;
  int access_mask_;

  static constexpr const char *_type_key = "tl.RegionOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(RegionOpNode, TileOperatorNode);

  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;

  const Buffer &GetBuffer() const { return buffer_; }
  const Array<Range> &GetRanges() const { return ranges_; }
  int GetAccessMask() const { return access_mask_; }
  bool IsFullRegion() const;

  TileOperator Clone() const override;

  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<RegionOpNode>()
        .def_ro("buffer", &RegionOpNode::buffer_)
        .def_ro("ranges", &RegionOpNode::ranges_)
        .def_ro("access_mask", &RegionOpNode::access_mask_);
  }

  bool SEqualReduce(const RegionOpNode *other, SEqualReducer equal) const {
    return equal(buffer_, other->buffer_) && equal(ranges_, other->ranges_) &&
           equal(access_mask_, other->access_mask_);
  }

  void SHashReduce(SHashReducer hash_reduce) const {
    hash_reduce(buffer_);
    hash_reduce(ranges_);
    hash_reduce(access_mask_);
  }

  static constexpr bool _type_has_method_sequal_reduce = true;
  static constexpr bool _type_has_method_shash_reduce = true;
};

class RegionOp : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(RegionOp, TileOperator, RegionOpNode);
  TVM_DLL RegionOp(Array<PrimExpr> args, BufferMap vmap);

  static const Op &Get();
};

} // namespace tl
} // namespace tvm

#endif // TVM_TL_OP_REGION_H_