atomicadd_vectorize.cc 10 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
/*!
 * \file atomicadd_vectorize.cc
 * \brief A tool to automatically vectorize atomic add
 */

#include "atomicadd_vectorize.h"

namespace tvm {
namespace tl {

using namespace tir;
using arith::IRMutatorWithAnalyzer;
using arith::IRVisitorWithAnalyzer;

AtomicAddVectorizePlanner::AtomicAddVectorizePlanner() = default;

AtomicAddVectorizePlanResult
AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) {
  int vectorize_size_max = 1;
  this->vector_size_ = 4;
  this->dynamic_ = false;
  this->condition_ = PrimExpr();

  PostOrderVisit(node, [&](const ObjectRef &obj) {
    if (const auto *call = obj.as<CallNode>()) {
      if (call->op == atomicadd_elem_op()) {
        if (call->args.size() < 2) {
          // Fallback: unexpected arity
          vectorize_size_max = 1;
          DLOG(WARNING) << "[AtomicAddVectorizePlanner] atomicadd_elem_op "
                           "expects 2 args, got "
                        << call->args.size() << "; Fallback to no vectorize";
          return;
        }
        DataType dtype;
        if (const auto *load = call->args[0].as<BufferLoadNode>()) {
          dtype = load->dtype;
          vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
        } else if (const auto *ite = call->args[0].as<IfThenElseNode>()) {
          if (const auto *then_load = ite->then_case.as<BufferLoadNode>()) {
            dtype = then_load->dtype;
            vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
          } else if (const auto *else_load =
                         ite->else_case.as<BufferLoadNode>()) {
            dtype = else_load->dtype;
            vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
          } else {
            // fallback
            vectorize_size_max = 1;
            DLOG(WARNING) << "[AtomicAddVectorizePlanner] IfThenElse case "
                             "has no BufferLoad; Fallback to no vectorize";
          }
        } else {
          // fallback
          vectorize_size_max = 1;
          DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg1 type "
                        << call->args[1]->GetTypeKey()
                        << "; Fallback to no vectorize";
        }
      }
    }
  });

  if (vectorize_size_max <= 1) {
    return {1, dynamic_, condition_};
  }

  this->max_vector_size = vectorize_size_max;
  this->operator()(node);
  return {vector_size_, dynamic_, condition_};
}

void AtomicAddVectorizePlanner::VisitStmt_(const ForNode *node) {
  inner_for_ = node;
  arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}

void AtomicAddVectorizePlanner::VisitExpr_(const CallNode *node) {
  if (node->op == atomicadd_elem_op() && !node->args.empty()) {
    if (node->args.size() < 2) {
      return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
    }
    const BufferLoadNode *buffer_load_dst = node->args[0].as<BufferLoadNode>();
    const BufferLoadNode *buffer_load_src = node->args[1].as<BufferLoadNode>();
    if (buffer_load_src && buffer_load_src->buffer.defined() &&
        buffer_load_dst && buffer_load_dst->buffer.defined()) {
      Buffer dst_buffer = buffer_load_dst->buffer;
      UpdateVectorSize(buffer_load_dst->indices, dst_buffer);

      Buffer src_buffer = buffer_load_src->buffer;
      UpdateVectorSize(buffer_load_src->indices, src_buffer);
    }
  }
  return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}

int AtomicAddVectorizePlanner::GetVectorizeSizeMax(int compute_capability,
                                                   DataType dtype) {
  if (dtype == DataType::Float(16)) {
    return 2;
  }
  if (dtype == DataType::BFloat(16)) {
    return compute_capability > 75 ? 2 : 1;
  }
  if (dtype == DataType::Float(32)) {
    return compute_capability >= 90 ? 4 : 1;
  }
  return 1;
}

void AtomicAddVectorizePlanner::UpdateVectorSize(const Array<PrimExpr> &indices,
                                                 const Buffer &buffer) {
  if (!inner_for_)
    return;
  auto extent_ptr = inner_for_->extent.as<IntImmNode>();
  if (!extent_ptr)
    return;

  const DataType &access_type = buffer->dtype;
  max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);

  auto last_dim = buffer->shape.back();
  auto mod_set = analyzer_.modular_set(last_dim);

  if (buffer->shape.back().as<IntImmNode>()) {
    max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
    auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);

    if (gcd_base < Downcast<IntImm>(last_dim)->value) {
      max_vector_size = gcd_base;
    }

    vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);

    PrimExpr elem_offset = 0;
    PrimExpr stride = 1;
    for (int i = indices.size() - 1; i >= 0; --i) {
      elem_offset = elem_offset + indices[i] * stride;
      stride = stride * buffer->shape[i];
    }

    while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
                               inner_for_->extent, vector_size_, &analyzer_)) {
      vector_size_ /= 2;
    }
  } else if (vector_size_ <= 4) {
    dynamic_ = true;
    PrimExpr offset = buffer.OffsetOf(indices).back();
    condition_ = (truncmod(offset, vector_size_) == 0);
  }
}

class AtomicAddVectorizeRewriter : public StmtExprMutator {
public:
  AtomicAddVectorizeRewriter(const AtomicAddVectorizePlanResult &plan)
      : vector_size_(plan.vector_size), dynamic_(plan.dynamic),
        condition_(plan.condition) {}

private:
  /**
   * @brief Visits a For node and rewrites the innermost loop for atomic-add
   * vectorization.
   *
   * If the visited For node is the recorded innermost loop, this method
   * validates that the loop extent is a constant, divisible by the planned
   * vector size, and has a zero minimum. When vectorization is enabled
   * (dynamic_ == false) it:
   *  - locates the thread index variable named "tx" inside the loop body,
   *  - creates a new outer loop variable named "<old_loop_var>_outer",
   *  - substitutes occurrences of `tx` with `tx * vector_size_` and the old
   * loop var with `outer_var * vector_size_` so each outer iteration maps to a
   * contiguous vector-sized chunk,
   *  - returns a new For with extent divided by vector_size_ and the
   * transformed body.
   *
   * If dynamic_ is true, the method returns the (possibly mutated) inner For
   * unchanged.
   *
   * Side effects:
   *  - updates inner_for_ to point to the current For node during visitation.
   *  - performs runtime checks (ICHECK) to enforce: constant extent, extent %
   * vector_size_ == 0, and zero loop minimum; violations terminate execution.
   *
   * @return The original or transformed For statement as a Stmt.
   */
  Stmt VisitStmt_(const ForNode *node) final {
    inner_for_ = node;
    auto ret = StmtExprMutator::VisitStmt_(node);
    if (vector_size_ == 1)
      return ret;
    if (inner_for_ == node) {
      For fnode = ret.as<For>().value();
      auto old_var = fnode->loop_var;
      auto new_var = Var(old_var->name_hint);
      auto extent_ptr = as_const_int(fnode->extent);
      ICHECK(extent_ptr) << fnode->extent;
      int extent = *extent_ptr;
      ICHECK(extent % vector_size_ == 0)
          << "extent: " << extent << " vector_size_: " << vector_size_;
      ICHECK(is_zero(fnode->min));
      if (!dynamic_) {
        Map<Var, PrimExpr> vmap;
        vmap.Set(old_var, new_var * vector_size_);
        Stmt body = Substitute(fnode->body, vmap);
        return For(new_var, 0, extent / vector_size_, fnode->kind, body,
                   fnode->thread_binding, fnode->annotations, fnode->span);
      }
    }
    return ret;
  }

  PrimExpr VisitExpr_(const CallNode *node) final {
    bool legal_vectorize = true;
    if (dynamic_)
      legal_vectorize = false;
    if (!(node->op == atomicadd_elem_op()))
      legal_vectorize = false;
    if (node->args.size() < 2)
      legal_vectorize = false;
    if (legal_vectorize) {
      const BufferLoadNode *temp_dst_node = node->args[0].as<BufferLoadNode>();
      const BufferLoadNode *temp_value_node =
          node->args[1].as<BufferLoadNode>();
      if (!temp_dst_node || !temp_value_node)
        legal_vectorize = false;
    }
    if (legal_vectorize) {
      const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]);
      const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]);
      // The default memory order is relaxed
      // Ref: src/tl_templates/cuda/atomic.h::AtomicAdd
      const IntImm memory_order =
          node->args.size() >= 3 ? Downcast<IntImm>(node->args[2]) : IntImm(0);
      Array<PrimExpr> new_args;
      Call address_of_dst =
          Call(DataType::Handle(), builtin::address_of(), {dst_node});
      Call address_of_value =
          Call(DataType::Handle(), builtin::address_of(), {value_node});
      if (vector_size_ == 4) {
        new_args.push_back(StringImm("AtomicAddx4"));
        new_args.push_back(address_of_dst);
        new_args.push_back(address_of_value);
      } else if (vector_size_ == 2) {
        new_args.push_back(StringImm("AtomicAddx2"));
        new_args.push_back(address_of_dst);
        new_args.push_back(address_of_value);
      } else {
        new_args.push_back(StringImm("AtomicAdd"));
        new_args.push_back(dst_node);
        new_args.push_back(value_node);
      }

      new_args.push_back(memory_order);

      Call new_call =
          tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);

      return new_call;
    } else {
      Array<PrimExpr> new_args;
      new_args.push_back(StringImm("AtomicAdd"));
      for (auto x : node->args)
        new_args.push_back(x);

      Call new_call =
          tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);

      return new_call;
    }
  }

  const ForNode *inner_for_;
  const int vector_size_;
  const PrimExpr condition_;
  const bool dynamic_;
};

For VectorizeAtomicAdd(const For &for_node, int compute_capability) {
  AtomicAddVectorizePlanResult res = {1, false, 0};
  AtomicAddVectorizePlanner planner;
  res = planner.Plan(for_node, compute_capability);
  auto rewriter = AtomicAddVectorizeRewriter(res);
  return Downcast<For>(rewriter(for_node));
}

} // namespace tl
} // namespace tvm