atomic_add.h 2.59 KB
Newer Older
1
2
/*!
 * \file tl/op/atomic_add.h
3
 * \brief Atomic addition operations for concurrent memory updates
4
5
6
7
8
 */

#ifndef TVM_TL_OP_ATOMIC_ADD_H_
#define TVM_TL_OP_ATOMIC_ADD_H_

9
#include "operator.h"
10
11
12
13
14
15
16
#include "parallel.h"

namespace tvm {
namespace tl {

using namespace tir;

17
/// Node class for atomic addition operations
18
class AtomicAddNode : public TileOperatorNode {
19
public:
20
21
22
  Buffer src, dst; ///< Source and destination buffers
  Array<Range> src_range,
      dst_range;          ///< Access ranges for source and destination
23
  IntImm use_tma;         ///< Whether to use TMA for memory operations
24
  IntImm coalesced_width; ///< Width for memory coalescing optimization
25
  IntImm memory_order;    ///< Memory order for atomic operations
26

27
  mutable ParallelOp par_op_; ///< Associated parallel operation
28
29
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicAdd", AtomicAddNode,
                                    TileOperatorNode);
30
31
32
33
34
35

  Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
  LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;

  static const Op &Get();
  TileOperator Clone() const;
36

37
38
39
40
41
42
43
  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<AtomicAddNode>()
        .def_ro("src", &AtomicAddNode::src)
        .def_ro("dst", &AtomicAddNode::dst)
        .def_ro("src_range", &AtomicAddNode::src_range)
        .def_ro("dst_range", &AtomicAddNode::dst_range)
44
        .def_ro("use_tma", &AtomicAddNode::use_tma)
45
46
        .def_ro("coalesced_width", &AtomicAddNode::coalesced_width)
        .def_ro("memory_order", &AtomicAddNode::memory_order);
47
48
  }

49
protected:
50
  /// Create SIMT-style parallel loop structure
51
  For MakeSIMTLoop(arith::Analyzer *analyzer) const;
52
  /// Generate iteration variables for loop nest
53
  Array<IterVar> MakeIterVars() const;
54
  /// Generate buffer indices from iteration variables
55
  Array<PrimExpr> MakeIndices(const Array<IterVar> &ivs, int src_dst) const;
56
57
  /// Return buffer indices and size
  std::pair<Array<PrimExpr>, PrimExpr> ReturnIndicesAndSize(int src_dst) const;
58
  /// Create boundary predicate for memory safety
59
60
  PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array<IterVar> &ivs,
                         Array<PrimExpr> extents, int src_dst) const;
61
};
62

63
/// Wrapper class for atomic addition operations
64
65
class AtomicAdd : public TileOperator {
public:
66
67
  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator,
                                             AtomicAddNode);
68
  TVM_DLL AtomicAdd(Array<PrimExpr> args);
69
  static const Op &Get();
70
71
72
73
74
};

} // namespace tl
} // namespace tvm

75
#endif //  TVM_TL_OP_ATOMIC_ADD_H_