atomicadd_vectorize.h 1.48 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 * \file atomicadd_vectorize.h
 * \brief A tool to automatically vectorize a for atomicadd
 */

#ifndef TVM_TL_ATOMICADD_VECTORIZE_H_
#define TVM_TL_ATOMICADD_VECTORIZE_H_

9
10
11
12
13
14
15
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "atomicadd_vectorize.h"
#include "common/loop_vectorization_utils.h"
#include <numeric>
16
#include <tvm/arith/analyzer.h>
17
18
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
19
#include <tvm/tir/op.h>
20
21
#include <tvm/tir/stmt_functor.h>
#include <utility>
22
23
24
25
26
27

namespace tvm {
namespace tl {

using namespace tir;

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
For VectorizeAtomicAdd(const For &for_node, int compute_capability);

struct AtomicAddVectorizePlanResult {
  int vector_size;
  bool dynamic;
  PrimExpr condition;
};

class AtomicAddVectorizePlanner : public arith::IRVisitorWithAnalyzer {
public:
  AtomicAddVectorizePlanner();

  AtomicAddVectorizePlanResult Plan(const For &node, int compute_capability);

private:
  void VisitStmt_(const ForNode *node) final;
  void VisitExpr_(const CallNode *node) final;

  int GetVectorizeSizeMax(int compute_capability, DataType dtype);
  void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer);

  const ForNode *inner_for_ = nullptr;
  bool has_nonlocal_memory_access_ = false;
  int vector_size_ = 4;
  int max_vector_size = 1;
  bool dynamic_ = false;
  PrimExpr condition_;
};
56
57
58
59
60

} // namespace tl
} // namespace tvm

#endif // TVM_TL_ATOMICADD_VECTORIZE_H_