atomic_add.cc 16.3 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file tl/op/atomic_add.cc
 *
 * Define elment-wise operators.
 */

7
8
#include "./atomic_add.h"
#include "./region.h"
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../target/utils.h"
#include "../transform/atomicadd_vectorize.h"
#include "../transform/common/loop_fusion_utils.h"
#include "../transform/loop_partition.h"
#include "builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

24
25
26
27
28
29
30
31
32
33
34
35
/**
 * @brief Extracts a numeric architecture identifier from a Target's "arch"
 * attribute.
 *
 * Reads the Target's "arch" string (must be defined) and, if it has the form
 * "sm_<N>", parses and returns N as an integer. For any other arch string,
 * returns 0.
 *
 * @param target Target whose "arch" attribute will be inspected (ICHECKs that
 * the attribute is defined).
 * @return int Parsed integer suffix when the arch is "sm_<N>", otherwise 0.
 */
36
37
38
39
40
41
42
43
44
45
46
47
48
static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
  const char *arch_str = s.value().c_str();
  if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
    arch_int = atoi(&arch_str[3]);
  } else {
    arch_int = 0;
  }
  return arch_int;
}

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
/**
 * @brief Construct an AtomicAdd operator from call arguments and a buffer map.
 *
 * Builds the internal AtomicAddNode, extracts the source and destination
 * regions and their backing Buffers from the first two call-style expressions
 * in `args` (via RegionOp), and stores them along with their ranges. If a third
 * argument is provided, it is interpreted as an integer immediate and stored as
 * the node's coalesced width.
 *
 * @param args Call-style PrimExprs where:
 *             - args[0] is the source region call,
 *             - args[1] is the destination region call,
 *             - args[2] (optional) is an IntImm specifying coalesced width.
 * @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects.
 *
 * Notes:
 * - The constructor checks that args[0] and args[1] are CallNodes.
 * - The constructed node is stored in this->data_.
 */
68
69
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
  ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
70
71
72
73
74
75
76
  Array<Range> rgs[2];
  Buffer bf[2];
  for (int i = 0; i < 2; i++) {
    auto expr = args[i];
    auto call = expr.as<CallNode>();
    ICHECK(call);
    auto region = RegionOp(call->args, vmap);
77
78
    rgs[i] = region->GetRanges();
    bf[i] = region->GetBuffer();
79
  }
80
81
  std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
  std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
82
  if (args.size() >= 3) {
83
    node->coalesced_width = Downcast<IntImm>(args[2]);
84
  }
85
  data_ = std::move(node);
86
87
}

88
89
90
91
92
93
94
95
96
/**
 * @brief Create a deep copy of this AtomicAdd node wrapped as a TileOperator.
 *
 * Produces a new AtomicAddNode object copied from this node. If this node has
 * an associated ParallelOp (par_op_), the parallel op is cloned and attached to
 * the new node so the cloned operator preserves parallelization state.
 *
 * @return TileOperator A TileOperator owning the cloned AtomicAddNode.
 */
97
98
99
100
101
102
103
104
TileOperator AtomicAddNode::Clone() const {
  auto op = make_object<AtomicAddNode>(*this);
  if (par_op_.defined()) {
    op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
  }
  return AtomicAdd(op);
}

105
106
107
108
109
110
111
112
113
114
115
116
117
/**
 * @brief Create data-parallel iteration variables for non-singleton dimensions
 * of the source.
 *
 * Constructs an Array of IterVar corresponding to each dimension in `src_range`
 * whose extent is not equal to 1. Each IterVar has domain Range(0, extent), a
 * Var named sequentially ("i", "j", "k", ...) with the same dtype as the
 * extent, and type IterVarType::kDataPar. The ordering of returned itervars
 * matches the order of dimensions in `src_range`.
 *
 * @return Array<IterVar> Iteration variables for all non-singleton extents in
 * `src_range`.
 */
118
Array<IterVar> AtomicAddNode::MakeIterVars() const {
119
120
121
122
123
124
125
126
127
128
129
130
131
132
  Array<IterVar> loop_vars;
  size_t idx = 0;
  for (size_t i = 0; i < src_range.size(); i++) {
    if (is_one(src_range[i]->extent))
      continue;
    Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
    idx++;
    loop_vars.push_back(
        {Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
  }
  return loop_vars;
}

// ivs: itervars returned by MakeIterVars()
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/**
 * @brief Build index expressions for either source or destination from loop
 * iter vars.
 *
 * Given a list of iteration variables that correspond to the non-singleton
 * extents of the selected region (source when src_dst == 0, destination when
 * src_dst == 1), return an array of index expressions matching the full rank of
 * that region. For dimensions with extent == 1, the corresponding index is the
 * range's minimum; otherwise the index is `min + ivar`.
 *
 * @param ivs Iteration variables in order for all non-singleton dimensions of
 * the chosen region.
 * @param src_dst Selects which region to index: 0 for source (src_range), 1 for
 * destination (dst_range).
 * @return Array<PrimExpr> Index expressions for every dimension of the selected
 * region, in original dimension order.
 *
 * @note The function checks that the number of provided iter vars equals the
 * number of non-singleton extents; it will abort (ICHECK) if they differ.
 */
153
154
Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
                                           int src_dst) const {
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
  Array<PrimExpr> indices;
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
    if (is_one(ranges[i]->extent))
      indices.push_back(ranges[i]->min);
    else {
      indices.push_back(ranges[i]->min + ivs[idx]->var);
      idx++;
    }
  }
  ICHECK(idx == ivs.size())
      << "idx = " << idx << ", ivs.size() = " << ivs.size()
      << "src name = " << src->name << ", dst name = " << dst->name;
  return indices;
}

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
/**
 * @brief Build a combined bound-check predicate for indexed access.
 *
 * Constructs an AND'd predicate ensuring each non-singleton index (derived from
 * `ivs`) stays within [0, extent) for the selected operand (source when
 * `src_dst==0`, destination otherwise). For each non-unit Range in the chosen
 * range list this produces two conditions:
 *   - range.min + iv >= 0
 *   - range.min + iv < extent
 *
 * Conditions that the analyzer can prove (with symbolic bounds) are omitted.
 * If no uncertain conditions remain, an empty PrimExpr is returned.
 *
 * Note: the function ICHECKs that `extents.size()` equals the number of ranges
 * for the selected operand.
 *
 * @param ivs Iteration variables corresponding to non-singleton extents (order
 *            matches the non-unit ranges of the chosen operand).
 * @param extents Per-dimension upper bounds to check against; must have the
 *                same size as the selected range list.
 * @param src_dst Selects which ranges to validate: 0 => `src_range`, else
 *                `dst_range`.
 * @return PrimExpr A conjunction of remaining (non-provable) bounds checks, or
 *         an empty PrimExpr when no checks are required.
 */
197
198
199
200
PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
                                      const Array<IterVar> &ivs,
                                      Array<PrimExpr> extents,
                                      int src_dst) const {
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  Array<PrimExpr> cond_list;
  ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
    if (is_one(ranges[i]->extent))
      continue;
    PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    cond = ranges[i]->min + ivs[idx]->var >= 0;
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    idx++;
  }
  if (cond_list.empty())
    return {};
  else {
    PrimExpr cond = cond_list[0];
    for (size_t i = 1; i < cond_list.size(); i++)
      cond = And(cond, cond_list[i]);
    return cond;
  }
}

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
/**
 * @brief Build a SIMT-style loop nest that performs element-wise atomic
 * additions from src to dst.
 *
 * Constructs a nested loop (parallelized per iter var) that loads a value from
 * the source buffer, optionally casts it to the destination dtype, and performs
 * an extern atomic add into the destination buffer address. For scalar
 * (zero-dimensional) operations a trivial serial For with a single BufferStore
 * is returned.
 *
 * The method:
 * - Creates iter vars for all non-singleton extents and binds them into the
 * provided analyzer.
 * - Validates loop variable counts against src/dst ranges (ICHECK on mismatch).
 * - Computes indexed accesses and emits optional bound predicates;
 * out-of-bounds accesses are masked to zero when predicates are uncertain.
 * - Emits an extern `call_extern("AtomicAdd", address_of(dst_value),
 * src_value)` call wrapped in an Evaluate statement.
 * - Wraps the body with a parallel For at each loop level. If `coalesced_width`
 * is defined it is attached as the "coalesced_width" annotation on each loop.
 *
 * Note: This function mutates the analyzer binding state by binding loop
 * variables and may fail via ICHECK if internal assumptions about shapes are
 * violated.
 *
 * @return A nested For loop (parallel loops) implementing the atomic-add
 * kernel. For scalar cases a serial For of extent 1 is returned.
 */
256
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
  Array<IterVar> loop_vars = MakeIterVars();
  bool is_scalar = loop_vars.size() == 0;
  if (is_scalar) {
    return For(Var("i"), 0, 1, ForKind::kSerial,
               BufferStore(dst, BufferLoad(src, {0}), {0}));
  }

  for (const auto &iv : loop_vars)
    analyzer->Bind(iv->var, iv->dom);

  ICHECK(loop_vars.size() <= src_range.size())
      << "loop_vars.size() = " << loop_vars.size()
      << ", src_range.size() = " << src_range.size() << ", src = " << src->name
      << ", dst = " << dst->name;

  ICHECK(loop_vars.size() <= dst_range.size())
      << "loop_vars.size() = " << loop_vars.size()
      << ", dst_range.size() = " << dst_range.size() << ", src = " << src->name
      << ", dst = " << dst->name;

  Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
  Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);

  PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
  PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);

  Array<PrimExpr> new_args;
  new_args.push_back(StringImm("AtomicAdd"));

  PrimExpr src_value = BufferLoad(src, src_indices);
  if (src->dtype != dst->dtype)
    src_value = Cast(dst->dtype, src_value);
  if (src_predicate.defined())
    src_value = if_then_else(src_predicate, src_value, make_zero(dst->dtype));

  PrimExpr dst_value = BufferLoad(dst, dst_indices);
  if (dst_predicate.defined())
    dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype));

  Call address_of_value =
      tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_value});

  new_args.push_back(address_of_value);
  new_args.push_back(src_value);

  Call atomicadd_call =
      tvm::tir::Call(dst->dtype, builtin::call_extern(), new_args);

  Stmt body = tvm::tir::Evaluate(atomicadd_call);

  for (int i = loop_vars.size() - 1; i >= 0; i--) {
    Map<String, ObjectRef> annotations = {};
    if (coalesced_width.defined()) {
      annotations.Set("coalesced_width", coalesced_width);
    }

    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body, std::nullopt, annotations);
  }
  return Downcast<For>(body);
}

319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
/**
 * @brief Lower the atomic-add top-level operator into a parallel, vectorized
 * TIR loop.
 *
 * Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs
 * layout inference at multiple levels, partitions the root loop by the provided
 * thread variable, vectorizes the thread loop, and returns the final
 * (optionally predicate-guarded) statement.
 *
 * The lowering pipeline:
 *  - Build the SIMT loop via MakeSIMTLoop.
 *  - Fuse parallel loops into a single For and wrap as a ParallelOp.
 *  - Run layout inference at kCommon, kStrict, and kFree levels using fields
 * from `T`.
 *  - Obtain the loop layout, partition the root loop with PartitionLoop by
 * `T.thread_var`.
 *  - Vectorize the partitioned thread loop via VectorizeLoop.
 *  - If the ParallelOp produced a predicate for `T.thread_var`, return an
 * IfThenElse that guards the vectorized loop with that predicate; otherwise
 * return the vectorized loop.
 *
 * @param T Lowering context whose fields are used:
 *   - T.target: target architecture for layout inference and lowering
 * decisions.
 *   - T.thread_var: the Var used to partition the outer loop for thread-level
 * parallelism.
 *   - T.thread_bounds: bounds associated with the thread dimension (used during
 * partitioning).
 *   - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used
 * during InferLayout.
 * @param analyzer Analyzer used for symbolic reasoning during partitioning and
 * folding (omitted from detailed param docs as a common analysis utility).
 * @return Stmt A lowered TIR statement representing the parallelized and
 * vectorized atomic-add.
 */
354
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
355
356
357
  Target target = T.target;
  auto simt_loop = MakeSIMTLoop(analyzer);
  auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
358
  auto par_op = ParallelOp(fused_loop);
359

360
361
362
  std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
                                    InferLevel::kFree};
  for (auto level : levels) {
363
    (par_op)->InferLayout(
364
        {T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
365
  }
366
367
368
369
370
371
372
373
374
  auto loop_layout = par_op->GetLoopLayout();
  Var thread_var = T.thread_var;
  Range thread_bounds = T.thread_bounds;
  auto thread_loop =
      PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
  // TODO(@dyq): buggy implementation, need to fix
  // vectorized_thread_loop = VectorizeAtomicAdd(
  //     thread_loop, thread_var, thread_bounds, GetArchInt(target));
  auto vectorized_thread_loop = VectorizeLoop(thread_loop);
375
376
377
378
379
380
381
382
383

  if (par_op->GetPredicate(T.thread_var).defined()) {
    return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                      vectorized_thread_loop);
  }

  return vectorized_thread_loop;
}

384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
/**
 * @brief Infer and return the layout map for the atomic add operator.
 *
 * Constructs a cached ParallelOp (by building the SIMT loop) if not already
 * present, validates that local.fragment layouts for src and dst match when
 * both are provided, and then delegates layout inference to the underlying
 * ParallelOp.
 *
 * @param T Layout inference inputs, including an optional mapping of buffers to
 * layouts.
 * @param level Inference strictness level.
 * @return LayoutMap The inferred layout mapping for buffers used by this
 * operator.
 *
 * @note This method mutates the AtomicAddNode by creating and storing a
 * ParallelOp on first invocation.
 * @throws If both src and dst have layouts in `local.fragment` and their
 * fragment layouts differ, an ICHECK failure is raised with diagnostic output.
 */
403
404
405
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
                                     InferLevel level) const {
  if (!par_op_.defined()) {
406
    arith::Analyzer analyzer;
407
    par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
  }
  if (T.layout_map.count(src) && T.layout_map.count(dst)) {
    if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
      const FragmentNode *src_layout = T.layout_map[src].as<FragmentNode>();
      const FragmentNode *dst_layout = T.layout_map[dst].as<FragmentNode>();
      if (src_layout && dst_layout) {
        ICHECK(src_layout->IsEqual(dst_layout, true))
            << "Get different layout for " << src << " and " << dst
            << "\nLHS = " << src_layout->DebugOutput()
            << "\nRHS = " << dst_layout->DebugOutput()
            << "\nYou may need to use a shared memory to transform the layout";
      }
    }
  }
  return par_op_->InferLayout(T, level);
}

TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
    .set_num_inputs(2)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

} // namespace tl
} // namespace tvm