"vscode:/vscode.git/clone" did not exist on "e5fb9c518b6aa0838391004eeb122cdee34b60c4"
atomicadd_vectorize.cc 9.83 KB
Newer Older
1
2
/*!
 * \file atomicadd_vectorize.cc
Gabriel Wu's avatar
Gabriel Wu committed
3
 * \brief A tool to automatically vectorize atomic add
4
5
 */

6
#include "atomicadd_vectorize.h"
7
8
9
10
11
12
13
14

namespace tvm {
namespace tl {

using namespace tir;
using arith::IRMutatorWithAnalyzer;
using arith::IRVisitorWithAnalyzer;

15
AtomicAddVectorizePlanner::AtomicAddVectorizePlanner() = default;
16

17
18
19
20
21
22
AtomicAddVectorizePlanResult
AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) {
  int vectorize_size_max = 1;
  this->vector_size_ = 4;
  this->dynamic_ = false;
  this->condition_ = PrimExpr();
23

24
25
  PostOrderVisit(node, [&](const ObjectRef &obj) {
    if (const auto *call = obj.as<CallNode>()) {
26
27
28
29
30
31
32
      if (call->op == atomicadd_elem_op()) {
        if (call->args.size() < 2) {
          // Fallback: unexpected arity
          vectorize_size_max = 1;
          DLOG(WARNING) << "[AtomicAddVectorizePlanner] atomicadd_elem_op "
                           "expects 2 args, got "
                        << call->args.size() << "; Fallback to no vectorize";
33
          return;
34
35
36
37
38
39
40
41
42
43
44
45
        }
        DataType dtype;
        if (const auto *load = call->args[0].as<BufferLoadNode>()) {
          dtype = load->dtype;
          vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
        } else if (const auto *ite = call->args[0].as<IfThenElseNode>()) {
          if (const auto *then_load = ite->then_case.as<BufferLoadNode>()) {
            dtype = then_load->dtype;
            vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
          } else if (const auto *else_load =
                         ite->else_case.as<BufferLoadNode>()) {
            dtype = else_load->dtype;
46
47
48
49
            vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
          } else {
            // fallback
            vectorize_size_max = 1;
50
51
            DLOG(WARNING) << "[AtomicAddVectorizePlanner] IfThenElse case "
                             "has no BufferLoad; Fallback to no vectorize";
52
          }
53
54
55
56
57
58
        } else {
          // fallback
          vectorize_size_max = 1;
          DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg1 type "
                        << call->args[1]->GetTypeKey()
                        << "; Fallback to no vectorize";
59
60
61
62
        }
      }
    }
  });
63

64
65
  if (vectorize_size_max <= 1) {
    return {1, dynamic_, condition_};
66
67
  }

68
69
70
71
  this->max_vector_size = vectorize_size_max;
  this->operator()(node);
  return {vector_size_, dynamic_, condition_};
}
72

73
74
75
76
77
78
void AtomicAddVectorizePlanner::VisitStmt_(const ForNode *node) {
  inner_for_ = node;
  arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}

void AtomicAddVectorizePlanner::VisitExpr_(const CallNode *node) {
79
80
81
82
83
84
85
86
87
88
  if (node->op == atomicadd_elem_op() && !node->args.empty()) {
    if (node->args.size() < 2) {
      return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
    }
    const BufferLoadNode *buffer_load_dst = node->args[0].as<BufferLoadNode>();
    const BufferLoadNode *buffer_load_src = node->args[1].as<BufferLoadNode>();
    if (buffer_load_src && buffer_load_src->buffer.defined() &&
        buffer_load_dst && buffer_load_dst->buffer.defined()) {
      Buffer dst_buffer = buffer_load_dst->buffer;
      UpdateVectorSize(buffer_load_dst->indices, dst_buffer);
89

90
91
      Buffer src_buffer = buffer_load_src->buffer;
      UpdateVectorSize(buffer_load_src->indices, src_buffer);
92
93
    }
  }
94
95
  return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
96

97
98
99
100
101
102
103
104
105
106
107
108
109
int AtomicAddVectorizePlanner::GetVectorizeSizeMax(int compute_capability,
                                                   DataType dtype) {
  if (dtype == DataType::Float(16)) {
    return 2;
  }
  if (dtype == DataType::BFloat(16)) {
    return compute_capability > 75 ? 2 : 1;
  }
  if (dtype == DataType::Float(32)) {
    return compute_capability >= 90 ? 4 : 1;
  }
  return 1;
}
110

111
112
113
114
115
116
117
void AtomicAddVectorizePlanner::UpdateVectorSize(const Array<PrimExpr> &indices,
                                                 const Buffer &buffer) {
  if (!inner_for_)
    return;
  auto extent_ptr = inner_for_->extent.as<IntImmNode>();
  if (!extent_ptr)
    return;
118

119
120
  const DataType &access_type = buffer->dtype;
  max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
121

122
123
  auto last_dim = buffer->shape.back();
  auto mod_set = analyzer_.modular_set(last_dim);
124

125
126
127
  if (buffer->shape.back().as<IntImmNode>()) {
    max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
    auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
128

129
130
131
    if (gcd_base < Downcast<IntImm>(last_dim)->value) {
      max_vector_size = gcd_base;
    }
132

133
    vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
134

135
136
137
138
139
    PrimExpr elem_offset = 0;
    PrimExpr stride = 1;
    for (int i = indices.size() - 1; i >= 0; --i) {
      elem_offset = elem_offset + indices[i] * stride;
      stride = stride * buffer->shape[i];
140
141
    }

142
143
144
145
146
147
148
149
150
151
    while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
                               inner_for_->extent, vector_size_, &analyzer_)) {
      vector_size_ /= 2;
    }
  } else if (vector_size_ <= 4) {
    dynamic_ = true;
    PrimExpr offset = buffer.OffsetOf(indices).back();
    condition_ = (truncmod(offset, vector_size_) == 0);
  }
}
152
153
154

class AtomicAddVectorizeRewriter : public StmtExprMutator {
public:
155
156
157
  AtomicAddVectorizeRewriter(const AtomicAddVectorizePlanResult &plan)
      : vector_size_(plan.vector_size), dynamic_(plan.dynamic),
        condition_(plan.condition) {}
158
159

private:
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
  /**
   * @brief Visits a For node and rewrites the innermost loop for atomic-add
   * vectorization.
   *
   * If the visited For node is the recorded innermost loop, this method
   * validates that the loop extent is a constant, divisible by the planned
   * vector size, and has a zero minimum. When vectorization is enabled
   * (dynamic_ == false) it:
   *  - locates the thread index variable named "tx" inside the loop body,
   *  - creates a new outer loop variable named "<old_loop_var>_outer",
   *  - substitutes occurrences of `tx` with `tx * vector_size_` and the old
   * loop var with `outer_var * vector_size_` so each outer iteration maps to a
   * contiguous vector-sized chunk,
   *  - returns a new For with extent divided by vector_size_ and the
   * transformed body.
   *
   * If dynamic_ is true, the method returns the (possibly mutated) inner For
   * unchanged.
   *
   * Side effects:
   *  - updates inner_for_ to point to the current For node during visitation.
   *  - performs runtime checks (ICHECK) to enforce: constant extent, extent %
   * vector_size_ == 0, and zero loop minimum; violations terminate execution.
   *
   * @return The original or transformed For statement as a Stmt.
   */
186
187
188
  Stmt VisitStmt_(const ForNode *node) final {
    inner_for_ = node;
    auto ret = StmtExprMutator::VisitStmt_(node);
189
190
    if (vector_size_ == 1)
      return ret;
191
    if (inner_for_ == node) {
192
      For fnode = ret.as<For>().value();
193
194
      auto old_var = fnode->loop_var;
      auto new_var = Var(old_var->name_hint);
195
196
197
198
199
200
201
202
      auto extent_ptr = as_const_int(fnode->extent);
      ICHECK(extent_ptr) << fnode->extent;
      int extent = *extent_ptr;
      ICHECK(extent % vector_size_ == 0)
          << "extent: " << extent << " vector_size_: " << vector_size_;
      ICHECK(is_zero(fnode->min));
      if (!dynamic_) {
        Map<Var, PrimExpr> vmap;
203
        vmap.Set(old_var, new_var * vector_size_);
204
        Stmt body = Substitute(fnode->body, vmap);
205
        return For(new_var, 0, extent / vector_size_, fnode->kind, body,
206
207
208
                   fnode->thread_binding, fnode->annotations, fnode->span);
      }
    }
209
    return ret;
210
211
212
  }

  PrimExpr VisitExpr_(const CallNode *node) final {
213
214
215
216
217
218
219
220
221
222
223
224
225
    bool legal_vectorize = true;
    if (dynamic_)
      legal_vectorize = false;
    if (!(node->op == atomicadd_elem_op()))
      legal_vectorize = false;
    if (node->args.size() < 2)
      legal_vectorize = false;
    if (legal_vectorize) {
      const BufferLoadNode *temp_dst_node = node->args[0].as<BufferLoadNode>();
      const BufferLoadNode *temp_value_node =
          node->args[1].as<BufferLoadNode>();
      if (!temp_dst_node || !temp_value_node)
        legal_vectorize = false;
226
    }
227
228
229
    if (legal_vectorize) {
      const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]);
      const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]);
230
231
232
233
      // The default memory order is relaxed
      // Ref: src/tl_templates/cuda/atomic.h::AtomicAdd
      const IntImm memory_order =
          node->args.size() >= 3 ? Downcast<IntImm>(node->args[2]) : IntImm(0);
234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
      Call address_of_dst =
          Call(DataType::Handle(), builtin::address_of(), {dst_node});
      Call address_of_value =
          Call(DataType::Handle(), builtin::address_of(), {value_node});
      Array<PrimExpr> new_args;
      if (vector_size_ == 4) {
        new_args.push_back(StringImm("AtomicAddx4"));
      } else if (vector_size_ == 2) {
        new_args.push_back(StringImm("AtomicAddx2"));
      } else {
        new_args.push_back(StringImm("AtomicAdd"));
      }
      new_args.push_back(address_of_dst);
      new_args.push_back(address_of_value);
249
      new_args.push_back(memory_order);
250

251
252
      Call new_call =
          tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);
253

254
255
256
257
258
259
260
261
262
263
264
      return new_call;
    } else {
      Array<PrimExpr> new_args;
      new_args.push_back(StringImm("AtomicAdd"));
      for (auto x : node->args)
        new_args.push_back(x);

      Call new_call =
          tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);

      return new_call;
265
266
267
268
269
270
271
272
273
    }
  }

  const ForNode *inner_for_;
  const int vector_size_;
  const PrimExpr condition_;
  const bool dynamic_;
};

274
275
276
277
278
279
For VectorizeAtomicAdd(const For &for_node, int compute_capability) {
  AtomicAddVectorizePlanResult res = {1, false, 0};
  AtomicAddVectorizePlanner planner;
  res = planner.Plan(for_node, compute_capability);
  auto rewriter = AtomicAddVectorizeRewriter(res);
  return Downcast<For>(rewriter(for_node));
280
281
282
}

} // namespace tl
283
} // namespace tvm