reduce.h 6.02 KB
Newer Older
1
2
/*!
 * \file tl/op/reduce.h
3
 * \brief Reduction operators for tensor computations
4
5
6
7
8
 */

#ifndef TVM_TL_OP_REDUCE_H_
#define TVM_TL_OP_REDUCE_H_

9
#include "operator.h"
10
11

namespace tvm {
12

13
14
15
16
namespace tl {

using namespace tir;

17
18
19
20
21
22
23
/// Supported reduction operation types
enum class ReduceTypeEnum : uint8_t {
  kSum,    ///< Sum reduction
  kAbsSum, ///< Absolute sum reduction
  kMax,    ///< Maximum value reduction
  kMin,    ///< Minimum value reduction
  kAbsMax, ///< Maximum absolute value reduction
24
25
26
  kBitAnd, ///< Bitwise and reduction
  kBitOr,  ///< Bitwise or reduction
  kBitXor, ///< Bitwise xor reduction
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
};

/// Node class representing a reduction type
class ReduceTypeNode : public Object {
public:
  int type{-1}; ///< Internal type identifier
  static constexpr const char *_type_key = "tl.ReduceType";
  TVM_DECLARE_FINAL_OBJECT_INFO(ReduceTypeNode, Object);

  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<ReduceTypeNode>().def_ro("type", &ReduceTypeNode::type);
  }

  bool SEqualReduce(const ReduceTypeNode *other, SEqualReducer equal) const {
    return equal(type, other->type);
  }

  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(type); }

  static constexpr bool _type_has_method_sequal_reduce = true;
  static constexpr bool _type_has_method_shash_reduce = true;

  /// Type checking methods
  bool isSum() const { return type == int(ReduceTypeEnum::kSum); }
  bool isAbsSum() const { return type == int(ReduceTypeEnum::kAbsSum); }
  bool isMax() const { return type == int(ReduceTypeEnum::kMax); }
  bool isMin() const { return type == int(ReduceTypeEnum::kMin); }
  bool isAbsMax() const { return type == int(ReduceTypeEnum::kAbsMax); }
56
57
58
  bool isBitAnd() const { return type == int(ReduceTypeEnum::kBitAnd); }
  bool isBitOr() const { return type == int(ReduceTypeEnum::kBitOr); }
  bool isBitXor() const { return type == int(ReduceTypeEnum::kBitXor); }
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
};

/// Wrapper class for reduction type with string-based construction
class ReduceType : public ObjectRef {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(ReduceType, ObjectRef, ReduceTypeNode);
  TVM_DLL ReduceType(std::string type) {
    auto node = make_object<ReduceTypeNode>();
    if (type == "sum") {
      node->type = int(ReduceTypeEnum::kSum);
    } else if (type == "abssum") {
      node->type = int(ReduceTypeEnum::kAbsSum);
    } else if (type == "max") {
      node->type = int(ReduceTypeEnum::kMax);
    } else if (type == "absmax") {
      node->type = int(ReduceTypeEnum::kAbsMax);
    } else if (type == "min") {
      node->type = int(ReduceTypeEnum::kMin);
77
78
79
80
81
82
    } else if (type == "bitand") {
      node->type = int(ReduceTypeEnum::kBitAnd);
    } else if (type == "bitor") {
      node->type = int(ReduceTypeEnum::kBitOr);
    } else if (type == "bitxor") {
      node->type = int(ReduceTypeEnum::kBitXor);
83
84
85
86
87
    } else {
      LOG(FATAL) << "Invalid reduce type: " << type;
    }
    data_ = std::move(node);
  }
88
};
89

90
/// Node class for reduction operations
91
92
class ReduceOpNode : public TileOperatorNode {
public:
93
94
95
96
  tir::Buffer src, dst; ///< Source and destination buffers
  int dim;              ///< Dimension to reduce along
  ReduceType type;      ///< Type of reduction operation
  bool clear;           ///< Whether to clear destination before reduction
97

98
99
100
  static constexpr const char *_type_key = "tl.ReduceOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(ReduceOpNode, TileOperatorNode);

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<ReduceOpNode>()
        .def_ro("src", &ReduceOpNode::src)
        .def_ro("dst", &ReduceOpNode::dst)
        .def_ro("dim", &ReduceOpNode::dim)
        .def_ro("type", &ReduceOpNode::type)
        .def_ro("clear", &ReduceOpNode::clear);
  }

  bool SEqualReduce(const ReduceOpNode *other, SEqualReducer equal) const {
    return equal(src, other->src) && equal(dst, other->dst) &&
           equal(dim, other->dim) && equal(type, other->type) &&
           equal(clear, other->clear);
  }

  void SHashReduce(SHashReducer hash_reduce) const {
    hash_reduce(src);
    hash_reduce(dst);
    hash_reduce(dim);
    hash_reduce(type);
    hash_reduce(clear);
  }

  static constexpr bool _type_has_method_sequal_reduce = true;
  static constexpr bool _type_has_method_shash_reduce = true;

  /// Lower the operator to TIR statements
129
  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
130
  /// Infer memory layout for buffers
131
132
133
134
135
136
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;
  static const Op &Get();
  TileOperator Clone() const;

private:
137
  /// Generate initial value for reduction
138
  PrimExpr MakeInitValue() const;
139
  /// Generate reduction expression
140
  PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const;
141
  /// Generate codegen reducer string
142
143
144
  std::string MakeCodegenReducer() const;
};

145
/// Wrapper class for reduction operations
146
class ReduceOp : public TileOperator {
147
public:
148
149
  TVM_DEFINE_OBJECT_REF_METHODS(ReduceOp, TileOperator, ReduceOpNode);
  TVM_DLL ReduceOp(Array<PrimExpr> args, BufferMap vmap);
150
  static const Op &Get();
151
};
152

153
/// Node class for cumulative sum operations
154
155
class CumSumOpNode : public TileOperatorNode {
public:
156
157
158
  tir::Buffer src, dst; ///< Source and destination buffers
  int dim;              ///< Dimension along which to compute cumulative sum
  bool reverse;         ///< Whether to compute in reverse order
159
160
161
162
163
164
165
166
167
168
  static constexpr const char *_type_key = "tl.CumSumOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(CumSumOpNode, TileOperatorNode);

  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;
  static const Op &Get();
  TileOperator Clone() const;
};

169
/// Wrapper class for cumulative sum operations
170
171
172
173
174
class CumSumOp : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(CumSumOp, TileOperator, CumSumOpNode);
  TVM_DLL CumSumOp(Array<PrimExpr> args, BufferMap vmap);
  static const Op &Get();
175
176
};

177
178
} // namespace tl
} // namespace tvm
179

180
#endif //  TVM_TL_OP_REDUCE_H_