finalize_reducer.h 3.09 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
// Copyright (c) Tile-AI Corporation.
// Licensed under the MIT License.

/*!
 * \file src/op/finalize_reducer.h
 * \brief Define finalize_reducer operator.
 */

#ifndef TVM_TL_OP_FINALIZE_REDUCER_H_
#define TVM_TL_OP_FINALIZE_REDUCER_H_

#include "../transform/layout_reducer.h"
#include "./operator.h"

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
/**
 * FinalizeReducer operator node for Tile IR.
 *
 * Represents a TL-level operator that finalizes a reducer buffer into a
 * result using a specified reducer operation.
 *
 * Public members:
 * - reducer: the tir::Buffer that holds the intermediate reduction values.
 * - op: the reducer operation to apply when finalizing values.
 */

/**
 * Lower this operator to a TIR statement.
 *
 * @param T Lowering arguments (buffers, indices, and other lowering context).
 * @param analyzer Arithmetic analyzer used to simplify expressions during
 * lowering.
 * @return A tir::Stmt that implements the finalize-reducer semantics for the
 * provided lowering context.
 */

/**
 * Infer layout mapping for this operator.
 *
 * Determines how input and output buffer layouts relate for the
 * finalize-reducer operator at the given inference level.
 *
 * @param T Layout inference arguments (including operand layouts and shapes).
 * @param level Inference precision level.
 * @return A LayoutMap describing the inferred layouts.
 */

/**
 * Get the singleton Op object representing this operator.
 *
 * @return A reference to the Op describing FinalizeReducer.
 */

/**
 * Create a deep copy of this operator node as a TileOperator.
 *
 * @return A TileOperator handle that is an independent clone of this node.
 */

/**
 * Public wrapper for FinalizeReducerOpNode.
 *
 * Provides the reference semantics and construction API used by callers.
 */

/**
 * Construct a FinalizeReducerOp from TL-level arguments.
 *
 * @param args Positional primitive expressions that parameterize the operator
 *             (e.g., shapes, axis indices). Documented where their meaning is
 *             not obvious from name or type in call sites.
 * @param vmap Mapping from operand names to tir::Buffer instances used by this
 * operator.
 */

/**
 * Get the Op singleton for the public FinalizeReducerOp handle.
 *
 * @return A reference to the Op describing FinalizeReducer.
 */
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
namespace tvm {
namespace tl {

using namespace tir;

class FinalizeReducerOpNode : public TileOperatorNode {
public:
  tir::Buffer reducer;
  ReducerOpType op;

  static constexpr const char *_type_key = "tl.FinalizeReducerOp";
  TVM_DECLARE_FINAL_OBJECT_INFO(FinalizeReducerOpNode, TileOperatorNode);

  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
  LayoutMap InferLayout(const LayoutInferArgs &T,
                        InferLevel level) const override;
  static const Op &Get();
  TileOperator Clone() const;
};

class FinalizeReducerOp : public TileOperator {
public:
  TVM_DEFINE_OBJECT_REF_METHODS(FinalizeReducerOp, TileOperator,
                                FinalizeReducerOpNode);
  TVM_DLL FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap);
  static const Op &Get();
};

} // namespace tl
} // namespace tvm

#endif //  TVM_TL_OP_FINALIZE_REDUCER_H_