reduce.cc 23 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*!
 * \file tl/op/reduce.cc
 *
 * Define reduce operator.
 */

#include "reduce.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
12
#include <tvm/tir/stmt_functor.h>
13
14

#include "../layout/utils.h"
15
#include "../op/parallel.h"
16
#include "../target/utils.h"
17
#include "../transform/loop_partition.h"
18
#include "tir/transforms/ir_utils.h"
19
20
21
22
23
24

namespace tvm {
namespace tl {

using namespace tir;

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
/**
 * @brief Construct a ReduceOp from raw TL arguments and a buffer mapping.
 *
 * Interprets `args` and `vmap` to populate an internal ReduceOpNode:
 * - args[0]: access pointer for the source buffer
 * - args[1]: access pointer for the destination buffer
 * - args[2]: string literal specifying the reduce type: "sum", "abssum",
 *            "absmax", "max", or "min"
 * - args[3]: integer literal for the reduction dimension (axis)
 * - args[4]: boolean literal indicating whether to clear/init the destination
 *
 * The constructor resolves the access pointers via `vmap`, maps the reduce
 * type string to the ReduceType enum, assigns the reduction dimension and
 * clear flag, and stores the constructed node in `data_`. An invalid reduce
 * type triggers a fatal check.
 *
 * @param args Array of TL prim-expr arguments as described above.
 * @param vmap Mapping from variables (from access pointers) to Buffer objects.
 */
44
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
45
46
47
48
49
  ObjectPtr<ReduceOpNode> node = make_object<ReduceOpNode>();
  node->src = vmap[GetVarFromAccessPtr(args[0])];
  node->dst = vmap[GetVarFromAccessPtr(args[1])];
  std::string reduce_type = args[2].as<StringImm>().value()->value;
  node->dim = args[3].as<IntImm>().value()->value;
50
  if (reduce_type == "sum")
51
    node->type = ReduceType::kSum;
52
  else if (reduce_type == "abssum")
53
    node->type = ReduceType::kAbsSum;
54
  else if (reduce_type == "absmax")
55
    node->type = ReduceType::kAbsMax;
56
  else if (reduce_type == "max")
57
    node->type = ReduceType::kMax;
58
  else if (reduce_type == "min")
59
    node->type = ReduceType::kMin;
60
61
  else
    ICHECK(0) << "Unknown reduce type: " << reduce_type;
62
63
  node->clear = args[4].as<Bool>().value();
  data_ = std::move(node);
64
65
}

66
67
68
69
70
71
72
73
/**
 * @brief Create a copy of this ReduceOpNode wrapped as a TileOperator.
 *
 * Returns a new TileOperator holding a freshly allocated ReduceOpNode
 * constructed as a copy of this node.
 *
 * @return TileOperator A tile operator that owns the cloned ReduceOpNode.
 */
74
75
76
77
78
TileOperator ReduceOpNode::Clone() const {
  auto op = make_object<ReduceOpNode>(*this);
  return ReduceOp(op);
}

79
80
81
82
83
84
85
86
87
/**
 * @brief Create a deep copy of this CumSum op node wrapped as a TileOperator.
 *
 * Returns a new TileOperator whose underlying CumSumOpNode is a copy of
 * the current node. Useful for cloning operators when building or
 * transforming computation graphs.
 *
 * @return TileOperator A TileOperator containing a copy of this node.
 */
88
89
90
91
92
TileOperator CumSumOpNode::Clone() const {
  auto op = make_object<CumSumOpNode>(*this);
  return CumSumOp(op);
}

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
/**
 * @brief Create the initial accumulator value for the destination buffer based
 * on reduction type.
 *
 * Returns the PrimExpr representing the initial value stored in the destination
 * accumulator before any source elements are combined. The returned value
 * depends on the destination dtype and the node's reduction type:
 * - kSum, kAbsSum: zero of the destination dtype.
 * - kMax: minimum representable value for signed integers, zero for unsigned
 * integers, and -INFINITY for floating point.
 * - kMin: maximum representable value for signed integers, all-ones (max) for
 * unsigned integers, and +INFINITY for floating point.
 * - kAbsMax: zero of the destination dtype.
 *
 * The function will abort (ICHECK failure) if the reduction type is
 * unrecognized.
 *
 * @return PrimExpr initial value appropriate for `dst->dtype` and `type`.
 */
112
PrimExpr ReduceOpNode::MakeInitValue() const {
113
114
115
116
117
  auto dst_dtype = dst->dtype;
  auto is_int = dst_dtype.is_int();
  bool is_uint = dst_dtype.is_uint();
  auto bits = dst_dtype.bits();

118
  switch (type) {
119
120
121
122
123
  case ReduceType::kSum:
    return make_zero(dst->dtype);
  case ReduceType::kAbsSum:
    return make_zero(dst->dtype);
  case ReduceType::kMax:
124
125
126
127
128
129
130
    if (is_int) {
      return make_const(dst->dtype, -(1 << (bits - 1)));
    } else if (is_uint) {
      return make_const(dst->dtype, 0);
    } else {
      return make_const(dst->dtype, -INFINITY);
    }
131
  case ReduceType::kMin:
132
133
134
135
136
137
138
    if (is_int) {
      return make_const(dst->dtype, (1 << (bits - 1)) - 1);
    } else if (is_uint) {
      return make_const(dst->dtype, (1 << bits) - 1);
    } else {
      return make_const(dst->dtype, INFINITY);
    }
139
140
  case ReduceType::kAbsMax:
    return make_const(dst->dtype, 0);
141
142
  default:
    ICHECK(0);
143
144
145
  }
}

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
/**
 * @brief Combine two scalar expressions according to this node's reduction
 * type.
 *
 * Casts the right operand to the left operand's dtype if they differ, then
 * returns the reduction of `a` and `b` using the operator specified by `type`:
 * - kSum: `a + b`
 * - kAbsSum: `a + max(b, -b)`
 * - kMax: `max(a, b)`
 * - kMin: `min(a, b)`
 * - kAbsMax: `max(max(a, b), -min(a, b))`
 *
 * @param a Left-hand operand (result dtype drives the output dtype).
 * @param b Right-hand operand (will be cast to `a`'s dtype if needed).
 * @return PrimExpr The combined expression with dtype equal to `a.dtype`.
 *
 * @note The function DCHECKs/ICHECKs on an unknown/unsupported reduction type.
 */
164
PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
165
166
167
168
169
  PrimExpr lhs = a, rhs = b;
  if (lhs->dtype != rhs->dtype) {
    rhs = Cast(lhs->dtype, rhs);
  }
  switch (type) {
170
171
172
173
174
175
176
177
  case ReduceType::kSum:
    return lhs + rhs;
  case ReduceType::kAbsSum:
    return lhs + Max(rhs, -rhs);
  case ReduceType::kMax:
    return Max(lhs, rhs);
  case ReduceType::kMin:
    return Min(lhs, rhs);
178
179
  case ReduceType::kAbsMax:
    return Max(Max(lhs, rhs), -Min(lhs, rhs));
180
181
182
  default:
    ICHECK(0);
    return PrimExpr(0);
183
184
185
  }
}

186
187
188
189
190
191
192
193
194
195
196
197
198
199
/**
 * @brief Map the reduction type to the codegen reducer name used by external
 * ALL-Reduce/CUDA helpers.
 *
 * Returns the string identifier of the code-generation reducer corresponding to
 * this ReduceOpNode's `type`. Mapping:
 * - kSum, kAbsSum -> "tl::SumOp"
 * - kMax, kAbsMax -> "tl::MaxOp"
 * - kMin -> "tl::MinOp"
 *
 * The function terminates with a check failure if `type` is unknown.
 *
 * @return std::string Reducer name used by codegen extern calls.
 */
200
std::string ReduceOpNode::MakeCodegenReducer() const {
201
  switch (type) {
202
203
204
205
206
207
208
209
  case ReduceType::kSum:
    return "tl::SumOp";
  case ReduceType::kAbsSum:
    return "tl::SumOp";
  case ReduceType::kMax:
    return "tl::MaxOp";
  case ReduceType::kMin:
    return "tl::MinOp";
210
211
  case ReduceType::kAbsMax:
    return "tl::MaxOp";
212
213
214
  default:
    ICHECK(0);
    return "";
215
216
217
  }
}

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
/**
 * @brief Lower the Reduce operator node to a TIR statement.
 *
 * Lowers a ReduceOpNode that targets fragment-local buffers into a sequence of
 * TIR statements implementing: per-thread local reduction, inter-thread
 * AllReduce (when needed), and final writeback (with an optional duplicate
 * clear buffer to avoid in-place conflicts). Supports reduction kinds
 * (sum/abs-sum/max/min/abs-max) and handles layout-driven index mapping and
 * loop partitioning to thread axes.
 *
 * @param T Lowering context providing buffer remapping, layout map, target and
 *          thread bounds, and workspace allocation helper. Must contain
 *          fragment-local mappings for both src and dst.
 * @param analyzer Symbolic analyzer used to simplify and compress iterators.
 * @return Stmt The constructed TIR statement implementing the reduction.
 *
 * Preconditions:
 * - src and dst buffers must be in "local.fragment" scope.
 * - The layouts must have compatible input/output dimensions for the
 *   specified reduction axis.
 *
 * Failure modes:
 * - The function uses ICHECK to enforce unsupported scopes, dimension
 *   mismatches, unknown reduction types, and other invariants; violations
 *   will trigger a fatal check failure.
 */
244
Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
245
246
  ICHECK(this->src.scope() == "local.fragment" &&
         this->dst.scope() == "local.fragment")
247
248
249
250
251
      << "Reduce for shared memory not implemented.";
  auto src_buffer = T.buffer_remap[this->src];
  auto dst_buffer = T.buffer_remap[this->dst];
  Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
  Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
252
253
254
255
256
257
258
259
260
261
262
263
  size_t src_dim = src_layout->InputDim();
  size_t dst_dim = dst_layout->InputDim();

  bool is_1d_reduce = src_dim == dst_dim && dst_dim == 1;

  if (is_1d_reduce) {
    ICHECK(is_one(dst_layout->OutputShape().back()))
        << "Reduce for scalar not implemented.";
  } else {
    ICHECK(src_dim == dst_dim + 1) << "Reduce dimension mismatch.";
  }

264
  Array<IterVar> dst_vars;
265
  for (size_t i = 0; i < dst_dim; i++) {
266
    Var var = Var(std::string{char('i' + i)});
267
268
    dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
                               IterVarType::kDataPar));
269
  }
270
271
272
273
  Array<IterVar> src_vars;
  if (!is_1d_reduce) {
    src_vars = dst_vars;
  }
274
275
276
277
278
279
280
  src_vars.insert(src_vars.begin() + this->dim,
                  {Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
                   IterVarType::kDataPar});
  Array<PrimExpr> src_indices = src_layout->Forward(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
  Array<PrimExpr> dst_indices = dst_layout->Forward(
      dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
281
282
283

  Array<Stmt> stmts;

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
  bool require_init = this->clear;
  // sum op must be cleared
  if (this->type == ReduceType::kSum) {
    require_init = true;
  } else if (this->type == ReduceType::kAbsSum) {
    require_init = true;
  }

  Buffer clear_buffer = dst_buffer;
  bool need_duplicate = false;
  if (this->type == ReduceType::kSum && !this->clear) {
    need_duplicate = true;
  } else if (this->type == ReduceType::kAbsSum && !this->clear) {
    need_duplicate = true;
  }

  if (need_duplicate) {
    // Create a new buffer with same shape and dtype as dst_buffer
    clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype,
                               dst_buffer->name + "_clear",
                               GetPtrStorageScope(dst_buffer->data));
  }

307
  // make reduce-init stmt
308
  if (require_init)
309
    stmts.push_back(
310
        BufferStore(clear_buffer, this->MakeInitValue(), dst_indices));
311
312
313
314
315
316
317

  // make thread-local reduce
  Array<PrimExpr> src_indice_compressed;
  Array<IterVar> src_var_compressed;
  for (size_t i = 0; i < src_layout->OutputDim(); i++) {
    PrimExpr expr;
    IterVar var;
318
319
    std::tie(expr, var) = CompressIterator(src_indices[i], src_vars,
                                           src_vars[this->dim]->var, analyzer);
320
321
322
    src_indice_compressed.push_back(expr);
    src_var_compressed.push_back(var);
  }
323
  Stmt reduce_local = BufferStore(
324
325
      clear_buffer,
      this->MakeReduce(BufferLoad(clear_buffer, dst_indices),
326
327
                       BufferLoad(src_buffer, src_indice_compressed)),
      dst_indices);
328
329
  for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
    reduce_local =
330
        For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
331
            ForKind::kUnrolled, reduce_local, std::nullopt,
332
            {{tir::attr::pragma_unroll_explicit, Bool(false)}});
333
334
335
336
  }
  stmts.push_back(reduce_local);

  // make inter-thread reduce
337
338
339
340
341
  PrimExpr src_thread = src_layout->ForwardThread(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
  auto iter_sum =
      arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
  for (const auto &iter_split : iter_sum->args) {
342
    auto mark = iter_split->source->source.as<Var>();
343
    ICHECK(mark) << "Not a normalized iterator: " << iter_split->source;
344
345
346
347
    if (mark.value().same_as(src_vars[this->dim]->var)) {
      auto scale = as_const_int(iter_split->scale);
      auto extent = as_const_int(iter_split->extent);
      ICHECK(scale != nullptr && extent != nullptr);
348
349
      if (*extent == 1)
        continue;
350

351
352
      int reducing_threads = (*extent) * (*scale);
      std::stringstream ss;
353

354
      auto thread_offset = T.thread_bounds->min;
355
      if (TargetIsHopper(T.target)) {
356
        auto all_threads = T.thread_bounds->extent;
357
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
358
359
           << reducing_threads << ", " << (*scale) << ", " << thread_offset
           << ", " << all_threads << ">::run_hopper";
360
361
      } else {
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
362
363
           << reducing_threads << ", " << (*scale) << ", " << thread_offset
           << ">::run";
364
      }
365
      Array<PrimExpr> thread_reduce_args = {
366
          StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)};
367
      if (reducing_threads >= 32) {
368
        PrimExpr workspace = T.AddWorkspace(
369
            *as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
370
371
        thread_reduce_args.push_back(workspace);
      }
372
      auto call =
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
          Call(clear_buffer->dtype, builtin::call_extern(), thread_reduce_args);
      stmts.push_back(BufferStore(clear_buffer, call, dst_indices));
    }
  }
  Stmt reduce_interthread = BufferStore(
      clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices);

  // copy clear_buffer to dst_buffer
  if (need_duplicate) {
    // if is reduce sum, we should add a copy from clear_buffer to dst_buffer
    if (this->type == ReduceType::kSum) {
      stmts.push_back(BufferStore(dst_buffer,
                                  Add(BufferLoad(dst_buffer, dst_indices),
                                      BufferLoad(clear_buffer, dst_indices)),
                                  dst_indices));
    } else if (this->type == ReduceType::kAbsSum) {
      stmts.push_back(BufferStore(dst_buffer,
                                  Add(BufferLoad(dst_buffer, dst_indices),
                                      BufferLoad(clear_buffer, dst_indices)),
                                  dst_indices));
    } else {
      ICHECK(false) << "Unsupported reduce type: " << (int)this->type;
395
396
397
398
399
    }
  }
  // make the outer spatial loop
  Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
  for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
400
401
    body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
               ForKind::kParallel, body);
402
403
404
  }

  body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
405
406
407
408
  if (need_duplicate) {
    body = Allocate(clear_buffer->data, clear_buffer->dtype,
                    clear_buffer->shape, const_true(), body);
  }
409
410
411
  return body;
}

412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
/**
 * @brief Infer a layout mapping for the destination buffer of a Reduce
 * operator.
 *
 * When inference level is below `kStrict`, and both source and destination
 * buffers live in `local.fragment` with a known source fragment layout, this
 * computes a candidate destination Fragment layout that accounts for
 * replication over the reduction dimension and binds thread ranges from
 * `T.thread_bounds`.
 *
 * Behavior:
 * - Constructs a destination Fragment whose replicate extent equals
 *   src.shape[dim] * src_fragment.ReplicateExtent(), and whose threading is
 *   derived from the source fragment with the reduction dimension folded out.
 * - If no layout exists for `dst` in `T.layout_map`, returns a map {dst ->
 * inferred}.
 * - If `dst` already has a layout, validates that the existing layout strictly
 *   contains the computed layout (shapes match and fragment containment holds).
 *   If compatible but the computed replicate extent is larger, returns the new
 * layout.
 * - In all other cases (strict inference level, unsupported scopes, or no src
 * layout), returns an empty map.
 *
 * @param T Layout inference context containing `layout_map` and
 * `thread_bounds`.
 * @param level Inference strictness; no inference is performed at or above
 * `kStrict`.
 * @return LayoutMap A mapping for `dst` to an inferred Fragment layout, or
 * empty.
 * @throws LayoutConflictException if an existing `dst` layout conflicts with
 * the computed layout (not containable or incompatible replication extents).
 */
444
445
LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
446
447
  if (level >= InferLevel::kStrict)
    return {};
448
  if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
449
      T.layout_map.count(src)) {
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
    auto src_layout = T.layout_map[src].as<Fragment>().value();

    PrimExpr indice_rep_extent = src->shape[dim];
    PrimExpr src_rep_extent = src_layout->ReplicateExtent();
    PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent;

    Array<PrimExpr> fwd;
    for (int i = 0; i < static_cast<int>(src->shape.size()); i++) {
      if (i == dim) {
        fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
      } else if (i < dim) {
        fwd.push_back(InputPlaceholder(i));
      } else if (i > dim) {
        fwd.push_back(InputPlaceholder(i - 1));
      }
    }
466
467
    auto thd = src_layout->ForwardThread(
        fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
468
    Fragment dst_layout =
469
        Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
470
471
            ->CondenseReplicateVar()
            ->BindThreadRange(T.thread_bounds);
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
    if (!T.layout_map.count(dst))
      return {{dst, dst_layout}};
    else {
      // Check if computed layout is compatible with existing: the existing one
      // must strictly contains the computed layout
      auto orig_dst_layout =
          T.layout_map.Get(dst).value().as<Fragment>().value();
      ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim());
      Array<PrimExpr> indices;
      indices.reserve(dst_layout->InputDim());
      arith::Analyzer inner_analyzer;
      for (int i = 0; i < dst_layout->InputDim(); ++i) {
        auto x = InputPlaceholder(i);
        indices.push_back(x);
        // should be literal - literal = 0, any analyzer will work
        ICHECK(is_zero(inner_analyzer.Simplify(
            dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i])));
        inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
      }

      ICHECK(as_const_int(dst_layout->ReplicateExtent()));
      ICHECK(as_const_int(src_layout->ReplicateExtent()));
      auto dst_rep = *as_const_int(dst_layout->ReplicateExtent());
      auto src_rep = *as_const_int(src_layout->ReplicateExtent());
      if (dst_rep < src_rep ||
          !ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices,
                                 inner_analyzer)) {
        std::ostringstream oss;
        oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. "
            << src << "\nLHS = " << src_layout->DebugOutput()
            << "\nRHS = " << orig_dst_layout->DebugOutput()
            << "\nYou may need to use a shared memory to transform the "
               "layout";
        throw LayoutConflictException(oss.str());
      }

      if (dst_rep > src_rep) {
        return {{dst, dst_layout}};
      }
    }
512
513
514
515
516
517
  }
  return {};
}

TIR_REGISTER_TL_OP(ReduceOp, reduce)
    .set_num_inputs(4)
518
519
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
520

521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
/**
 * @brief Construct a CumSumOp from a list of arguments and a buffer map.
 *
 * Expects args to contain exactly four PrimExprs in this order:
 *  0: access pointer to source buffer (src),
 *  1: access pointer to destination buffer (dst),
 *  2: integer dimension to perform the cumulative sum along (dim),
 *  3: boolean flag indicating whether to compute the cumsum in reverse
 * (reverse).
 *
 * The constructor resolves src and dst from the provided BufferMap and stores
 * the parsed dim and reverse values on the node. It verifies that args.size()
 * == 4 and that dim is a valid axis for the source buffer shape.
 *
 * @param args Array of PrimExpr as described above.
 */
537
538
539
540
541
542
543
544
545
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
  /*
    CumSum arguments:
      src: input buffer
      dst: output buffer
      dim: dimension to cumsum
      reverse: whether to cumsum in reverse order
   */
  CHECK_EQ(args.size(), 4);
546
547
548
549
550
551
552
  ObjectPtr<CumSumOpNode> node = make_object<CumSumOpNode>();
  node->src = vmap[GetVarFromAccessPtr(args[0])];
  node->dst = vmap[GetVarFromAccessPtr(args[1])];
  node->dim = args[2].as<IntImm>().value()->value;
  node->reverse = args[3].as<Bool>().value();
  CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()));
  data_ = std::move(node);
553
554
}

555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
/**
 * @brief Lower the CumSum operator to TIR.
 *
 * Produces a TIR statement implementing cumulative sum depending on buffer
 * scopes:
 * - For shared/shared.dyn scopes: emits an extern call to
 * `tl::CumSum2D<threads, dim, reverse>::run` with arguments [function_name,
 * src.access_ptr(1), dst.access_ptr(3), src.shape...]. The number of threads is
 * taken from `T.thread_bounds->extent`. Returns an Evaluate(Call(...))
 * statement.
 * - For local.fragment scopes on both src and dst: fatal error (not
 * implemented).
 * - For any other scope combinations: fails with an assertion.
 *
 * The `analyzer` parameter is accepted for interface compatibility but is not
 * used by this lowering.
 *
 * @param T Lowering arguments (provides thread bounds and other lowering
 * context).
 * @return Stmt A TIR statement representing the lowered cumulative-sum
 * operation.
 */
577
Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
578
579
580
581
582
583
584
585
  if (this->src.scope() == "local.fragment" &&
      this->dst.scope() == "local.fragment") {
    LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue "
                  "if you need this feature.";
  } else if (this->src.scope() == "shared.dyn" ||
             this->src.scope() == "shared") {
    ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
    std::stringstream ss;
586
    auto threads = T.thread_bounds->extent;
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
    ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
       << (reverse ? "true" : "false") << ">::run";
    Array<PrimExpr> args = {StringImm(ss.str()), src.access_ptr(1),
                            dst.access_ptr(3)};
    for (int i = 0; i < src->shape.size(); i++) {
      args.push_back(src->shape[i]);
    }
    return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
  } else {
    ICHECK(false) << "Cannot lower cumsum for " << this->src.scope() << " and "
                  << this->dst.scope();
  }

  return Stmt();
}

603
604
605
606
607
608
609
610
611
612
613
/**
 * @brief Layout inference for CumSum operator.
 *
 * CumSum does not perform any layout inference; this function always returns
 * an empty mapping. The operator's lowering expects shared-memory semantics
 * and layout decisions are handled elsewhere.
 *
 * @param T Layout inference inputs (buffers, existing layouts, etc.).
 * @param level Inference strictness level (unused).
 * @return LayoutMap Empty map indicating no inferred layouts.
 */
614
615
LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
616
617
618
619
620
621
622
  return {};
}

TIR_REGISTER_TL_OP(CumSumOp, cumsum)
    .set_num_inputs(4)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
623
624
} // namespace tl
} // namespace tvm