reduce.cc 24.7 KB
Newer Older
1
2
/*!
 * \file tl/op/reduce.cc
3
 * \brief Implementation of reduction operators
4
5
6
7
8
9
10
 */

#include "reduce.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
11
#include <tvm/tir/stmt_functor.h>
12
13

#include "../layout/utils.h"
14
#include "../op/parallel.h"
15
#include "../target/utils.h"
16
#include "../transform/loop_partition.h"
17
#include "region.h"
18
#include "tir/transforms/ir_utils.h"
19
#include "tvm/tir/stmt.h"
20
21
22
23
24
25

namespace tvm {
namespace tl {

using namespace tir;

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// Normalize an argument (BufferRegion/BufferLoad/tl.region)
// to BufferRegion so Reduce can uniformly consume regions.
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
                                            const BufferMap &vmap) {
  // Case 1: Already a BufferRegion
  if (arg->IsInstance<BufferRegionNode>()) {
    return Downcast<BufferRegion>(arg);
  }

  // Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
  // extent=1)
  if (const auto *load = arg.as<BufferLoadNode>()) {
    Array<Range> ranges;
    for (const PrimExpr &index : load->indices) {
      if (const auto *ramp = index.as<RampNode>()) {
        ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
        ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
            << "Only stride-1 Ramp is supported in region conversion";
        ICHECK(ramp->lanes.as<IntImmNode>())
            << "Scalable vector lanes not supported in region conversion";
        ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
      } else {
        ranges.push_back(Range::FromMinExtent(index, 1));
      }
    }
    return BufferRegion(load->buffer, ranges);
  }

  // Case 3: Call nodes (only tl.region)
  if (const auto *call = arg.as<CallNode>()) {
    // tl.region(...) — reconstruct via RegionOp
    if (call->op.same_as(RegionOp::Get())) {
      RegionOp region(call->args, vmap);
      return BufferRegion(region->GetBuffer(), region->GetRanges());
    }
61
62
63
64
65
66
67
68
69
70
    // builtin.tvm_access_ptr(...) — map var to Buffer and take full region
    if (call->op.same_as(builtin::tvm_access_ptr())) {
      Var var = Downcast<Var>(call->args[1]);
      Buffer buf = vmap[var];
      Array<Range> ranges;
      for (PrimExpr extent : buf->shape) {
        ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
      }
      return BufferRegion(buf, ranges);
    }
71
72
73
74
75
76
  }

  LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg;
  throw; // Unreachable
}

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
                                        int rw_mask) {
  Buffer buf = region->buffer;
  int ndim = static_cast<int>(buf->shape.size());
  ICHECK(ndim == 1 || ndim == 2) << "Cumsum expects buffers with 1 or 2 dims";

  PrimExpr offset, extent;
  if (ndim == 1) {
    // Simple 1D region: offset and extent come from the single axis.
    auto axis = region->region[0];
    offset = axis->min;
    extent = axis->extent;
  } else {
    // Compute row-major strides for ndim >= 2
    std::vector<PrimExpr> strides(ndim);
    PrimExpr one = make_const(buf->shape[0].dtype(), 1);
    PrimExpr cur = one;
    for (int i = ndim - 1; i >= 0; --i) {
      strides[i] = cur;
      cur = cur * buf->shape[i];
    }
    // Offset: sum_{i in [0..ndim-3]} min_i * stride_i
    offset = make_const(buf->shape[0].dtype(), 0);
    for (int i = 0; i < ndim - 2; ++i) {
      offset = offset + region->region[i]->min * strides[i];
    }

    // Extent: last two extents product (elements)
    extent =
        region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
  }

  // ptype and return handle
  PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
  Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
                           IntImm(DataType::Int(32), rw_mask)};
  return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}

120
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
121
  ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
122
123
124
125
126
  // Accept BufferRegion/BufferLoad/tl.region for src/dst
  node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap);
  node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap);
  node->src = node->srcRegion_->buffer;
  node->dst = node->dstRegion_->buffer;
127
128
  std::string reduce_type = args[2].as<StringImm>().value()->value;
  node->dim = args[3].as<IntImm>().value()->value;
129
  node->type = ReduceType(reduce_type);
130
131
  node->clear = args[4].as<Bool>().value();
  data_ = std::move(node);
132
133
}

134
TileOperator ReduceOpNode::Clone() const {
135
  auto op = tvm::ffi::make_object<ReduceOpNode>(*this);
136
137
138
139
  return ReduceOp(op);
}

TileOperator CumSumOpNode::Clone() const {
140
  auto op = tvm::ffi::make_object<CumSumOpNode>(*this);
141
142
143
144
  return CumSumOp(op);
}

PrimExpr ReduceOpNode::MakeInitValue() const {
145
146
147
148
149
  auto dst_dtype = dst->dtype;
  auto is_int = dst_dtype.is_int();
  bool is_uint = dst_dtype.is_uint();
  auto bits = dst_dtype.bits();

150
  if (type->isSum()) {
151
    return make_zero(dst->dtype);
152
  } else if (type->isAbsSum()) {
153
    return make_zero(dst->dtype);
154
  } else if (type->isMax()) {
155
156
157
158
159
160
161
    if (is_int) {
      return make_const(dst->dtype, -(1 << (bits - 1)));
    } else if (is_uint) {
      return make_const(dst->dtype, 0);
    } else {
      return make_const(dst->dtype, -INFINITY);
    }
162
  } else if (type->isMin()) {
163
164
165
166
167
168
169
    if (is_int) {
      return make_const(dst->dtype, (1 << (bits - 1)) - 1);
    } else if (is_uint) {
      return make_const(dst->dtype, (1 << bits) - 1);
    } else {
      return make_const(dst->dtype, INFINITY);
    }
170
  } else if (type->isAbsMax()) {
171
    return make_const(dst->dtype, 0);
172
173
174
175
176
177
178
179
180
181
182
183
184
  } else if (type->isBitAnd()) {
    if (is_int) {
      return make_const(dst->dtype, -1);
    } else if (is_uint) {
      return make_const(dst->dtype, (1 << bits) - 1);
    } else {
      // Should not arrive here
      return make_const(dst->dtype, -INFINITY);
    }
  } else if (type->isBitOr()) {
    return make_zero(dst->dtype);
  } else if (type->isBitXor()) {
    return make_zero(dst->dtype);
185
186
  } else {
    LOG(FATAL) << "Unsupported reduce type: " << type->type;
187
    return PrimExpr();
188
189
190
  }
}

191
192
193
PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
                                  const PrimExpr &b) const {
  PrimExpr rhs = b;
194
195
196
  if (lhs->dtype != rhs->dtype) {
    rhs = Cast(lhs->dtype, rhs);
  }
197
  if (type->isSum()) {
198
    return lhs + rhs;
199
  } else if (type->isAbsSum()) {
200
    return lhs + Max(rhs, -rhs);
201
  } else if (type->isMax()) {
202
    return Max(lhs, rhs);
203
  } else if (type->isMin()) {
204
    return Min(lhs, rhs);
205
  } else if (type->isAbsMax()) {
206
    return Max(tvm::abs(lhs), tvm::abs(rhs));
207
208
209
210
211
212
  } else if (type->isBitAnd()) {
    return lhs & rhs;
  } else if (type->isBitOr()) {
    return lhs | rhs;
  } else if (type->isBitXor()) {
    return lhs ^ rhs;
213
214
  } else {
    LOG(FATAL) << "Unsupported reduce type: " << type->type;
215
216
217
  }
}

218
std::string ReduceOpNode::MakeCodegenReducer() const {
219
  if (type->isSum()) {
220
    return "tl::SumOp";
221
  } else if (type->isAbsSum()) {
222
    return "tl::SumOp";
223
  } else if (type->isMax()) {
224
    return "tl::MaxOp";
225
  } else if (type->isMin()) {
226
    return "tl::MinOp";
227
  } else if (type->isAbsMax()) {
228
    return "tl::MaxOp";
229
230
231
232
233
234
  } else if (type->isBitAnd()) {
    return "tl::BitAndOp";
  } else if (type->isBitOr()) {
    return "tl::BitOrOp";
  } else if (type->isBitXor()) {
    return "tl::BitXorOp";
235
236
  } else {
    LOG(FATAL) << "Unsupported reduce type: " << type->type;
237
    return "";
238
239
240
  }
}

241
/**
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
 * @brief Lower the Reduce operator to a TIR statement.
 *
 * Lowers a ReduceOpNode operating on fragment-scoped buffers into a sequence of
 * TIR statements implementing: optional initialization, thread-local reduction
 * (unrolled inner loops), inter-thread reduction via a runtime AllReduce call
 * (Hopper-specific `run_hopper` variant when TargetIsHopper(T.target) is true),
 * and an optional accumulation or copy back to the destination buffer when a
 * temporary clear buffer is used.
 *
 * Behavior notes:
 * - Only supports src and dst in "local.fragment" scope; otherwise it checks
 *   and aborts with "Reduce for shared memory not implemented.".
 * - Supports both 1D reductions (scalar output) and reductions along a single
 *   extra dimension; validates layout dimensionality consistency.
 * - If `clear` is set (or for sum/abssum reductions), an initial value is
 *   written to the clear buffer; for non-clearing sum/abssum a duplicate
 *   temporary buffer is allocated and accumulated back into dst after
 * reduction.
 * - Performs iterator compression for local reduction loops using `analyzer`.
 * - Detects parallel thread splitting from the normalized iterator sum and
 *   emits a call to a templated `tl::AllReduce<...>::run` (or `run_hopper`)
 *   via `builtin::call_extern`. For sufficiently large reducing thread counts
 *   (>= 32) a workspace is allocated via T.AddWorkspace and passed to the
 *   AllReduce call.
 * - The final body is wrapped in parallel loops over the destination spatial
 *   dimensions and partitioned by the lowering thread variable. If a temporary
 *   clear buffer is used, it is allocated for the body.
 *
 * @param T Lowering context providing buffer and layout maps, thread bounds,
 *          target information, thread variable, and workspace allocation
 * helper.
 * @param analyzer Analyzer used for iterator compression and arithmetic
 * normalization.
 * @return Stmt Lowered TIR statement implementing the reduction.
276
 */
277
Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
278
279
280
281
282
  auto get_buffer = [&](const Buffer &buf) {
    if (T.buffer_remap.count(buf))
      return T.buffer_remap[buf];
    return buf;
  };
283

284
285
  auto src_scope = this->src.scope();
  auto dst_scope = this->dst.scope();
286

287
  if (src_scope == "local.fragment" && dst_scope == "local.fragment") {
288

289
290
291
292
293
294
    Buffer src_buffer = get_buffer(this->src);
    Buffer dst_buffer = get_buffer(this->dst);
    Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
    Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
    size_t src_dim = src_layout->InputDim();
    size_t dst_dim = dst_layout->InputDim();
295

296
    bool is_1d_reduce = src_dim == dst_dim && dst_dim == 1;
297

298
299
300
301
302
303
    if (is_1d_reduce) {
      ICHECK(is_one(dst_layout->OutputShape().back()))
          << "Reduce for scalar not implemented.";
    } else {
      ICHECK_EQ(src_dim, dst_dim + 1) << "Reduce dimension mismatch.";
    }
304

305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    Array<IterVar> dst_vars;
    for (size_t i = 0; i < dst_dim; ++i) {
      Var var = Var(std::string{char('i' + i)});
      dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
                                 IterVarType::kDataPar));
    }

    Array<IterVar> src_vars;
    if (!is_1d_reduce) {
      src_vars = dst_vars;
    }
    Range reduce_dom(0, src_layout->InputShape()[this->dim]);
    IterVar reduce_iv(reduce_dom, Var("rv"), IterVarType::kDataPar);
    src_vars.insert(src_vars.begin() + this->dim, reduce_iv);

    Array<PrimExpr> src_indices = src_layout->Forward(
        src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
    Array<PrimExpr> dst_indices = dst_layout->Forward(
        dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));

    Array<Stmt> stmts;

    bool require_init = this->clear;
    if (this->type->isSum() || this->type->isAbsSum() ||
        this->type->isBitAnd() || this->type->isBitOr() ||
        this->type->isBitXor()) {
      require_init = true;
    }

    Buffer clear_buffer = dst_buffer;
    bool need_duplicate = false;
    if ((this->type->isSum() || this->type->isAbsSum()) && !this->clear) {
      need_duplicate = true;
    } else if (this->type->isBitAnd() && !this->clear) {
      need_duplicate = true;
    } else if ((this->type->isBitOr() || this->type->isBitXor()) &&
               !this->clear) {
      need_duplicate = true;
    }

    if (need_duplicate) {
      // Create a new buffer with same shape and dtype as dst_buffer
      clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype,
                                 dst_buffer->name + "_clear",
                                 GetPtrStorageScope(dst_buffer->data));
    }
    // make reduce-init stmt
    if (require_init) {
      stmts.push_back(
          BufferStore(clear_buffer, this->MakeInitValue(), dst_indices));
    }

    // make thread-local reduce
    Array<PrimExpr> src_indice_compressed;
    Array<IterVar> src_var_compressed;
    for (size_t i = 0; i < src_layout->OutputDim(); ++i) {
      PrimExpr expr;
      IterVar var;
      std::tie(expr, var) = CompressIterator(
          src_indices[i], src_vars, src_vars[this->dim]->var, analyzer);
      src_indice_compressed.push_back(expr);
      src_var_compressed.push_back(var);
    }

    Stmt reduce_local = BufferStore(
        clear_buffer,
        this->MakeReduce(BufferLoad(clear_buffer, dst_indices),
                         BufferLoad(src_buffer, src_indice_compressed)),
        dst_indices);

    for (int i = static_cast<int>(src_layout->OutputDim()) - 1; i >= 0; --i) {
      reduce_local =
          For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
              ForKind::kUnrolled, reduce_local, std::nullopt,
              {{tir::attr::pragma_unroll_explicit, Bool(false)}});
    }
    stmts.push_back(reduce_local);

    PrimExpr src_thread = src_layout->ForwardThread(
        src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
    auto iter_sum =
        arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
    for (const auto &iter_split : iter_sum->args) {
      auto mark = iter_split->source->source.as<Var>();
      ICHECK(mark) << "Not a normalized iterator: " << iter_split->source;
      if (mark.value().same_as(src_vars[this->dim]->var)) {
        auto scale = as_const_int(iter_split->scale);
        auto extent = as_const_int(iter_split->extent);
        ICHECK(scale != nullptr && extent != nullptr);
        if (*extent == 1)
          continue;

        int reducing_threads = (*extent) * (*scale);
        std::stringstream ss;

        auto thread_offset = T.thread_bounds->min;
        if (TargetIsHopper(T.target) || TargetIsSm100(T.target)) {
          auto all_threads = T.thread_bounds->extent;
          ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
             << reducing_threads << ", " << (*scale) << ", " << thread_offset
             << ", " << all_threads << ">::run_hopper";
        } else {
          ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
             << reducing_threads << ", " << (*scale) << ", " << thread_offset
             << ">::run";
        }
        Array<PrimExpr> thread_reduce_args = {
            StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)};
        if (reducing_threads >= 32) {
          PrimExpr workspace = T.AddWorkspace(
              *as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
          thread_reduce_args.push_back(workspace);
        }
        auto call = Call(clear_buffer->dtype, builtin::call_extern(),
                         thread_reduce_args);
        stmts.push_back(BufferStore(clear_buffer, call, dst_indices));
421
      }
422
    }
423
424
425
426
427
428
429
430
431
432
433
434
435

    if (need_duplicate) {
      PrimExpr src_val = BufferLoad(clear_buffer, dst_indices);
      PrimExpr dst_val = BufferLoad(dst_buffer, dst_indices);
      PrimExpr update;
      if (this->type->isSum() || this->type->isAbsSum()) {
        update = dst_val + src_val;
      } else if (this->type->isBitAnd()) {
        update = this->clear ? src_val : bitwise_and(dst_val, src_val);
      } else if (this->type->isBitOr()) {
        update = bitwise_or(dst_val, src_val);
      } else if (this->type->isBitXor()) {
        update = bitwise_xor(dst_val, src_val);
436
      } else {
437
        LOG(FATAL) << "Unsupported reduce type: " << this->type->type;
438
      }
439
440
441
442
443
444
445
446
447
448
449
450
      stmts.push_back(BufferStore(dst_buffer, update, dst_indices));
    }

    Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
    for (int i = static_cast<int>(dst_layout->InputDim()) - 1; i >= 0; --i) {
      body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
                 ForKind::kParallel, body);
    }

    if (dst_layout->InputDim() > 0) {
      body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer,
                           dst_layout);
451
    } else {
452
453
      PrimExpr guard = (T.thread_var == T.thread_bounds->min);
      body = IfThenElse(guard, body);
454
    }
455
456
457
458
459
460

    if (need_duplicate) {
      body = Allocate(clear_buffer->data, clear_buffer->dtype,
                      clear_buffer->shape, const_true(), body);
    }
    return body;
461
462
  }

463
464
465
  LOG(FATAL) << "Reduce for buffers in scope (" << src_scope << ", "
             << dst_scope << ") is not implemented.";
  return Stmt();
466
467
}

468
469
LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
470
471
  if (level >= InferLevel::kStrict)
    return {};
472

473
  if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
474
      T.layout_map.count(src)) {
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    auto src_layout = T.layout_map[src].as<Fragment>().value();

    PrimExpr indice_rep_extent = src->shape[dim];
    PrimExpr src_rep_extent = src_layout->ReplicateExtent();
    PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent;

    Array<PrimExpr> fwd;
    for (int i = 0; i < static_cast<int>(src->shape.size()); i++) {
      if (i == dim) {
        fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
      } else if (i < dim) {
        fwd.push_back(InputPlaceholder(i));
      } else if (i > dim) {
        fwd.push_back(InputPlaceholder(i - 1));
      }
    }
491
492
    auto thd = src_layout->ForwardThread(
        fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521

    // Ensure the thread count is divisible by the replicate extent.
    // Otherwise, we cannot infer a valid fragment<->fragment layout.
    {
      arith::Analyzer analyzer;
      PrimExpr num_threads = T.thread_bounds->extent;
      // Though the dest_buffer_rep_extent will be compressed at
      // CondenseReplicateVar, we need to check the divisibility here to avoid
      // the issue that the thread count is not divisible by the replicate
      // extent.
      if (!analyzer.CanProve(FloorMod(num_threads, dest_buffer_rep_extent) ==
                             0) &&
          !analyzer.CanProve(FloorMod(dest_buffer_rep_extent, num_threads) ==
                             0)) {
        ICHECK(false) << "ReduceOp fragment layout inference failed: "
                         "num_threads % replicate_extent != 0. "
                      << "This mapping requires the block's thread count to be "
                         "divisible by the "
                      << "replicate extent. "
                      << "Try one of: (1) choose a thread block size divisible "
                         "by replicate_extent; "
                      << "(2) pick a different reduce dimension or adjust the "
                         "source fragment layout; "
                      << "Details: num_threads=" << num_threads
                      << ", replicate_extent=" << indice_rep_extent
                      << ", src=" << src << ", dst=" << dst;
      }
    }

522
    Fragment dst_layout =
523
        Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
524
525
            ->CondenseReplicateVar()
            ->BindThreadRange(T.thread_bounds);
526

527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    if (!T.layout_map.count(dst))
      return {{dst, dst_layout}};
    else {
      // Check if computed layout is compatible with existing: the existing one
      // must strictly contains the computed layout
      auto orig_dst_layout =
          T.layout_map.Get(dst).value().as<Fragment>().value();
      ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim());
      Array<PrimExpr> indices;
      indices.reserve(dst_layout->InputDim());
      arith::Analyzer inner_analyzer;
      for (int i = 0; i < dst_layout->InputDim(); ++i) {
        auto x = InputPlaceholder(i);
        indices.push_back(x);
        // should be literal - literal = 0, any analyzer will work
        ICHECK(is_zero(inner_analyzer.Simplify(
            dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i])));
        inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
      }

      ICHECK(as_const_int(dst_layout->ReplicateExtent()));
      ICHECK(as_const_int(src_layout->ReplicateExtent()));
      auto dst_rep = *as_const_int(dst_layout->ReplicateExtent());
      auto src_rep = *as_const_int(src_layout->ReplicateExtent());
      if (dst_rep < src_rep ||
          !ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices,
                                 inner_analyzer)) {
        std::ostringstream oss;
        oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. "
            << src << "\nLHS = " << src_layout->DebugOutput()
            << "\nRHS = " << orig_dst_layout->DebugOutput()
            << "\nYou may need to use a shared memory to transform the "
               "layout";
        throw LayoutConflictException(oss.str());
      }

      if (dst_rep > src_rep) {
        return {{dst, dst_layout}};
      }
    }
567
568
569
570
571
572
  }
  return {};
}

TIR_REGISTER_TL_OP(ReduceOp, reduce)
    .set_num_inputs(4)
573
574
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
575

576
577
578
579
580
581
582
583
584
585
// Normalize "Buffer" to BufferRegion. Use the shape of the buffer as the
// ranges.
static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) {
  Array<Range> ranges;
  for (PrimExpr extent : buf->shape) {
    ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
  }
  return BufferRegion(buf, ranges);
}

586
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
587
588
589
590
591
  /// CumSum constructor arguments:
  /// - src: input buffer
  /// - dst: output buffer
  /// - dim: dimension to cumsum
  /// - reverse: whether to cumsum in reverse order
592
  CHECK_EQ(args.size(), 4);
593
  ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>();
594
595
596
597
598
599
  // node->src = vmap[GetVarFromAccessPtr(args[0])];
  // node->dst = vmap[GetVarFromAccessPtr(args[1])];
  node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap);
  node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap);
  node->src = node->srcRegion_->buffer;
  node->dst = node->dstRegion_->buffer;
600
601
  node->dim = args[2].as<IntImm>().value()->value;
  node->reverse = args[3].as<Bool>().value();
602
603
604
605
606
  CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()))
      << "The dim of cumsum should be less than the number of dimensions. Got "
         "dim="
      << node->dim << ", but src has " << node->src->shape.size() << " dims.";

607
  data_ = std::move(node);
608
609
}

610
Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
611
612
613
614
615
616
617
618
  if (this->src.scope() == "local.fragment" &&
      this->dst.scope() == "local.fragment") {
    LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue "
                  "if you need this feature.";
  } else if (this->src.scope() == "shared.dyn" ||
             this->src.scope() == "shared") {
    ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
    std::stringstream ss;
619
    auto threads = T.thread_bounds->extent;
620
621
    Array<PrimExpr> args;
    int ndim = static_cast<int>(src->shape.size());
622
623
624
625
626

    // Build access pointers from regions locally
    PrimExpr srcPtr = MakeAccessPtrFromRegion(srcRegion_, 1);
    PrimExpr dstPtr = MakeAccessPtrFromRegion(dstRegion_, 2);

627
628
629
630
631
    if (ndim == 1) {
      ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim "
                           "= 0.";
      ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false")
         << ">::run";
632
      args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0]};
633
634
635
    } else if (ndim == 2) {
      ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
         << (reverse ? "true" : "false") << ">::run";
636
637
      args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0],
              src->shape[1]};
638
639
640
    } else {
      LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got "
                 << ndim << "D.";
641
642
643
644
645
646
647
648
649
650
    }
    return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
  } else {
    ICHECK(false) << "Cannot lower cumsum for " << this->src.scope() << " and "
                  << this->dst.scope();
  }

  return Stmt();
}

651
652
LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
653
654
655
656
657
658
659
  return {};
}

TIR_REGISTER_TL_OP(CumSumOp, cumsum)
    .set_num_inputs(4)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
660
661
662
663
664
665
666

TVM_FFI_STATIC_INIT_BLOCK() {
  ReduceOpNode::RegisterReflection();
  CumSumOpNode::RegisterReflection();
  ReduceTypeNode::RegisterReflection();
}

667
} // namespace tl
668
} // namespace tvm