reduce.h 6.01 KB
Newer Older
1
2
/*!
 * \file tl/op/reduce.h
3
 * \brief Reduction operators for tensor computations
4
5
6
7
8
 */

#ifndef TVM_TL_OP_REDUCE_H_
#define TVM_TL_OP_REDUCE_H_

9
#include "operator.h"
10
11

namespace tvm {
12

13
14
15
16
namespace tl {

using namespace tir;

17
18
19
20
21
22
23
/// Supported reduction operation types
enum class ReduceTypeEnum : uint8_t {
  kSum,    ///< Sum reduction
  kAbsSum, ///< Absolute sum reduction
  kMax,    ///< Maximum value reduction
  kMin,    ///< Minimum value reduction
  kAbsMax, ///< Maximum absolute value reduction
24
25
26
  kBitAnd, ///< Bitwise and reduction
  kBitOr,  ///< Bitwise or reduction
  kBitXor, ///< Bitwise xor reduction
27
28
29
30
31
32
};

/// Node class representing a reduction type
class ReduceTypeNode : public Object {
public:
  int type{-1}; ///< Internal type identifier
33
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceType", ReduceTypeNode, Object);
34
35
36
37
38
39
40
41
42
43
44
45

  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<ReduceTypeNode>().def_ro("type", &ReduceTypeNode::type);
  }

  /// Type checking methods
  bool isSum() const { return type == int(ReduceTypeEnum::kSum); }
  bool isAbsSum() const { return type == int(ReduceTypeEnum::kAbsSum); }
  bool isMax() const { return type == int(ReduceTypeEnum::kMax); }
  bool isMin() const { return type == int(ReduceTypeEnum::kMin); }
  bool isAbsMax() const { return type == int(ReduceTypeEnum::kAbsMax); }
46
47
48
  bool isBitAnd() const { return type == int(ReduceTypeEnum::kBitAnd); }
  bool isBitOr() const { return type == int(ReduceTypeEnum::kBitOr); }
  bool isBitXor() const { return type == int(ReduceTypeEnum::kBitXor); }
49
50
51
52
53
};

/// Wrapper class for reduction type with string-based construction
class ReduceType : public ObjectRef {
public:
54
55
  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceType, ObjectRef,
                                             ReduceTypeNode);
56
  TVM_DLL ReduceType(std::string type) {
57
    auto node = tvm::ffi::make_object<ReduceTypeNode>();
58
59
60
61
62
63
64
65
66
67
    if (type == "sum") {
      node->type = int(ReduceTypeEnum::kSum);
    } else if (type == "abssum") {
      node->type = int(ReduceTypeEnum::kAbsSum);
    } else if (type == "max") {
      node->type = int(ReduceTypeEnum::kMax);
    } else if (type == "absmax") {
      node->type = int(ReduceTypeEnum::kAbsMax);
    } else if (type == "min") {
      node->type = int(ReduceTypeEnum::kMin);
68
69
70
71
72
73
    } else if (type == "bitand") {
      node->type = int(ReduceTypeEnum::kBitAnd);
    } else if (type == "bitor") {
      node->type = int(ReduceTypeEnum::kBitOr);
    } else if (type == "bitxor") {
      node->type = int(ReduceTypeEnum::kBitXor);
74
75
76
77
78
    } else {
      LOG(FATAL) << "Invalid reduce type: " << type;
    }
    data_ = std::move(node);
  }
79
};
80

81
/// Node class for reduction operations
82
83
class ReduceOpNode : public TileOperatorNode {
public:
84
  tir::Buffer src, dst; ///< Source and destination buffers
85
86
87
88
89
  // Optional: keep the original regions used to construct this op
  BufferRegion srcRegion_, dstRegion_;
  int dim;         ///< Dimension to reduce along
  ReduceType type; ///< Type of reduction operation
  bool clear;      ///< Whether to clear destination before reduction
90

91
92
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceOp", ReduceOpNode,
                                    TileOperatorNode);
93

94
95
96
97
98
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<ReduceOpNode>()
        .def_ro("src", &ReduceOpNode::src)
        .def_ro("dst", &ReduceOpNode::dst)
99
100
        .def_ro("srcRegion", &ReduceOpNode::srcRegion_)
        .def_ro("dstRegion", &ReduceOpNode::dstRegion_)
101
102
103
104
105
106
        .def_ro("dim", &ReduceOpNode::dim)
        .def_ro("type", &ReduceOpNode::type)
        .def_ro("clear", &ReduceOpNode::clear);
  }

  /// Lower the operator to TIR statements
107
  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
108
  /// Infer memory layout for buffers
109
110
111
112
113
114
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;
  static const Op &Get();
  TileOperator Clone() const;

private:
115
  /// Generate initial value for reduction
116
  PrimExpr MakeInitValue() const;
117
  /// Generate reduction expression
118
  PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const;
119
  /// Generate codegen reducer string
120
121
122
  std::string MakeCodegenReducer() const;
};

123
/// Wrapper class for reduction operations
124
class ReduceOp : public TileOperator {
125
public:
126
127
  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceOp, TileOperator,
                                             ReduceOpNode);
128
  TVM_DLL ReduceOp(Array<PrimExpr> args, BufferMap vmap);
129
  static const Op &Get();
130
};
131

132
/// Node class for cumulative sum operations
133
134
class CumSumOpNode : public TileOperatorNode {
public:
135
  tir::Buffer src, dst; ///< Source and destination buffers
136
137
138
139
  // Optional: keep the original regions used to construct this op
  BufferRegion srcRegion_, dstRegion_;
  int dim;      ///< Dimension along which to compute cumulative sum
  bool reverse; ///< Whether to compute in reverse order
140
141
142
143
144
145
146
147
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode,
                                    TileOperatorNode);

  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<CumSumOpNode>()
        .def_ro("src", &CumSumOpNode::src)
        .def_ro("dst", &CumSumOpNode::dst)
148
149
        .def_ro("srcRegion", &CumSumOpNode::srcRegion_)
        .def_ro("dstRegion", &CumSumOpNode::dstRegion_)
150
151
152
        .def_ro("dim", &CumSumOpNode::dim)
        .def_ro("reverse", &CumSumOpNode::reverse);
  }
153
154
155
156
157
158
159
160

  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;
  static const Op &Get();
  TileOperator Clone() const;
};

161
/// Wrapper class for cumulative sum operations
162
163
class CumSumOp : public TileOperator {
public:
164
165
  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CumSumOp, TileOperator,
                                             CumSumOpNode);
166
167
  TVM_DLL CumSumOp(Array<PrimExpr> args, BufferMap vmap);
  static const Op &Get();
168
169
};

170
171
} // namespace tl
} // namespace tvm
172

173
#endif //  TVM_TL_OP_REDUCE_H_