layout_reducer.h 2.15 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
/*!
 * \file layout_reducer.h
 */

#ifndef TVM_TL_TRANSFORM_LAYOUT_REDUCER_H_
#define TVM_TL_TRANSFORM_LAYOUT_REDUCER_H_

#include <tvm/tir/op.h>

#include "../layout/layout.h"

namespace tvm {
/**
 * Types of reduction operations supported by TL transforms.
 *
 * SUM   - arithmetic sum reduction.
 * MAX   - elementwise maximum reduction.
 * MIN   - elementwise minimum reduction.
 */

/**
 * Representation semantics for a reducer.
 *
 * ALL  - reducer collapses all elements along the reduced axes.
 * NONE - reducer does not collapse (used to represent a placeholder/no-op).
 */

/**
 * Holds metadata describing a reducer used in layout transforms.
 *
 * Contains the reduction operation (`op`) and its representation semantics
 * (`rep`).
 */

/**
 * Construct a ReducerInfoNode from textual identifiers.
 *
 * @param op_str  String identifier for the reduction operation (e.g., "sum",
 * "max", "min").
 * @param rep_str String identifier for the representation semantics (e.g.,
 * "all", "none").
 */

/**
 * Handle type for ReducerInfoNode (ObjectRef wrapper).
 *
 * Constructed from string identifiers for operation and representation.
 *
 * @param op_str  String identifier for the reduction operation (e.g., "sum",
 * "max", "min").
 * @param rep_str String identifier for the representation semantics (e.g.,
 * "all", "none").
 */

/**
 * Attribute key used to attach ReducerInfo to IR nodes or other attribute maps.
 */
namespace tl {

enum class ReducerOpType { SUM, MAX, MIN };
enum class ReducerRepType { ALL, NONE };

struct ReducerInfoNode : Object {
  ReducerOpType op;
  ReducerRepType rep;

  ReducerInfoNode() = default;
  ReducerInfoNode(const String &op_str, const String &rep_str);
  static constexpr const char *_type_key = "tl.ReducerInfo";
  TVM_DECLARE_FINAL_OBJECT_INFO(ReducerInfoNode, Object);
};

struct ReducerInfo : ObjectRef {
public:
  TVM_DLL ReducerInfo(const String &op_str, const String &rep_str) {
    data_ = make_object<ReducerInfoNode>(op_str, rep_str);
  }

  TVM_DEFINE_OBJECT_REF_METHODS(ReducerInfo, ObjectRef, ReducerInfoNode);
};

namespace attr {
constexpr const char *kReducerInfo = "reducer_info";
}

} // namespace tl
} // namespace tvm

#endif