reduce.cc 21.2 KB
Newer Older
1
2
/*!
 * \file tl/op/reduce.cc
3
 * \brief Implementation of reduction operators
4
5
6
7
8
9
10
 */

#include "reduce.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
11
#include <tvm/tir/stmt_functor.h>
12
13

#include "../layout/utils.h"
14
#include "../op/parallel.h"
15
#include "../target/utils.h"
16
#include "../transform/loop_partition.h"
17
#include "region.h"
18
#include "tir/transforms/ir_utils.h"
19
#include "tvm/tir/stmt.h"
20
#include "utils.h"
21
22
23
24
25
26

namespace tvm {
namespace tl {

using namespace tir;

27
// NormalizeToBufferRegion moved to src/op/utils.{h,cc}
28

29
// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
30

31
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
32
  ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
33
34
35
36
37
  // Accept BufferRegion/BufferLoad/tl.region for src/dst
  node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap);
  node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap);
  node->src = node->srcRegion_->buffer;
  node->dst = node->dstRegion_->buffer;
38
39
  std::string reduce_type = args[2].as<StringImm>().value()->value;
  node->dim = args[3].as<IntImm>().value()->value;
40
  node->type = ReduceType(reduce_type);
41
42
  node->clear = args[4].as<Bool>().value();
  data_ = std::move(node);
43
44
}

45
TileOperator ReduceOpNode::Clone() const {
46
  auto op = tvm::ffi::make_object<ReduceOpNode>(*this);
47
48
49
50
  return ReduceOp(op);
}

TileOperator CumSumOpNode::Clone() const {
51
  auto op = tvm::ffi::make_object<CumSumOpNode>(*this);
52
53
54
55
  return CumSumOp(op);
}

PrimExpr ReduceOpNode::MakeInitValue() const {
56
57
58
59
60
  auto dst_dtype = dst->dtype;
  auto is_int = dst_dtype.is_int();
  bool is_uint = dst_dtype.is_uint();
  auto bits = dst_dtype.bits();

61
  if (type->isSum()) {
62
    return make_zero(dst->dtype);
63
  } else if (type->isAbsSum()) {
64
    return make_zero(dst->dtype);
65
  } else if (type->isMax()) {
66
67
68
69
70
71
72
    if (is_int) {
      return make_const(dst->dtype, -(1 << (bits - 1)));
    } else if (is_uint) {
      return make_const(dst->dtype, 0);
    } else {
      return make_const(dst->dtype, -INFINITY);
    }
73
  } else if (type->isMin()) {
74
75
76
77
78
79
80
    if (is_int) {
      return make_const(dst->dtype, (1 << (bits - 1)) - 1);
    } else if (is_uint) {
      return make_const(dst->dtype, (1 << bits) - 1);
    } else {
      return make_const(dst->dtype, INFINITY);
    }
81
  } else if (type->isAbsMax()) {
82
    return make_const(dst->dtype, 0);
83
84
85
86
87
88
89
90
91
92
93
94
95
  } else if (type->isBitAnd()) {
    if (is_int) {
      return make_const(dst->dtype, -1);
    } else if (is_uint) {
      return make_const(dst->dtype, (1 << bits) - 1);
    } else {
      // Should not arrive here
      return make_const(dst->dtype, -INFINITY);
    }
  } else if (type->isBitOr()) {
    return make_zero(dst->dtype);
  } else if (type->isBitXor()) {
    return make_zero(dst->dtype);
96
97
  } else {
    LOG(FATAL) << "Unsupported reduce type: " << type->type;
98
    return PrimExpr();
99
100
101
  }
}

102
103
104
PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
                                  const PrimExpr &b) const {
  PrimExpr rhs = b;
105
106
107
  if (lhs->dtype != rhs->dtype) {
    rhs = Cast(lhs->dtype, rhs);
  }
108
  if (type->isSum()) {
109
    return lhs + rhs;
110
  } else if (type->isAbsSum()) {
111
    return lhs + Max(rhs, -rhs);
112
  } else if (type->isMax()) {
113
    return Max(lhs, rhs);
114
  } else if (type->isMin()) {
115
    return Min(lhs, rhs);
116
  } else if (type->isAbsMax()) {
117
    return Max(tvm::abs(lhs), tvm::abs(rhs));
118
119
120
121
122
123
  } else if (type->isBitAnd()) {
    return lhs & rhs;
  } else if (type->isBitOr()) {
    return lhs | rhs;
  } else if (type->isBitXor()) {
    return lhs ^ rhs;
124
125
  } else {
    LOG(FATAL) << "Unsupported reduce type: " << type->type;
126
127
128
  }
}

129
std::string ReduceOpNode::MakeCodegenReducer() const {
130
  if (type->isSum()) {
131
    return "tl::SumOp";
132
  } else if (type->isAbsSum()) {
133
    return "tl::SumOp";
134
  } else if (type->isMax()) {
135
    return "tl::MaxOp";
136
  } else if (type->isMin()) {
137
    return "tl::MinOp";
138
  } else if (type->isAbsMax()) {
139
    return "tl::MaxOp";
140
141
142
143
144
145
  } else if (type->isBitAnd()) {
    return "tl::BitAndOp";
  } else if (type->isBitOr()) {
    return "tl::BitOrOp";
  } else if (type->isBitXor()) {
    return "tl::BitXorOp";
146
147
  } else {
    LOG(FATAL) << "Unsupported reduce type: " << type->type;
148
    return "";
149
150
151
  }
}

152
/**
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
 * @brief Lower the Reduce operator to a TIR statement.
 *
 * Lowers a ReduceOpNode operating on fragment-scoped buffers into a sequence of
 * TIR statements implementing: optional initialization, thread-local reduction
 * (unrolled inner loops), inter-thread reduction via a runtime AllReduce call
 * (Hopper-specific `run_hopper` variant when TargetIsHopper(T.target) is true),
 * and an optional accumulation or copy back to the destination buffer when a
 * temporary clear buffer is used.
 *
 * Behavior notes:
 * - Only supports src and dst in "local.fragment" scope; otherwise it checks
 *   and aborts with "Reduce for shared memory not implemented.".
 * - Supports both 1D reductions (scalar output) and reductions along a single
 *   extra dimension; validates layout dimensionality consistency.
 * - If `clear` is set (or for sum/abssum reductions), an initial value is
 *   written to the clear buffer; for non-clearing sum/abssum a duplicate
 *   temporary buffer is allocated and accumulated back into dst after
 * reduction.
 * - Performs iterator compression for local reduction loops using `analyzer`.
 * - Detects parallel thread splitting from the normalized iterator sum and
 *   emits a call to a templated `tl::AllReduce<...>::run` (or `run_hopper`)
 *   via `builtin::call_extern`. For sufficiently large reducing thread counts
 *   (>= 32) a workspace is allocated via T.AddWorkspace and passed to the
 *   AllReduce call.
 * - The final body is wrapped in parallel loops over the destination spatial
 *   dimensions and partitioned by the lowering thread variable. If a temporary
 *   clear buffer is used, it is allocated for the body.
 *
 * @param T Lowering context providing buffer and layout maps, thread bounds,
 *          target information, thread variable, and workspace allocation
 * helper.
 * @param analyzer Analyzer used for iterator compression and arithmetic
 * normalization.
 * @return Stmt Lowered TIR statement implementing the reduction.
187
 */
188
Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
189
190
191
192
193
  auto get_buffer = [&](const Buffer &buf) {
    if (T.buffer_remap.count(buf))
      return T.buffer_remap[buf];
    return buf;
  };
194

195
196
  auto src_scope = this->src.scope();
  auto dst_scope = this->dst.scope();
197

198
  if (src_scope == "local.fragment" && dst_scope == "local.fragment") {
199

200
201
202
203
204
205
    Buffer src_buffer = get_buffer(this->src);
    Buffer dst_buffer = get_buffer(this->dst);
    Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
    Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
    size_t src_dim = src_layout->InputDim();
    size_t dst_dim = dst_layout->InputDim();
206

207
    bool is_1d_reduce = src_dim == dst_dim && dst_dim == 1;
208

209
210
211
212
213
214
    if (is_1d_reduce) {
      ICHECK(is_one(dst_layout->OutputShape().back()))
          << "Reduce for scalar not implemented.";
    } else {
      ICHECK_EQ(src_dim, dst_dim + 1) << "Reduce dimension mismatch.";
    }
215

216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    Array<IterVar> dst_vars;
    for (size_t i = 0; i < dst_dim; ++i) {
      Var var = Var(std::string{char('i' + i)});
      dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
                                 IterVarType::kDataPar));
    }

    Array<IterVar> src_vars;
    if (!is_1d_reduce) {
      src_vars = dst_vars;
    }
    Range reduce_dom(0, src_layout->InputShape()[this->dim]);
    IterVar reduce_iv(reduce_dom, Var("rv"), IterVarType::kDataPar);
    src_vars.insert(src_vars.begin() + this->dim, reduce_iv);

    Array<PrimExpr> src_indices = src_layout->Forward(
        src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
    Array<PrimExpr> dst_indices = dst_layout->Forward(
        dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));

    Array<Stmt> stmts;

    bool require_init = this->clear;
    if (this->type->isSum() || this->type->isAbsSum() ||
        this->type->isBitAnd() || this->type->isBitOr() ||
        this->type->isBitXor()) {
      require_init = true;
    }

    Buffer clear_buffer = dst_buffer;
    bool need_duplicate = false;
    if ((this->type->isSum() || this->type->isAbsSum()) && !this->clear) {
      need_duplicate = true;
    } else if (this->type->isBitAnd() && !this->clear) {
      need_duplicate = true;
    } else if ((this->type->isBitOr() || this->type->isBitXor()) &&
               !this->clear) {
      need_duplicate = true;
    }

    if (need_duplicate) {
      // Create a new buffer with same shape and dtype as dst_buffer
      clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype,
                                 dst_buffer->name + "_clear",
                                 GetPtrStorageScope(dst_buffer->data));
    }
    // make reduce-init stmt
    if (require_init) {
      stmts.push_back(
          BufferStore(clear_buffer, this->MakeInitValue(), dst_indices));
    }

    // make thread-local reduce
    Array<PrimExpr> src_indice_compressed;
    Array<IterVar> src_var_compressed;
    for (size_t i = 0; i < src_layout->OutputDim(); ++i) {
      PrimExpr expr;
      IterVar var;
      std::tie(expr, var) = CompressIterator(
          src_indices[i], src_vars, src_vars[this->dim]->var, analyzer);
      src_indice_compressed.push_back(expr);
      src_var_compressed.push_back(var);
    }

    Stmt reduce_local = BufferStore(
        clear_buffer,
        this->MakeReduce(BufferLoad(clear_buffer, dst_indices),
                         BufferLoad(src_buffer, src_indice_compressed)),
        dst_indices);

    for (int i = static_cast<int>(src_layout->OutputDim()) - 1; i >= 0; --i) {
      reduce_local =
          For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
              ForKind::kUnrolled, reduce_local, std::nullopt,
              {{tir::attr::pragma_unroll_explicit, Bool(false)}});
    }
    stmts.push_back(reduce_local);

    PrimExpr src_thread = src_layout->ForwardThread(
        src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
    auto iter_sum =
        arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
    for (const auto &iter_split : iter_sum->args) {
      auto mark = iter_split->source->source.as<Var>();
      ICHECK(mark) << "Not a normalized iterator: " << iter_split->source;
      if (mark.value().same_as(src_vars[this->dim]->var)) {
        auto scale = as_const_int(iter_split->scale);
        auto extent = as_const_int(iter_split->extent);
        ICHECK(scale != nullptr && extent != nullptr);
        if (*extent == 1)
          continue;

        int reducing_threads = (*extent) * (*scale);
        std::stringstream ss;

        auto thread_offset = T.thread_bounds->min;
        if (TargetIsHopper(T.target) || TargetIsSm100(T.target)) {
          auto all_threads = T.thread_bounds->extent;
          ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
             << reducing_threads << ", " << (*scale) << ", " << thread_offset
             << ", " << all_threads << ">::run_hopper";
        } else {
          ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
             << reducing_threads << ", " << (*scale) << ", " << thread_offset
             << ">::run";
        }
        Array<PrimExpr> thread_reduce_args = {
            StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)};
        if (reducing_threads >= 32) {
          PrimExpr workspace = T.AddWorkspace(
              *as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
          thread_reduce_args.push_back(workspace);
        }
        auto call = Call(clear_buffer->dtype, builtin::call_extern(),
                         thread_reduce_args);
        stmts.push_back(BufferStore(clear_buffer, call, dst_indices));
332
      }
333
    }
334
335
336
337
338
339
340
341
342
343
344
345
346

    if (need_duplicate) {
      PrimExpr src_val = BufferLoad(clear_buffer, dst_indices);
      PrimExpr dst_val = BufferLoad(dst_buffer, dst_indices);
      PrimExpr update;
      if (this->type->isSum() || this->type->isAbsSum()) {
        update = dst_val + src_val;
      } else if (this->type->isBitAnd()) {
        update = this->clear ? src_val : bitwise_and(dst_val, src_val);
      } else if (this->type->isBitOr()) {
        update = bitwise_or(dst_val, src_val);
      } else if (this->type->isBitXor()) {
        update = bitwise_xor(dst_val, src_val);
347
      } else {
348
        LOG(FATAL) << "Unsupported reduce type: " << this->type->type;
349
      }
350
351
352
353
354
355
356
357
358
359
360
361
      stmts.push_back(BufferStore(dst_buffer, update, dst_indices));
    }

    Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
    for (int i = static_cast<int>(dst_layout->InputDim()) - 1; i >= 0; --i) {
      body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
                 ForKind::kParallel, body);
    }

    if (dst_layout->InputDim() > 0) {
      body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer,
                           dst_layout);
362
    } else {
363
364
      PrimExpr guard = (T.thread_var == T.thread_bounds->min);
      body = IfThenElse(guard, body);
365
    }
366
367
368
369
370
371

    if (need_duplicate) {
      body = Allocate(clear_buffer->data, clear_buffer->dtype,
                      clear_buffer->shape, const_true(), body);
    }
    return body;
372
373
  }

374
375
376
  LOG(FATAL) << "Reduce for buffers in scope (" << src_scope << ", "
             << dst_scope << ") is not implemented.";
  return Stmt();
377
378
}

379
380
LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
381
382
  if (level >= InferLevel::kStrict)
    return {};
383

384
  if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
385
      T.layout_map.count(src)) {
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    auto src_layout = T.layout_map[src].as<Fragment>().value();

    PrimExpr indice_rep_extent = src->shape[dim];
    PrimExpr src_rep_extent = src_layout->ReplicateExtent();
    PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent;

    Array<PrimExpr> fwd;
    for (int i = 0; i < static_cast<int>(src->shape.size()); i++) {
      if (i == dim) {
        fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
      } else if (i < dim) {
        fwd.push_back(InputPlaceholder(i));
      } else if (i > dim) {
        fwd.push_back(InputPlaceholder(i - 1));
      }
    }
402
403
    auto thd = src_layout->ForwardThread(
        fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432

    // Ensure the thread count is divisible by the replicate extent.
    // Otherwise, we cannot infer a valid fragment<->fragment layout.
    {
      arith::Analyzer analyzer;
      PrimExpr num_threads = T.thread_bounds->extent;
      // Though the dest_buffer_rep_extent will be compressed at
      // CondenseReplicateVar, we need to check the divisibility here to avoid
      // the issue that the thread count is not divisible by the replicate
      // extent.
      if (!analyzer.CanProve(FloorMod(num_threads, dest_buffer_rep_extent) ==
                             0) &&
          !analyzer.CanProve(FloorMod(dest_buffer_rep_extent, num_threads) ==
                             0)) {
        ICHECK(false) << "ReduceOp fragment layout inference failed: "
                         "num_threads % replicate_extent != 0. "
                      << "This mapping requires the block's thread count to be "
                         "divisible by the "
                      << "replicate extent. "
                      << "Try one of: (1) choose a thread block size divisible "
                         "by replicate_extent; "
                      << "(2) pick a different reduce dimension or adjust the "
                         "source fragment layout; "
                      << "Details: num_threads=" << num_threads
                      << ", replicate_extent=" << indice_rep_extent
                      << ", src=" << src << ", dst=" << dst;
      }
    }

433
    Fragment dst_layout =
434
        Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
435
436
            ->CondenseReplicateVar()
            ->BindThreadRange(T.thread_bounds);
437

438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    if (!T.layout_map.count(dst))
      return {{dst, dst_layout}};
    else {
      // Check if computed layout is compatible with existing: the existing one
      // must strictly contains the computed layout
      auto orig_dst_layout =
          T.layout_map.Get(dst).value().as<Fragment>().value();
      ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim());
      Array<PrimExpr> indices;
      indices.reserve(dst_layout->InputDim());
      arith::Analyzer inner_analyzer;
      for (int i = 0; i < dst_layout->InputDim(); ++i) {
        auto x = InputPlaceholder(i);
        indices.push_back(x);
        // should be literal - literal = 0, any analyzer will work
        ICHECK(is_zero(inner_analyzer.Simplify(
            dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i])));
        inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
      }

      ICHECK(as_const_int(dst_layout->ReplicateExtent()));
      ICHECK(as_const_int(src_layout->ReplicateExtent()));
      auto dst_rep = *as_const_int(dst_layout->ReplicateExtent());
      auto src_rep = *as_const_int(src_layout->ReplicateExtent());
      if (dst_rep < src_rep ||
          !ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices,
                                 inner_analyzer)) {
        std::ostringstream oss;
        oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. "
            << src << "\nLHS = " << src_layout->DebugOutput()
            << "\nRHS = " << orig_dst_layout->DebugOutput()
            << "\nYou may need to use a shared memory to transform the "
               "layout";
        throw LayoutConflictException(oss.str());
      }

      if (dst_rep > src_rep) {
        return {{dst, dst_layout}};
      }
    }
478
479
480
481
482
483
  }
  return {};
}

TIR_REGISTER_TL_OP(ReduceOp, reduce)
    .set_num_inputs(4)
484
485
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
486

487
488
489
490
491
492
493
494
495
496
// Normalize "Buffer" to BufferRegion. Use the shape of the buffer as the
// ranges.
static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) {
  Array<Range> ranges;
  for (PrimExpr extent : buf->shape) {
    ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
  }
  return BufferRegion(buf, ranges);
}

497
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
498
499
500
501
502
  /// CumSum constructor arguments:
  /// - src: input buffer
  /// - dst: output buffer
  /// - dim: dimension to cumsum
  /// - reverse: whether to cumsum in reverse order
503
  CHECK_EQ(args.size(), 4);
504
  ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>();
505
506
507
508
509
510
  // node->src = vmap[GetVarFromAccessPtr(args[0])];
  // node->dst = vmap[GetVarFromAccessPtr(args[1])];
  node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap);
  node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap);
  node->src = node->srcRegion_->buffer;
  node->dst = node->dstRegion_->buffer;
511
512
  node->dim = args[2].as<IntImm>().value()->value;
  node->reverse = args[3].as<Bool>().value();
513
514
515
516
517
  CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()))
      << "The dim of cumsum should be less than the number of dimensions. Got "
         "dim="
      << node->dim << ", but src has " << node->src->shape.size() << " dims.";

518
  data_ = std::move(node);
519
520
}

521
Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
522
523
524
525
526
527
528
529
  if (this->src.scope() == "local.fragment" &&
      this->dst.scope() == "local.fragment") {
    LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue "
                  "if you need this feature.";
  } else if (this->src.scope() == "shared.dyn" ||
             this->src.scope() == "shared") {
    ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
    std::stringstream ss;
530
    auto threads = T.thread_bounds->extent;
531
532
    Array<PrimExpr> args;
    int ndim = static_cast<int>(src->shape.size());
533
534
535
536
537

    // Build access pointers from regions locally
    PrimExpr srcPtr = MakeAccessPtrFromRegion(srcRegion_, 1);
    PrimExpr dstPtr = MakeAccessPtrFromRegion(dstRegion_, 2);

538
539
540
541
542
    if (ndim == 1) {
      ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim "
                           "= 0.";
      ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false")
         << ">::run";
543
      args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0]};
544
545
546
    } else if (ndim == 2) {
      ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
         << (reverse ? "true" : "false") << ">::run";
547
548
      args = {StringImm(ss.str()), srcPtr, dstPtr, src->shape[0],
              src->shape[1]};
549
550
551
    } else {
      LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got "
                 << ndim << "D.";
552
553
554
555
556
557
558
559
560
561
    }
    return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
  } else {
    ICHECK(false) << "Cannot lower cumsum for " << this->src.scope() << " and "
                  << this->dst.scope();
  }

  return Stmt();
}

562
563
LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
564
565
566
567
568
569
570
  return {};
}

TIR_REGISTER_TL_OP(CumSumOp, cumsum)
    .set_num_inputs(4)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
571
572
573
574
575
576
577

TVM_FFI_STATIC_INIT_BLOCK() {
  ReduceOpNode::RegisterReflection();
  CumSumOpNode::RegisterReflection();
  ReduceTypeNode::RegisterReflection();
}

578
} // namespace tl
579
} // namespace tvm