reduce.h 5.46 KB
Newer Older
1
2
/*!
 * \file tl/op/reduce.h
3
 * \brief Reduction operators for tensor computations
4
5
6
7
8
 */

#ifndef TVM_TL_OP_REDUCE_H_
#define TVM_TL_OP_REDUCE_H_

9
#include "operator.h"
10
11

namespace tvm {
12

13
14
15
16
namespace tl {

using namespace tir;

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
/// Supported reduction operation types
enum class ReduceTypeEnum : uint8_t {
  kSum,    ///< Sum reduction
  kAbsSum, ///< Absolute sum reduction
  kMax,    ///< Maximum value reduction
  kMin,    ///< Minimum value reduction
  kAbsMax, ///< Maximum absolute value reduction
};

/// Node class representing a reduction type
class ReduceTypeNode : public Object {
public:
  int type{-1}; ///< Internal type identifier
  static constexpr const char *_type_key = "tl.ReduceType";
  TVM_DECLARE_FINAL_OBJECT_INFO(ReduceTypeNode, Object);

  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<ReduceTypeNode>().def_ro("type", &ReduceTypeNode::type);
  }

  bool SEqualReduce(const ReduceTypeNode *other, SEqualReducer equal) const {
    return equal(type, other->type);
  }

  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(type); }

  static constexpr bool _type_has_method_sequal_reduce = true;
  static constexpr bool _type_has_method_shash_reduce = true;

  /// Type checking methods
  bool isSum() const { return type == int(ReduceTypeEnum::kSum); }
  bool isAbsSum() const { return type == int(ReduceTypeEnum::kAbsSum); }
  bool isMax() const { return type == int(ReduceTypeEnum::kMax); }
  bool isMin() const { return type == int(ReduceTypeEnum::kMin); }
  bool isAbsMax() const { return type == int(ReduceTypeEnum::kAbsMax); }
};

/// Wrapper class for reduction type with string-based construction
class ReduceType : public ObjectRef {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(ReduceType, ObjectRef, ReduceTypeNode);
  TVM_DLL ReduceType(std::string type) {
    auto node = make_object<ReduceTypeNode>();
    if (type == "sum") {
      node->type = int(ReduceTypeEnum::kSum);
    } else if (type == "abssum") {
      node->type = int(ReduceTypeEnum::kAbsSum);
    } else if (type == "max") {
      node->type = int(ReduceTypeEnum::kMax);
    } else if (type == "absmax") {
      node->type = int(ReduceTypeEnum::kAbsMax);
    } else if (type == "min") {
      node->type = int(ReduceTypeEnum::kMin);
    } else {
      LOG(FATAL) << "Invalid reduce type: " << type;
    }
    data_ = std::move(node);
  }
76
};
77

78
/// Node class for reduction operations
79
80
class ReduceOpNode : public TileOperatorNode {
public:
81
82
83
84
  tir::Buffer src, dst; ///< Source and destination buffers
  int dim;              ///< Dimension to reduce along
  ReduceType type;      ///< Type of reduction operation
  bool clear;           ///< Whether to clear destination before reduction
85

86
87
88
  static constexpr const char *_type_key = "tl.ReduceOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(ReduceOpNode, TileOperatorNode);

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<ReduceOpNode>()
        .def_ro("src", &ReduceOpNode::src)
        .def_ro("dst", &ReduceOpNode::dst)
        .def_ro("dim", &ReduceOpNode::dim)
        .def_ro("type", &ReduceOpNode::type)
        .def_ro("clear", &ReduceOpNode::clear);
  }

  bool SEqualReduce(const ReduceOpNode *other, SEqualReducer equal) const {
    return equal(src, other->src) && equal(dst, other->dst) &&
           equal(dim, other->dim) && equal(type, other->type) &&
           equal(clear, other->clear);
  }

  void SHashReduce(SHashReducer hash_reduce) const {
    hash_reduce(src);
    hash_reduce(dst);
    hash_reduce(dim);
    hash_reduce(type);
    hash_reduce(clear);
  }

  static constexpr bool _type_has_method_sequal_reduce = true;
  static constexpr bool _type_has_method_shash_reduce = true;

  /// Lower the operator to TIR statements
117
  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
118
  /// Infer memory layout for buffers
119
120
121
122
123
124
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;
  static const Op &Get();
  TileOperator Clone() const;

private:
125
  /// Generate initial value for reduction
126
  PrimExpr MakeInitValue() const;
127
  /// Generate reduction expression
128
  PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const;
129
  /// Generate codegen reducer string
130
131
132
  std::string MakeCodegenReducer() const;
};

133
/// Wrapper class for reduction operations
134
class ReduceOp : public TileOperator {
135
public:
136
137
  TVM_DEFINE_OBJECT_REF_METHODS(ReduceOp, TileOperator, ReduceOpNode);
  TVM_DLL ReduceOp(Array<PrimExpr> args, BufferMap vmap);
138
  static const Op &Get();
139
};
140

141
/// Node class for cumulative sum operations
142
143
class CumSumOpNode : public TileOperatorNode {
public:
144
145
146
  tir::Buffer src, dst; ///< Source and destination buffers
  int dim;              ///< Dimension along which to compute cumulative sum
  bool reverse;         ///< Whether to compute in reverse order
147
148
149
150
151
152
153
154
155
156
  static constexpr const char *_type_key = "tl.CumSumOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(CumSumOpNode, TileOperatorNode);

  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;
  static const Op &Get();
  TileOperator Clone() const;
};

157
/// Wrapper class for cumulative sum operations
158
159
160
161
162
class CumSumOp : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(CumSumOp, TileOperator, CumSumOpNode);
  TVM_DLL CumSumOp(Array<PrimExpr> args, BufferMap vmap);
  static const Op &Get();
163
164
};

165
166
} // namespace tl
} // namespace tvm
167

168
#endif //  TVM_TL_OP_REDUCE_H_