atomicadd_vectorize.cc 9.66 KB
Newer Older
1
2
/*!
 * \file atomicadd_vectorize.cc
Gabriel Wu's avatar
Gabriel Wu committed
3
 * \brief A tool to automatically vectorize atomic add
4
5
 */

6
#include "atomicadd_vectorize.h"
7
8
9
10
11
12
13
14

namespace tvm {
namespace tl {

using namespace tir;
using arith::IRMutatorWithAnalyzer;
using arith::IRVisitorWithAnalyzer;

15
AtomicAddVectorizePlanner::AtomicAddVectorizePlanner() = default;
16

17
18
19
20
21
22
AtomicAddVectorizePlanResult
AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) {
  int vectorize_size_max = 1;
  this->vector_size_ = 4;
  this->dynamic_ = false;
  this->condition_ = PrimExpr();
23

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  PostOrderVisit(node, [&](const ObjectRef &obj) {
    if (const auto *call = obj.as<CallNode>()) {
      if (call->op == builtin::call_extern() && call->args.size() >= 2) {
        const auto *func_name = call->args[0].as<StringImmNode>();
        if (!func_name)
          return;
        if (func_name->value == "AtomicAdd") {
          DataType dtype;
          if (const auto *load = call->args[1].as<BufferLoadNode>()) {
            dtype = load->dtype;
            vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype);
          } else if (const auto *ite = call->args[1].as<IfThenElseNode>()) {
            if (const auto *then_load = ite->then_case.as<BufferLoadNode>()) {
              dtype = then_load->dtype;
              vectorize_size_max =
                  GetVectorizeSizeMax(compute_capability, dtype);
            } else if (const auto *else_load =
                           ite->else_case.as<BufferLoadNode>()) {
              dtype = else_load->dtype;
              vectorize_size_max =
                  GetVectorizeSizeMax(compute_capability, dtype);
            } else {
              // fallback
              vectorize_size_max = 1;
              DLOG(WARNING) << "[AtomicAddVectorizePlanner] IfThenElse case "
                               "has no BufferLoad; Fallback to no vectorize";
            }
          } else {
            // fallback
            vectorize_size_max = 1;
            DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg1 type "
                          << call->args[1]->GetTypeKey()
                          << "; Fallback to no vectorize";
          }
        }
      }
    }
  });
62

63
64
  if (vectorize_size_max <= 1) {
    return {1, dynamic_, condition_};
65
66
  }

67
68
69
70
  this->max_vector_size = vectorize_size_max;
  this->operator()(node);
  return {vector_size_, dynamic_, condition_};
}
71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
void AtomicAddVectorizePlanner::VisitStmt_(const ForNode *node) {
  inner_for_ = node;
  arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}

void AtomicAddVectorizePlanner::VisitExpr_(const CallNode *node) {
  if (node->op == builtin::call_extern() && node->args.size() >= 2) {
    if (const auto *func_name = node->args[0].as<StringImmNode>()) {
      if (func_name->value == "AtomicAdd") {
        const BufferLoadNode *buffer_load_dst =
            node->args[1].as<BufferLoadNode>();
        const BufferLoadNode *buffer_load_src =
            node->args[2].as<BufferLoadNode>();
        if (buffer_load_src && buffer_load_src->buffer.defined() &&
            buffer_load_dst && buffer_load_dst->buffer.defined()) {
          Buffer dst_buffer = buffer_load_dst->buffer;
          UpdateVectorSize(buffer_load_dst->indices, dst_buffer);

          Buffer src_buffer = buffer_load_src->buffer;
          UpdateVectorSize(buffer_load_src->indices, src_buffer);
92
93
94
95
        }
      }
    }
  }
96
97
  return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
98

99
100
101
102
103
104
105
106
107
108
109
110
111
int AtomicAddVectorizePlanner::GetVectorizeSizeMax(int compute_capability,
                                                   DataType dtype) {
  if (dtype == DataType::Float(16)) {
    return 2;
  }
  if (dtype == DataType::BFloat(16)) {
    return compute_capability > 75 ? 2 : 1;
  }
  if (dtype == DataType::Float(32)) {
    return compute_capability >= 90 ? 4 : 1;
  }
  return 1;
}
112

113
114
115
116
117
118
119
void AtomicAddVectorizePlanner::UpdateVectorSize(const Array<PrimExpr> &indices,
                                                 const Buffer &buffer) {
  if (!inner_for_)
    return;
  auto extent_ptr = inner_for_->extent.as<IntImmNode>();
  if (!extent_ptr)
    return;
120

121
122
  const DataType &access_type = buffer->dtype;
  max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
123

124
125
  auto last_dim = buffer->shape.back();
  auto mod_set = analyzer_.modular_set(last_dim);
126

127
128
129
  if (buffer->shape.back().as<IntImmNode>()) {
    max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
    auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
130

131
132
133
    if (gcd_base < Downcast<IntImm>(last_dim)->value) {
      max_vector_size = gcd_base;
    }
134

135
    vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
136

137
138
139
140
141
    PrimExpr elem_offset = 0;
    PrimExpr stride = 1;
    for (int i = indices.size() - 1; i >= 0; --i) {
      elem_offset = elem_offset + indices[i] * stride;
      stride = stride * buffer->shape[i];
142
143
    }

144
145
146
147
148
149
150
151
152
153
    while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
                               inner_for_->extent, vector_size_, &analyzer_)) {
      vector_size_ /= 2;
    }
  } else if (vector_size_ <= 4) {
    dynamic_ = true;
    PrimExpr offset = buffer.OffsetOf(indices).back();
    condition_ = (truncmod(offset, vector_size_) == 0);
  }
}
154
155
156

class AtomicAddVectorizeRewriter : public StmtExprMutator {
public:
157
158
159
  AtomicAddVectorizeRewriter(const AtomicAddVectorizePlanResult &plan)
      : vector_size_(plan.vector_size), dynamic_(plan.dynamic),
        condition_(plan.condition) {}
160
161

private:
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
  /**
   * @brief Visits a For node and rewrites the innermost loop for atomic-add
   * vectorization.
   *
   * If the visited For node is the recorded innermost loop, this method
   * validates that the loop extent is a constant, divisible by the planned
   * vector size, and has a zero minimum. When vectorization is enabled
   * (dynamic_ == false) it:
   *  - locates the thread index variable named "tx" inside the loop body,
   *  - creates a new outer loop variable named "<old_loop_var>_outer",
   *  - substitutes occurrences of `tx` with `tx * vector_size_` and the old
   * loop var with `outer_var * vector_size_` so each outer iteration maps to a
   * contiguous vector-sized chunk,
   *  - returns a new For with extent divided by vector_size_ and the
   * transformed body.
   *
   * If dynamic_ is true, the method returns the (possibly mutated) inner For
   * unchanged.
   *
   * Side effects:
   *  - updates inner_for_ to point to the current For node during visitation.
   *  - performs runtime checks (ICHECK) to enforce: constant extent, extent %
   * vector_size_ == 0, and zero loop minimum; violations terminate execution.
   *
   * @return The original or transformed For statement as a Stmt.
   */
188
189
190
  Stmt VisitStmt_(const ForNode *node) final {
    inner_for_ = node;
    auto ret = StmtExprMutator::VisitStmt_(node);
191
    if (inner_for_ == node) {
192
      For fnode = ret.as<For>().value();
193
194
      auto old_var = fnode->loop_var;
      auto new_var = Var(old_var->name_hint);
195
196
197
198
199
200
201
202
      auto extent_ptr = as_const_int(fnode->extent);
      ICHECK(extent_ptr) << fnode->extent;
      int extent = *extent_ptr;
      ICHECK(extent % vector_size_ == 0)
          << "extent: " << extent << " vector_size_: " << vector_size_;
      ICHECK(is_zero(fnode->min));
      if (!dynamic_) {
        Map<Var, PrimExpr> vmap;
203
        vmap.Set(old_var, new_var * vector_size_);
204
        Stmt body = Substitute(fnode->body, vmap);
205
        return For(new_var, 0, extent / vector_size_, fnode->kind, body,
206
207
208
                   fnode->thread_binding, fnode->annotations, fnode->span);
      }
    }
209
    return ret;
210
211
212
  }

  PrimExpr VisitExpr_(const CallNode *node) final {
213
214
215
    if (dynamic_) {
      return StmtExprMutator::VisitExpr_(node);
    }
216
217
218
219
    if (vector_size_ == 2 || vector_size_ == 4) {
      if (node->op == builtin::call_extern() && node->args.size() >= 2) {
        if (const auto *func_name = node->args[0].as<StringImmNode>()) {
          if (func_name->value == "AtomicAdd") {
220
            const BufferLoadNode *temp_dst_node =
221
                node->args[1].as<BufferLoadNode>();
222
            const BufferLoadNode *temp_value_node =
223
                node->args[2].as<BufferLoadNode>();
224
            if (!temp_dst_node || !temp_value_node) {
225
226
              return StmtExprMutator::VisitExpr_(node);
            }
227
228
229
230
            const BufferLoad dst_node =
                Downcast<BufferLoad>(node->args[1].as<BufferLoadNode>());
            const BufferLoad value_node =
                Downcast<BufferLoad>(node->args[2].as<BufferLoadNode>());
231

232
233
234
235
            Call address_of_dst =
                Call(DataType::Handle(), builtin::address_of(), {dst_node});
            Call address_of_value =
                Call(DataType::Handle(), builtin::address_of(), {value_node});
236
237
238
239
240
241
            Array<PrimExpr> new_args;
            if (vector_size_ == 2) {
              new_args.push_back(StringImm("AtomicAddx2"));
            } else {
              new_args.push_back(StringImm("AtomicAddx4"));
            }
242
            new_args.push_back(address_of_dst);
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
            new_args.push_back(address_of_value);

            Call new_call =
                tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);

            return new_call;
          }
        }
      }
    }
    return StmtExprMutator::VisitExpr_(node);
  }

  const ForNode *inner_for_;
  const int vector_size_;
  const PrimExpr condition_;
  const bool dynamic_;
};

262
263
264
265
266
267
For VectorizeAtomicAdd(const For &for_node, int compute_capability) {
  AtomicAddVectorizePlanResult res = {1, false, 0};
  AtomicAddVectorizePlanner planner;
  res = planner.Plan(for_node, compute_capability);
  int vectorize_hint = res.vector_size;
  if (vectorize_hint == 1)
268
    return for_node;
269
270
  auto rewriter = AtomicAddVectorizeRewriter(res);
  return Downcast<For>(rewriter(for_node));
271
272
273
}

} // namespace tl
274
} // namespace tvm