atomic_add.cc 17.6 KB
Newer Older
1
2
3
/*!
 * \file tl/op/atomic_add.cc
 *
4
 * Define element-wise operators.
5
6
 */

7
8
#include "./atomic_add.h"
#include "./region.h"
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../target/utils.h"
#include "../transform/atomicadd_vectorize.h"
#include "../transform/common/loop_fusion_utils.h"
#include "../transform/loop_partition.h"
#include "builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

24
25
26
27
28
29
30
31
32
33
34
35
/**
 * @brief Extracts a numeric architecture identifier from a Target's "arch"
 * attribute.
 *
 * Reads the Target's "arch" string (must be defined) and, if it has the form
 * "sm_<N>", parses and returns N as an integer. For any other arch string,
 * returns 0.
 *
 * @param target Target whose "arch" attribute will be inspected (ICHECKs that
 * the attribute is defined).
 * @return int Parsed integer suffix when the arch is "sm_<N>", otherwise 0.
 */
36
37
38
39
static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
40
41
42
  std::string arch = s.value();
  if (arch.rfind("sm_", 0) == 0) {
    arch_int = std::stoi(arch.substr(3));
43
44
45
46
47
48
  } else {
    arch_int = 0;
  }
  return arch_int;
}

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
/**
 * @brief Construct an AtomicAdd operator from call arguments and a buffer map.
 *
 * Builds the internal AtomicAddNode, extracts the source and destination
 * regions and their backing Buffers from the first two call-style expressions
 * in `args` (via RegionOp), and stores them along with their ranges. If a third
 * argument is provided, it is interpreted as an integer immediate and stored as
 * the node's coalesced width.
 *
 * @param args Call-style PrimExprs where:
 *             - args[0] is the source region call,
 *             - args[1] is the destination region call,
 *             - args[2] (optional) is an IntImm specifying coalesced width.
 * @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects.
 *
 * Notes:
 * - The constructor checks that args[0] and args[1] are CallNodes.
 * - The constructed node is stored in this->data_.
 */
68
69
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
  ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
70
71
72
73
74
75
76
  Array<Range> rgs[2];
  Buffer bf[2];
  for (int i = 0; i < 2; i++) {
    auto expr = args[i];
    auto call = expr.as<CallNode>();
    ICHECK(call);
    auto region = RegionOp(call->args, vmap);
77
78
    rgs[i] = region->GetRanges();
    bf[i] = region->GetBuffer();
79
  }
80
81
  std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
  std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
82
  if (args.size() >= 3) {
83
84
85
86
    node->use_tma = Downcast<IntImm>(args[2]);
  }
  if (args.size() >= 4) {
    node->coalesced_width = Downcast<IntImm>(args[3]);
87
  }
88
  data_ = std::move(node);
89
90
}

91
92
93
94
95
96
97
98
99
/**
 * @brief Create a deep copy of this AtomicAdd node wrapped as a TileOperator.
 *
 * Produces a new AtomicAddNode object copied from this node. If this node has
 * an associated ParallelOp (par_op_), the parallel op is cloned and attached to
 * the new node so the cloned operator preserves parallelization state.
 *
 * @return TileOperator A TileOperator owning the cloned AtomicAddNode.
 */
100
101
102
103
104
105
106
107
TileOperator AtomicAddNode::Clone() const {
  auto op = make_object<AtomicAddNode>(*this);
  if (par_op_.defined()) {
    op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
  }
  return AtomicAdd(op);
}

108
109
110
111
112
113
114
115
116
117
118
119
120
/**
 * @brief Create data-parallel iteration variables for non-singleton dimensions
 * of the source.
 *
 * Constructs an Array of IterVar corresponding to each dimension in `src_range`
 * whose extent is not equal to 1. Each IterVar has domain Range(0, extent), a
 * Var named sequentially ("i", "j", "k", ...) with the same dtype as the
 * extent, and type IterVarType::kDataPar. The ordering of returned itervars
 * matches the order of dimensions in `src_range`.
 *
 * @return Array<IterVar> Iteration variables for all non-singleton extents in
 * `src_range`.
 */
121
Array<IterVar> AtomicAddNode::MakeIterVars() const {
122
123
124
125
126
127
128
129
130
131
132
133
134
135
  Array<IterVar> loop_vars;
  size_t idx = 0;
  for (size_t i = 0; i < src_range.size(); i++) {
    if (is_one(src_range[i]->extent))
      continue;
    Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
    idx++;
    loop_vars.push_back(
        {Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
  }
  return loop_vars;
}

// ivs: itervars returned by MakeIterVars()
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
/**
 * @brief Build index expressions for either source or destination from loop
 * iter vars.
 *
 * Given a list of iteration variables that correspond to the non-singleton
 * extents of the selected region (source when src_dst == 0, destination when
 * src_dst == 1), return an array of index expressions matching the full rank of
 * that region. For dimensions with extent == 1, the corresponding index is the
 * range's minimum; otherwise the index is `min + ivar`.
 *
 * @param ivs Iteration variables in order for all non-singleton dimensions of
 * the chosen region.
 * @param src_dst Selects which region to index: 0 for source (src_range), 1 for
 * destination (dst_range).
 * @return Array<PrimExpr> Index expressions for every dimension of the selected
 * region, in original dimension order.
 *
 * @note The function checks that the number of provided iter vars equals the
 * number of non-singleton extents; it will abort (ICHECK) if they differ.
 */
156
157
Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
                                           int src_dst) const {
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
  Array<PrimExpr> indices;
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
    if (is_one(ranges[i]->extent))
      indices.push_back(ranges[i]->min);
    else {
      indices.push_back(ranges[i]->min + ivs[idx]->var);
      idx++;
    }
  }
  ICHECK(idx == ivs.size())
      << "idx = " << idx << ", ivs.size() = " << ivs.size()
      << "src name = " << src->name << ", dst name = " << dst->name;
  return indices;
}

175
176
177
178
179
180
181
182
183
184
185
186
std::pair<Array<PrimExpr>, PrimExpr>
AtomicAddNode::ReturnIndicesAndSize(int src_dst) const {
  Array<PrimExpr> indices;
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  PrimExpr size = 1;
  for (size_t i = 0; i < ranges.size(); i++) {
    indices.push_back(ranges[i]->min);
    size *= ranges[i]->extent;
  }
  return {indices, size};
}

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
/**
 * @brief Build a combined bound-check predicate for indexed access.
 *
 * Constructs an AND'd predicate ensuring each non-singleton index (derived from
 * `ivs`) stays within [0, extent) for the selected operand (source when
 * `src_dst==0`, destination otherwise). For each non-unit Range in the chosen
 * range list this produces two conditions:
 *   - range.min + iv >= 0
 *   - range.min + iv < extent
 *
 * Conditions that the analyzer can prove (with symbolic bounds) are omitted.
 * If no uncertain conditions remain, an empty PrimExpr is returned.
 *
 * Note: the function ICHECKs that `extents.size()` equals the number of ranges
 * for the selected operand.
 *
 * @param ivs Iteration variables corresponding to non-singleton extents (order
 *            matches the non-unit ranges of the chosen operand).
 * @param extents Per-dimension upper bounds to check against; must have the
 *                same size as the selected range list.
 * @param src_dst Selects which ranges to validate: 0 => `src_range`, else
 *                `dst_range`.
 * @return PrimExpr A conjunction of remaining (non-provable) bounds checks, or
 *         an empty PrimExpr when no checks are required.
 */
212
213
214
215
PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
                                      const Array<IterVar> &ivs,
                                      Array<PrimExpr> extents,
                                      int src_dst) const {
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  Array<PrimExpr> cond_list;
  ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
    if (is_one(ranges[i]->extent))
      continue;
    PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    cond = ranges[i]->min + ivs[idx]->var >= 0;
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    idx++;
  }
  if (cond_list.empty())
    return {};
  else {
    PrimExpr cond = cond_list[0];
    for (size_t i = 1; i < cond_list.size(); i++)
      cond = And(cond, cond_list[i]);
    return cond;
  }
}

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
/**
 * @brief Build a SIMT-style loop nest that performs element-wise atomic
 * additions from src to dst.
 *
 * Constructs a nested loop (parallelized per iter var) that loads a value from
 * the source buffer, optionally casts it to the destination dtype, and performs
 * an extern atomic add into the destination buffer address. For scalar
 * (zero-dimensional) operations a trivial serial For with a single BufferStore
 * is returned.
 *
 * The method:
 * - Creates iter vars for all non-singleton extents and binds them into the
 * provided analyzer.
 * - Validates loop variable counts against src/dst ranges (ICHECK on mismatch).
 * - Computes indexed accesses and emits optional bound predicates;
 * out-of-bounds accesses are masked to zero when predicates are uncertain.
 * - Emits an extern `call_extern("AtomicAdd", address_of(dst_value),
 * src_value)` call wrapped in an Evaluate statement.
 * - Wraps the body with a parallel For at each loop level. If `coalesced_width`
 * is defined it is attached as the "coalesced_width" annotation on each loop.
 *
 * Note: This function mutates the analyzer binding state by binding loop
 * variables and may fail via ICHECK if internal assumptions about shapes are
 * violated.
 *
 * @return A nested For loop (parallel loops) implementing the atomic-add
 * kernel. For scalar cases a serial For of extent 1 is returned.
 */
271
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
272
  Array<IterVar> loop_vars = MakeIterVars();
273
  bool is_scalar = loop_vars.empty();
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
  if (is_scalar) {
    return For(Var("i"), 0, 1, ForKind::kSerial,
               BufferStore(dst, BufferLoad(src, {0}), {0}));
  }

  for (const auto &iv : loop_vars)
    analyzer->Bind(iv->var, iv->dom);

  ICHECK(loop_vars.size() <= src_range.size())
      << "loop_vars.size() = " << loop_vars.size()
      << ", src_range.size() = " << src_range.size() << ", src = " << src->name
      << ", dst = " << dst->name;

  ICHECK(loop_vars.size() <= dst_range.size())
      << "loop_vars.size() = " << loop_vars.size()
      << ", dst_range.size() = " << dst_range.size() << ", src = " << src->name
      << ", dst = " << dst->name;

  Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
  Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);

  PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
  PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);

  Array<PrimExpr> new_args;
  new_args.push_back(StringImm("AtomicAdd"));

  PrimExpr src_value = BufferLoad(src, src_indices);
  if (src->dtype != dst->dtype)
    src_value = Cast(dst->dtype, src_value);
  if (src_predicate.defined())
    src_value = if_then_else(src_predicate, src_value, make_zero(dst->dtype));

  PrimExpr dst_value = BufferLoad(dst, dst_indices);
  if (dst_predicate.defined())
    dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype));

311
  new_args.push_back(dst_value);
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
  new_args.push_back(src_value);

  Call atomicadd_call =
      tvm::tir::Call(dst->dtype, builtin::call_extern(), new_args);

  Stmt body = tvm::tir::Evaluate(atomicadd_call);

  for (int i = loop_vars.size() - 1; i >= 0; i--) {
    Map<String, ObjectRef> annotations = {};
    if (coalesced_width.defined()) {
      annotations.Set("coalesced_width", coalesced_width);
    }

    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body, std::nullopt, annotations);
  }
  return Downcast<For>(body);
}

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
/**
 * @brief Lower the atomic-add top-level operator into a parallel, vectorized
 * TIR loop.
 *
 * Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs
 * layout inference at multiple levels, partitions the root loop by the provided
 * thread variable, vectorizes the thread loop, and returns the final
 * (optionally predicate-guarded) statement.
 *
 * The lowering pipeline:
 *  - Build the SIMT loop via MakeSIMTLoop.
 *  - Fuse parallel loops into a single For and wrap as a ParallelOp.
 *  - Run layout inference at kCommon, kStrict, and kFree levels using fields
 * from `T`.
 *  - Obtain the loop layout, partition the root loop with PartitionLoop by
 * `T.thread_var`.
 *  - Vectorize the partitioned thread loop via VectorizeLoop.
 *  - If the ParallelOp produced a predicate for `T.thread_var`, return an
 * IfThenElse that guards the vectorized loop with that predicate; otherwise
 * return the vectorized loop.
 *
 * @param T Lowering context whose fields are used:
 *   - T.target: target architecture for layout inference and lowering
 * decisions.
 *   - T.thread_var: the Var used to partition the outer loop for thread-level
 * parallelism.
 *   - T.thread_bounds: bounds associated with the thread dimension (used during
 * partitioning).
 *   - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used
 * during InferLayout.
 * @param analyzer Analyzer used for symbolic reasoning during partitioning and
 * folding (omitted from detailed param docs as a common analysis utility).
 * @return Stmt A lowered TIR statement representing the parallelized and
 * vectorized atomic-add.
 */
366
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
367
  Target target = T.target;
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
  if (use_tma->value != 0) {
    Array<PrimExpr> src_indices, dst_indices;
    PrimExpr src_size, dst_size;
    std::tie(src_indices, src_size) = ReturnIndicesAndSize(0);
    std::tie(dst_indices, dst_size) = ReturnIndicesAndSize(1);
    ICHECK(analyzer->CanProveEqual(src_size, dst_size))
        << "src_size = " << src_size << ", dst_size = " << dst_size;
    BufferLoad src_node = BufferLoad(src, src_indices);
    BufferLoad dst_node = BufferLoad(dst, dst_indices);
    Call address_of_src =
        Call(DataType::Handle(), builtin::address_of(), {src_node});
    Call address_of_dst =
        Call(DataType::Handle(), builtin::address_of(), {dst_node});

    int need_reduce = 1;
    int eviction_policy = 0;
    auto body = Evaluate(Call(DataType::Handle(), tma_store(),
                              {address_of_src, address_of_dst,
                               ceildiv(src_size * src->dtype.bits(), 8),
                               need_reduce, eviction_policy}));
    return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), body);
  }
390
391
  auto simt_loop = MakeSIMTLoop(analyzer);
  auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
392
  auto par_op = ParallelOp(fused_loop);
393

394
395
396
  std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
                                    InferLevel::kFree};
  for (auto level : levels) {
397
398
399
    (par_op)->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
                           false, T.buffer_remap},
                          level);
400
  }
401
402
403
404
405
  auto loop_layout = par_op->GetLoopLayout();
  Var thread_var = T.thread_var;
  Range thread_bounds = T.thread_bounds;
  auto thread_loop =
      PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
406
407
  auto vectorized_thread_loop = VectorizeAtomicAdd(
      thread_loop, thread_var, thread_bounds, GetArchInt(target));
408
409
410
411
412
413
414
415
416

  if (par_op->GetPredicate(T.thread_var).defined()) {
    return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                      vectorized_thread_loop);
  }

  return vectorized_thread_loop;
}

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
/**
 * @brief Infer and return the layout map for the atomic add operator.
 *
 * Constructs a cached ParallelOp (by building the SIMT loop) if not already
 * present, validates that local.fragment layouts for src and dst match when
 * both are provided, and then delegates layout inference to the underlying
 * ParallelOp.
 *
 * @param T Layout inference inputs, including an optional mapping of buffers to
 * layouts.
 * @param level Inference strictness level.
 * @return LayoutMap The inferred layout mapping for buffers used by this
 * operator.
 *
 * @note This method mutates the AtomicAddNode by creating and storing a
 * ParallelOp on first invocation.
 * @throws If both src and dst have layouts in `local.fragment` and their
 * fragment layouts differ, an ICHECK failure is raised with diagnostic output.
 */
436
437
438
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
                                     InferLevel level) const {
  if (!par_op_.defined()) {
439
    arith::Analyzer analyzer;
440
    par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
  }
  if (T.layout_map.count(src) && T.layout_map.count(dst)) {
    if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
      const FragmentNode *src_layout = T.layout_map[src].as<FragmentNode>();
      const FragmentNode *dst_layout = T.layout_map[dst].as<FragmentNode>();
      if (src_layout && dst_layout) {
        ICHECK(src_layout->IsEqual(dst_layout, true))
            << "Get different layout for " << src << " and " << dst
            << "\nLHS = " << src_layout->DebugOutput()
            << "\nRHS = " << dst_layout->DebugOutput()
            << "\nYou may need to use a shared memory to transform the layout";
      }
    }
  }
  return par_op_->InferLayout(T, level);
}

TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
    .set_num_inputs(2)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

463
464
TVM_FFI_STATIC_INIT_BLOCK({ AtomicAddNode::RegisterReflection(); });

465
466
} // namespace tl
} // namespace tvm