"src/targets/gpu/vscode:/vscode.git/clone" did not exist on "ab3c9971cf132fe89ac5a0b95d4dbf9996cb1411"
reduce.cc 19.2 KB
Newer Older
1
2
/*!
 * \file tl/op/reduce.cc
3
 * \brief Implementation of reduction operators
4
5
6
7
8
9
10
 */

#include "reduce.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
11
#include <tvm/tir/stmt_functor.h>
12
13

#include "../layout/utils.h"
14
#include "../op/parallel.h"
15
#include "../target/utils.h"
16
#include "../transform/loop_partition.h"
17
#include "tir/transforms/ir_utils.h"
18
19
20
21
22
23
24

namespace tvm {
namespace tl {

using namespace tir;

ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
25
26
27
28
29
  ObjectPtr<ReduceOpNode> node = make_object<ReduceOpNode>();
  node->src = vmap[GetVarFromAccessPtr(args[0])];
  node->dst = vmap[GetVarFromAccessPtr(args[1])];
  std::string reduce_type = args[2].as<StringImm>().value()->value;
  node->dim = args[3].as<IntImm>().value()->value;
30
  node->type = ReduceType(reduce_type);
31
32
  node->clear = args[4].as<Bool>().value();
  data_ = std::move(node);
33
34
}

35
36
37
38
39
40
41
42
43
44
45
TileOperator ReduceOpNode::Clone() const {
  auto op = make_object<ReduceOpNode>(*this);
  return ReduceOp(op);
}

TileOperator CumSumOpNode::Clone() const {
  auto op = make_object<CumSumOpNode>(*this);
  return CumSumOp(op);
}

PrimExpr ReduceOpNode::MakeInitValue() const {
46
47
48
49
50
  auto dst_dtype = dst->dtype;
  auto is_int = dst_dtype.is_int();
  bool is_uint = dst_dtype.is_uint();
  auto bits = dst_dtype.bits();

51
  if (type->isSum()) {
52
    return make_zero(dst->dtype);
53
  } else if (type->isAbsSum()) {
54
    return make_zero(dst->dtype);
55
  } else if (type->isMax()) {
56
57
58
59
60
61
62
    if (is_int) {
      return make_const(dst->dtype, -(1 << (bits - 1)));
    } else if (is_uint) {
      return make_const(dst->dtype, 0);
    } else {
      return make_const(dst->dtype, -INFINITY);
    }
63
  } else if (type->isMin()) {
64
65
66
67
68
69
70
    if (is_int) {
      return make_const(dst->dtype, (1 << (bits - 1)) - 1);
    } else if (is_uint) {
      return make_const(dst->dtype, (1 << bits) - 1);
    } else {
      return make_const(dst->dtype, INFINITY);
    }
71
  } else if (type->isAbsMax()) {
72
    return make_const(dst->dtype, 0);
73
74
75
76
77
78
79
80
81
82
83
84
85
  } else if (type->isBitAnd()) {
    if (is_int) {
      return make_const(dst->dtype, -1);
    } else if (is_uint) {
      return make_const(dst->dtype, (1 << bits) - 1);
    } else {
      // Should not arrive here
      return make_const(dst->dtype, -INFINITY);
    }
  } else if (type->isBitOr()) {
    return make_zero(dst->dtype);
  } else if (type->isBitXor()) {
    return make_zero(dst->dtype);
86
87
  } else {
    LOG(FATAL) << "Unsupported reduce type: " << type->type;
88
89
90
  }
}

91
92
93
PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
                                  const PrimExpr &b) const {
  PrimExpr rhs = b;
94
95
96
  if (lhs->dtype != rhs->dtype) {
    rhs = Cast(lhs->dtype, rhs);
  }
97
  if (type->isSum()) {
98
    return lhs + rhs;
99
  } else if (type->isAbsSum()) {
100
    return lhs + Max(rhs, -rhs);
101
  } else if (type->isMax()) {
102
    return Max(lhs, rhs);
103
  } else if (type->isMin()) {
104
    return Min(lhs, rhs);
105
  } else if (type->isAbsMax()) {
106
    return Max(Max(lhs, rhs), -Min(lhs, rhs));
107
108
109
110
111
112
  } else if (type->isBitAnd()) {
    return lhs & rhs;
  } else if (type->isBitOr()) {
    return lhs | rhs;
  } else if (type->isBitXor()) {
    return lhs ^ rhs;
113
114
  } else {
    LOG(FATAL) << "Unsupported reduce type: " << type->type;
115
116
117
  }
}

118
std::string ReduceOpNode::MakeCodegenReducer() const {
119
  if (type->isSum()) {
120
    return "tl::SumOp";
121
  } else if (type->isAbsSum()) {
122
    return "tl::SumOp";
123
  } else if (type->isMax()) {
124
    return "tl::MaxOp";
125
  } else if (type->isMin()) {
126
    return "tl::MinOp";
127
  } else if (type->isAbsMax()) {
128
    return "tl::MaxOp";
129
130
131
132
133
134
  } else if (type->isBitAnd()) {
    return "tl::BitAndOp";
  } else if (type->isBitOr()) {
    return "tl::BitOrOp";
  } else if (type->isBitXor()) {
    return "tl::BitXorOp";
135
136
  } else {
    LOG(FATAL) << "Unsupported reduce type: " << type->type;
137
    return "";
138
139
140
  }
}

141
/**
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
 * @brief Lower the Reduce operator to a TIR statement.
 *
 * Lowers a ReduceOpNode operating on fragment-scoped buffers into a sequence of
 * TIR statements implementing: optional initialization, thread-local reduction
 * (unrolled inner loops), inter-thread reduction via a runtime AllReduce call
 * (Hopper-specific `run_hopper` variant when TargetIsHopper(T.target) is true),
 * and an optional accumulation or copy back to the destination buffer when a
 * temporary clear buffer is used.
 *
 * Behavior notes:
 * - Only supports src and dst in "local.fragment" scope; otherwise it checks
 *   and aborts with "Reduce for shared memory not implemented.".
 * - Supports both 1D reductions (scalar output) and reductions along a single
 *   extra dimension; validates layout dimensionality consistency.
 * - If `clear` is set (or for sum/abssum reductions), an initial value is
 *   written to the clear buffer; for non-clearing sum/abssum a duplicate
 *   temporary buffer is allocated and accumulated back into dst after
 * reduction.
 * - Performs iterator compression for local reduction loops using `analyzer`.
 * - Detects parallel thread splitting from the normalized iterator sum and
 *   emits a call to a templated `tl::AllReduce<...>::run` (or `run_hopper`)
 *   via `builtin::call_extern`. For sufficiently large reducing thread counts
 *   (>= 32) a workspace is allocated via T.AddWorkspace and passed to the
 *   AllReduce call.
 * - The final body is wrapped in parallel loops over the destination spatial
 *   dimensions and partitioned by the lowering thread variable. If a temporary
 *   clear buffer is used, it is allocated for the body.
 *
 * @param T Lowering context providing buffer and layout maps, thread bounds,
 *          target information, thread variable, and workspace allocation
 * helper.
 * @param analyzer Analyzer used for iterator compression and arithmetic
 * normalization.
 * @return Stmt Lowered TIR statement implementing the reduction.
176
 */
177
Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
178
179
  ICHECK(this->src.scope() == "local.fragment" &&
         this->dst.scope() == "local.fragment")
180
181
182
183
184
      << "Reduce for shared memory not implemented.";
  auto src_buffer = T.buffer_remap[this->src];
  auto dst_buffer = T.buffer_remap[this->dst];
  Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
  Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
185
186
187
188
189
190
191
192
193
194
195
196
  size_t src_dim = src_layout->InputDim();
  size_t dst_dim = dst_layout->InputDim();

  bool is_1d_reduce = src_dim == dst_dim && dst_dim == 1;

  if (is_1d_reduce) {
    ICHECK(is_one(dst_layout->OutputShape().back()))
        << "Reduce for scalar not implemented.";
  } else {
    ICHECK(src_dim == dst_dim + 1) << "Reduce dimension mismatch.";
  }

197
  Array<IterVar> dst_vars;
198
  for (size_t i = 0; i < dst_dim; i++) {
199
    Var var = Var(std::string{char('i' + i)});
200
201
    dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
                               IterVarType::kDataPar));
202
  }
203
204
205
206
  Array<IterVar> src_vars;
  if (!is_1d_reduce) {
    src_vars = dst_vars;
  }
207
208
209
210
211
212
213
  src_vars.insert(src_vars.begin() + this->dim,
                  {Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
                   IterVarType::kDataPar});
  Array<PrimExpr> src_indices = src_layout->Forward(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
  Array<PrimExpr> dst_indices = dst_layout->Forward(
      dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
214
215
216

  Array<Stmt> stmts;

217
218
  bool require_init = this->clear;
  // sum op must be cleared
219
  if (this->type->isSum()) {
220
    require_init = true;
221
  } else if (this->type->isAbsSum()) {
222
    require_init = true;
223
224
225
226
227
228
  } else if (this->type->isBitAnd()) {
    require_init = true;
  } else if (this->type->isBitOr()) {
    require_init = true;
  } else if (this->type->isBitXor()) {
    require_init = true;
229
230
231
232
  }

  Buffer clear_buffer = dst_buffer;
  bool need_duplicate = false;
233
  if (this->type->isSum() && !this->clear) {
234
    need_duplicate = true;
235
  } else if (this->type->isAbsSum() && !this->clear) {
236
    need_duplicate = true;
237
238
239
240
241
242
  } else if (this->type->isBitAnd()) {
    need_duplicate = true;
  } else if (this->type->isBitOr() && !this->clear) {
    need_duplicate = true;
  } else if (this->type->isBitXor() && !this->clear) {
    need_duplicate = true;
243
244
245
246
247
248
249
250
251
  }

  if (need_duplicate) {
    // Create a new buffer with same shape and dtype as dst_buffer
    clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype,
                               dst_buffer->name + "_clear",
                               GetPtrStorageScope(dst_buffer->data));
  }

252
  // make reduce-init stmt
253
  if (require_init) {
254
    stmts.push_back(
255
        BufferStore(clear_buffer, this->MakeInitValue(), dst_indices));
256
  }
257
258
259
260
261
262
263

  // make thread-local reduce
  Array<PrimExpr> src_indice_compressed;
  Array<IterVar> src_var_compressed;
  for (size_t i = 0; i < src_layout->OutputDim(); i++) {
    PrimExpr expr;
    IterVar var;
264
265
    std::tie(expr, var) = CompressIterator(src_indices[i], src_vars,
                                           src_vars[this->dim]->var, analyzer);
266
267
268
    src_indice_compressed.push_back(expr);
    src_var_compressed.push_back(var);
  }
269
  Stmt reduce_local = BufferStore(
270
271
      clear_buffer,
      this->MakeReduce(BufferLoad(clear_buffer, dst_indices),
272
273
                       BufferLoad(src_buffer, src_indice_compressed)),
      dst_indices);
274
275
  for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
    reduce_local =
276
        For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
277
            ForKind::kUnrolled, reduce_local, std::nullopt,
278
            {{tir::attr::pragma_unroll_explicit, Bool(false)}});
279
280
281
282
  }
  stmts.push_back(reduce_local);

  // make inter-thread reduce
283
284
285
286
287
  PrimExpr src_thread = src_layout->ForwardThread(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
  auto iter_sum =
      arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
  for (const auto &iter_split : iter_sum->args) {
288
    auto mark = iter_split->source->source.as<Var>();
289
    ICHECK(mark) << "Not a normalized iterator: " << iter_split->source;
290
291
292
293
    if (mark.value().same_as(src_vars[this->dim]->var)) {
      auto scale = as_const_int(iter_split->scale);
      auto extent = as_const_int(iter_split->extent);
      ICHECK(scale != nullptr && extent != nullptr);
294
295
      if (*extent == 1)
        continue;
296

297
298
      int reducing_threads = (*extent) * (*scale);
      std::stringstream ss;
299

300
      auto thread_offset = T.thread_bounds->min;
301
      if (TargetIsHopper(T.target) || TargetIsSm100(T.target)) {
302
        auto all_threads = T.thread_bounds->extent;
303
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
304
305
           << reducing_threads << ", " << (*scale) << ", " << thread_offset
           << ", " << all_threads << ">::run_hopper";
306
307
      } else {
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
308
309
           << reducing_threads << ", " << (*scale) << ", " << thread_offset
           << ">::run";
310
      }
311
      Array<PrimExpr> thread_reduce_args = {
312
          StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)};
313
      if (reducing_threads >= 32) {
314
        PrimExpr workspace = T.AddWorkspace(
315
            *as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
316
317
        thread_reduce_args.push_back(workspace);
      }
318
      auto call =
319
320
321
322
323
324
325
326
327
328
          Call(clear_buffer->dtype, builtin::call_extern(), thread_reduce_args);
      stmts.push_back(BufferStore(clear_buffer, call, dst_indices));
    }
  }
  Stmt reduce_interthread = BufferStore(
      clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices);

  // copy clear_buffer to dst_buffer
  if (need_duplicate) {
    // if is reduce sum, we should add a copy from clear_buffer to dst_buffer
329
    if (this->type->isSum()) {
330
331
332
333
      stmts.push_back(BufferStore(dst_buffer,
                                  Add(BufferLoad(dst_buffer, dst_indices),
                                      BufferLoad(clear_buffer, dst_indices)),
                                  dst_indices));
334
    } else if (this->type->isAbsSum()) {
335
336
337
338
      stmts.push_back(BufferStore(dst_buffer,
                                  Add(BufferLoad(dst_buffer, dst_indices),
                                      BufferLoad(clear_buffer, dst_indices)),
                                  dst_indices));
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
    } else if (this->type->isBitAnd()) {
      if (!this->clear) {
        stmts.push_back(
            BufferStore(dst_buffer,
                        bitwise_and(BufferLoad(dst_buffer, dst_indices),
                                    BufferLoad(clear_buffer, dst_indices)),
                        dst_indices));
      } else {
        stmts.push_back(BufferStore(
            dst_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices));
      }
    } else if (this->type->isBitOr()) {
      stmts.push_back(
          BufferStore(dst_buffer,
                      bitwise_or(BufferLoad(dst_buffer, dst_indices),
                                 BufferLoad(clear_buffer, dst_indices)),
                      dst_indices));
    } else if (this->type->isBitXor()) {
      stmts.push_back(
          BufferStore(dst_buffer,
                      bitwise_xor(BufferLoad(dst_buffer, dst_indices),
                                  BufferLoad(clear_buffer, dst_indices)),
                      dst_indices));
362
    } else {
363
      ICHECK(false) << "Unsupported reduce type: " << this->type->type;
364
365
366
367
368
    }
  }
  // make the outer spatial loop
  Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
  for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
369
370
    body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
               ForKind::kParallel, body);
371
372
373
  }

  body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
374
375
376
377
  if (need_duplicate) {
    body = Allocate(clear_buffer->data, clear_buffer->dtype,
                    clear_buffer->shape, const_true(), body);
  }
378
379
380
  return body;
}

381
382
LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
383
384
  if (level >= InferLevel::kStrict)
    return {};
385
  if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
386
      T.layout_map.count(src)) {
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    auto src_layout = T.layout_map[src].as<Fragment>().value();

    PrimExpr indice_rep_extent = src->shape[dim];
    PrimExpr src_rep_extent = src_layout->ReplicateExtent();
    PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent;

    Array<PrimExpr> fwd;
    for (int i = 0; i < static_cast<int>(src->shape.size()); i++) {
      if (i == dim) {
        fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
      } else if (i < dim) {
        fwd.push_back(InputPlaceholder(i));
      } else if (i > dim) {
        fwd.push_back(InputPlaceholder(i - 1));
      }
    }
403
404
    auto thd = src_layout->ForwardThread(
        fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
405
    Fragment dst_layout =
406
        Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
407
408
            ->CondenseReplicateVar()
            ->BindThreadRange(T.thread_bounds);
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    if (!T.layout_map.count(dst))
      return {{dst, dst_layout}};
    else {
      // Check if computed layout is compatible with existing: the existing one
      // must strictly contains the computed layout
      auto orig_dst_layout =
          T.layout_map.Get(dst).value().as<Fragment>().value();
      ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim());
      Array<PrimExpr> indices;
      indices.reserve(dst_layout->InputDim());
      arith::Analyzer inner_analyzer;
      for (int i = 0; i < dst_layout->InputDim(); ++i) {
        auto x = InputPlaceholder(i);
        indices.push_back(x);
        // should be literal - literal = 0, any analyzer will work
        ICHECK(is_zero(inner_analyzer.Simplify(
            dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i])));
        inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
      }

      ICHECK(as_const_int(dst_layout->ReplicateExtent()));
      ICHECK(as_const_int(src_layout->ReplicateExtent()));
      auto dst_rep = *as_const_int(dst_layout->ReplicateExtent());
      auto src_rep = *as_const_int(src_layout->ReplicateExtent());
      if (dst_rep < src_rep ||
          !ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices,
                                 inner_analyzer)) {
        std::ostringstream oss;
        oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. "
            << src << "\nLHS = " << src_layout->DebugOutput()
            << "\nRHS = " << orig_dst_layout->DebugOutput()
            << "\nYou may need to use a shared memory to transform the "
               "layout";
        throw LayoutConflictException(oss.str());
      }

      if (dst_rep > src_rep) {
        return {{dst, dst_layout}};
      }
    }
449
450
451
452
453
454
  }
  return {};
}

TIR_REGISTER_TL_OP(ReduceOp, reduce)
    .set_num_inputs(4)
455
456
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
457

458
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
459
460
461
462
463
  /// CumSum constructor arguments:
  /// - src: input buffer
  /// - dst: output buffer
  /// - dim: dimension to cumsum
  /// - reverse: whether to cumsum in reverse order
464
  CHECK_EQ(args.size(), 4);
465
466
467
468
469
470
471
  ObjectPtr<CumSumOpNode> node = make_object<CumSumOpNode>();
  node->src = vmap[GetVarFromAccessPtr(args[0])];
  node->dst = vmap[GetVarFromAccessPtr(args[1])];
  node->dim = args[2].as<IntImm>().value()->value;
  node->reverse = args[3].as<Bool>().value();
  CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()));
  data_ = std::move(node);
472
473
}

474
Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
475
476
477
478
479
480
481
482
  if (this->src.scope() == "local.fragment" &&
      this->dst.scope() == "local.fragment") {
    LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue "
                  "if you need this feature.";
  } else if (this->src.scope() == "shared.dyn" ||
             this->src.scope() == "shared") {
    ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
    std::stringstream ss;
483
    auto threads = T.thread_bounds->extent;
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
    Array<PrimExpr> args;
    int ndim = static_cast<int>(src->shape.size());
    if (ndim == 1) {
      ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim "
                           "= 0.";
      ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false")
         << ">::run";
      args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3),
              src->shape[0]};
    } else if (ndim == 2) {
      ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
         << (reverse ? "true" : "false") << ">::run";
      args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3),
              src->shape[0], src->shape[1]};
    } else {
      LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got "
                 << ndim << "D.";
501
502
503
504
505
506
507
508
509
510
    }
    return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
  } else {
    ICHECK(false) << "Cannot lower cumsum for " << this->src.scope() << " and "
                  << this->dst.scope();
  }

  return Stmt();
}

511
512
LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
513
514
515
516
517
518
519
  return {};
}

TIR_REGISTER_TL_OP(CumSumOp, cumsum)
    .set_num_inputs(4)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
520
} // namespace tl
521
} // namespace tvm