reduce.cc 16.6 KB
Newer Older
1
2
/*!
 * \file tl/op/reduce.cc
3
 * \brief Implementation of reduction operators
4
5
6
7
8
9
10
 */

#include "reduce.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
11
#include <tvm/tir/stmt_functor.h>
12
13

#include "../layout/utils.h"
14
#include "../op/parallel.h"
15
#include "../target/utils.h"
16
#include "../transform/loop_partition.h"
17
#include "tir/transforms/ir_utils.h"
18
19
20
21
22
23
24

namespace tvm {
namespace tl {

using namespace tir;

ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
25
26
27
28
29
  ObjectPtr<ReduceOpNode> node = make_object<ReduceOpNode>();
  node->src = vmap[GetVarFromAccessPtr(args[0])];
  node->dst = vmap[GetVarFromAccessPtr(args[1])];
  std::string reduce_type = args[2].as<StringImm>().value()->value;
  node->dim = args[3].as<IntImm>().value()->value;
30
  node->type = ReduceType(reduce_type);
31
32
  node->clear = args[4].as<Bool>().value();
  data_ = std::move(node);
33
34
}

35
36
37
38
39
40
41
42
43
44
45
TileOperator ReduceOpNode::Clone() const {
  auto op = make_object<ReduceOpNode>(*this);
  return ReduceOp(op);
}

TileOperator CumSumOpNode::Clone() const {
  auto op = make_object<CumSumOpNode>(*this);
  return CumSumOp(op);
}

PrimExpr ReduceOpNode::MakeInitValue() const {
46
47
48
49
50
  auto dst_dtype = dst->dtype;
  auto is_int = dst_dtype.is_int();
  bool is_uint = dst_dtype.is_uint();
  auto bits = dst_dtype.bits();

51
  if (type->isSum()) {
52
    return make_zero(dst->dtype);
53
  } else if (type->isAbsSum()) {
54
    return make_zero(dst->dtype);
55
  } else if (type->isMax()) {
56
57
58
59
60
61
62
    if (is_int) {
      return make_const(dst->dtype, -(1 << (bits - 1)));
    } else if (is_uint) {
      return make_const(dst->dtype, 0);
    } else {
      return make_const(dst->dtype, -INFINITY);
    }
63
  } else if (type->isMin()) {
64
65
66
67
68
69
70
    if (is_int) {
      return make_const(dst->dtype, (1 << (bits - 1)) - 1);
    } else if (is_uint) {
      return make_const(dst->dtype, (1 << bits) - 1);
    } else {
      return make_const(dst->dtype, INFINITY);
    }
71
  } else if (type->isAbsMax()) {
72
    return make_const(dst->dtype, 0);
73
74
  } else {
    LOG(FATAL) << "Unsupported reduce type: " << type->type;
75
76
77
  }
}

78
79
80
PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
                                  const PrimExpr &b) const {
  PrimExpr rhs = b;
81
82
83
  if (lhs->dtype != rhs->dtype) {
    rhs = Cast(lhs->dtype, rhs);
  }
84
  if (type->isSum()) {
85
    return lhs + rhs;
86
  } else if (type->isAbsSum()) {
87
    return lhs + Max(rhs, -rhs);
88
  } else if (type->isMax()) {
89
    return Max(lhs, rhs);
90
  } else if (type->isMin()) {
91
    return Min(lhs, rhs);
92
  } else if (type->isAbsMax()) {
93
    return Max(Max(lhs, rhs), -Min(lhs, rhs));
94
95
  } else {
    LOG(FATAL) << "Unsupported reduce type: " << type->type;
96
97
98
  }
}

99
std::string ReduceOpNode::MakeCodegenReducer() const {
100
  if (type->isSum()) {
101
    return "tl::SumOp";
102
  } else if (type->isAbsSum()) {
103
    return "tl::SumOp";
104
  } else if (type->isMax()) {
105
    return "tl::MaxOp";
106
  } else if (type->isMin()) {
107
    return "tl::MinOp";
108
  } else if (type->isAbsMax()) {
109
    return "tl::MaxOp";
110
111
  } else {
    LOG(FATAL) << "Unsupported reduce type: " << type->type;
112
    return "";
113
114
115
  }
}

116
/**
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
 * @brief Lower the Reduce operator to a TIR statement.
 *
 * Lowers a ReduceOpNode operating on fragment-scoped buffers into a sequence of
 * TIR statements implementing: optional initialization, thread-local reduction
 * (unrolled inner loops), inter-thread reduction via a runtime AllReduce call
 * (Hopper-specific `run_hopper` variant when TargetIsHopper(T.target) is true),
 * and an optional accumulation or copy back to the destination buffer when a
 * temporary clear buffer is used.
 *
 * Behavior notes:
 * - Only supports src and dst in "local.fragment" scope; otherwise it checks
 *   and aborts with "Reduce for shared memory not implemented.".
 * - Supports both 1D reductions (scalar output) and reductions along a single
 *   extra dimension; validates layout dimensionality consistency.
 * - If `clear` is set (or for sum/abssum reductions), an initial value is
 *   written to the clear buffer; for non-clearing sum/abssum a duplicate
 *   temporary buffer is allocated and accumulated back into dst after
 * reduction.
 * - Performs iterator compression for local reduction loops using `analyzer`.
 * - Detects parallel thread splitting from the normalized iterator sum and
 *   emits a call to a templated `tl::AllReduce<...>::run` (or `run_hopper`)
 *   via `builtin::call_extern`. For sufficiently large reducing thread counts
 *   (>= 32) a workspace is allocated via T.AddWorkspace and passed to the
 *   AllReduce call.
 * - The final body is wrapped in parallel loops over the destination spatial
 *   dimensions and partitioned by the lowering thread variable. If a temporary
 *   clear buffer is used, it is allocated for the body.
 *
 * @param T Lowering context providing buffer and layout maps, thread bounds,
 *          target information, thread variable, and workspace allocation
 * helper.
 * @param analyzer Analyzer used for iterator compression and arithmetic
 * normalization.
 * @return Stmt Lowered TIR statement implementing the reduction.
151
 */
152
Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
153
154
  ICHECK(this->src.scope() == "local.fragment" &&
         this->dst.scope() == "local.fragment")
155
156
157
158
159
      << "Reduce for shared memory not implemented.";
  auto src_buffer = T.buffer_remap[this->src];
  auto dst_buffer = T.buffer_remap[this->dst];
  Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
  Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
160
161
162
163
164
165
166
167
168
169
170
171
  size_t src_dim = src_layout->InputDim();
  size_t dst_dim = dst_layout->InputDim();

  bool is_1d_reduce = src_dim == dst_dim && dst_dim == 1;

  if (is_1d_reduce) {
    ICHECK(is_one(dst_layout->OutputShape().back()))
        << "Reduce for scalar not implemented.";
  } else {
    ICHECK(src_dim == dst_dim + 1) << "Reduce dimension mismatch.";
  }

172
  Array<IterVar> dst_vars;
173
  for (size_t i = 0; i < dst_dim; i++) {
174
    Var var = Var(std::string{char('i' + i)});
175
176
    dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
                               IterVarType::kDataPar));
177
  }
178
179
180
181
  Array<IterVar> src_vars;
  if (!is_1d_reduce) {
    src_vars = dst_vars;
  }
182
183
184
185
186
187
188
  src_vars.insert(src_vars.begin() + this->dim,
                  {Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
                   IterVarType::kDataPar});
  Array<PrimExpr> src_indices = src_layout->Forward(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
  Array<PrimExpr> dst_indices = dst_layout->Forward(
      dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
189
190
191

  Array<Stmt> stmts;

192
193
  bool require_init = this->clear;
  // sum op must be cleared
194
  if (this->type->isSum()) {
195
    require_init = true;
196
  } else if (this->type->isAbsSum()) {
197
198
199
200
201
    require_init = true;
  }

  Buffer clear_buffer = dst_buffer;
  bool need_duplicate = false;
202
  if (this->type->isSum() && !this->clear) {
203
    need_duplicate = true;
204
  } else if (this->type->isAbsSum() && !this->clear) {
205
206
207
208
209
210
211
212
213
214
    need_duplicate = true;
  }

  if (need_duplicate) {
    // Create a new buffer with same shape and dtype as dst_buffer
    clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype,
                               dst_buffer->name + "_clear",
                               GetPtrStorageScope(dst_buffer->data));
  }

215
  // make reduce-init stmt
216
  if (require_init)
217
    stmts.push_back(
218
        BufferStore(clear_buffer, this->MakeInitValue(), dst_indices));
219
220
221
222
223
224
225

  // make thread-local reduce
  Array<PrimExpr> src_indice_compressed;
  Array<IterVar> src_var_compressed;
  for (size_t i = 0; i < src_layout->OutputDim(); i++) {
    PrimExpr expr;
    IterVar var;
226
227
    std::tie(expr, var) = CompressIterator(src_indices[i], src_vars,
                                           src_vars[this->dim]->var, analyzer);
228
229
230
    src_indice_compressed.push_back(expr);
    src_var_compressed.push_back(var);
  }
231
  Stmt reduce_local = BufferStore(
232
233
      clear_buffer,
      this->MakeReduce(BufferLoad(clear_buffer, dst_indices),
234
235
                       BufferLoad(src_buffer, src_indice_compressed)),
      dst_indices);
236
237
  for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
    reduce_local =
238
        For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
239
            ForKind::kUnrolled, reduce_local, std::nullopt,
240
            {{tir::attr::pragma_unroll_explicit, Bool(false)}});
241
242
243
244
  }
  stmts.push_back(reduce_local);

  // make inter-thread reduce
245
246
247
248
249
  PrimExpr src_thread = src_layout->ForwardThread(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
  auto iter_sum =
      arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
  for (const auto &iter_split : iter_sum->args) {
250
    auto mark = iter_split->source->source.as<Var>();
251
    ICHECK(mark) << "Not a normalized iterator: " << iter_split->source;
252
253
254
255
    if (mark.value().same_as(src_vars[this->dim]->var)) {
      auto scale = as_const_int(iter_split->scale);
      auto extent = as_const_int(iter_split->extent);
      ICHECK(scale != nullptr && extent != nullptr);
256
257
      if (*extent == 1)
        continue;
258

259
260
      int reducing_threads = (*extent) * (*scale);
      std::stringstream ss;
261

262
      auto thread_offset = T.thread_bounds->min;
263
      if (TargetIsHopper(T.target) || TargetIsSm100(T.target)) {
264
        auto all_threads = T.thread_bounds->extent;
265
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
266
267
           << reducing_threads << ", " << (*scale) << ", " << thread_offset
           << ", " << all_threads << ">::run_hopper";
268
269
      } else {
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
270
271
           << reducing_threads << ", " << (*scale) << ", " << thread_offset
           << ">::run";
272
      }
273
      Array<PrimExpr> thread_reduce_args = {
274
          StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)};
275
      if (reducing_threads >= 32) {
276
        PrimExpr workspace = T.AddWorkspace(
277
            *as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
278
279
        thread_reduce_args.push_back(workspace);
      }
280
      auto call =
281
282
283
284
285
286
287
288
289
290
          Call(clear_buffer->dtype, builtin::call_extern(), thread_reduce_args);
      stmts.push_back(BufferStore(clear_buffer, call, dst_indices));
    }
  }
  Stmt reduce_interthread = BufferStore(
      clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices);

  // copy clear_buffer to dst_buffer
  if (need_duplicate) {
    // if is reduce sum, we should add a copy from clear_buffer to dst_buffer
291
    if (this->type->isSum()) {
292
293
294
295
      stmts.push_back(BufferStore(dst_buffer,
                                  Add(BufferLoad(dst_buffer, dst_indices),
                                      BufferLoad(clear_buffer, dst_indices)),
                                  dst_indices));
296
    } else if (this->type->isAbsSum()) {
297
298
299
300
301
      stmts.push_back(BufferStore(dst_buffer,
                                  Add(BufferLoad(dst_buffer, dst_indices),
                                      BufferLoad(clear_buffer, dst_indices)),
                                  dst_indices));
    } else {
302
      ICHECK(false) << "Unsupported reduce type: " << this->type->type;
303
304
305
306
307
    }
  }
  // make the outer spatial loop
  Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
  for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
308
309
    body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
               ForKind::kParallel, body);
310
311
312
  }

  body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
313
314
315
316
  if (need_duplicate) {
    body = Allocate(clear_buffer->data, clear_buffer->dtype,
                    clear_buffer->shape, const_true(), body);
  }
317
318
319
  return body;
}

320
321
LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
322
323
  if (level >= InferLevel::kStrict)
    return {};
324
  if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
325
      T.layout_map.count(src)) {
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    auto src_layout = T.layout_map[src].as<Fragment>().value();

    PrimExpr indice_rep_extent = src->shape[dim];
    PrimExpr src_rep_extent = src_layout->ReplicateExtent();
    PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent;

    Array<PrimExpr> fwd;
    for (int i = 0; i < static_cast<int>(src->shape.size()); i++) {
      if (i == dim) {
        fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
      } else if (i < dim) {
        fwd.push_back(InputPlaceholder(i));
      } else if (i > dim) {
        fwd.push_back(InputPlaceholder(i - 1));
      }
    }
342
343
    auto thd = src_layout->ForwardThread(
        fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
344
    Fragment dst_layout =
345
        Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
346
347
            ->CondenseReplicateVar()
            ->BindThreadRange(T.thread_bounds);
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    if (!T.layout_map.count(dst))
      return {{dst, dst_layout}};
    else {
      // Check if computed layout is compatible with existing: the existing one
      // must strictly contains the computed layout
      auto orig_dst_layout =
          T.layout_map.Get(dst).value().as<Fragment>().value();
      ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim());
      Array<PrimExpr> indices;
      indices.reserve(dst_layout->InputDim());
      arith::Analyzer inner_analyzer;
      for (int i = 0; i < dst_layout->InputDim(); ++i) {
        auto x = InputPlaceholder(i);
        indices.push_back(x);
        // should be literal - literal = 0, any analyzer will work
        ICHECK(is_zero(inner_analyzer.Simplify(
            dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i])));
        inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
      }

      ICHECK(as_const_int(dst_layout->ReplicateExtent()));
      ICHECK(as_const_int(src_layout->ReplicateExtent()));
      auto dst_rep = *as_const_int(dst_layout->ReplicateExtent());
      auto src_rep = *as_const_int(src_layout->ReplicateExtent());
      if (dst_rep < src_rep ||
          !ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices,
                                 inner_analyzer)) {
        std::ostringstream oss;
        oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. "
            << src << "\nLHS = " << src_layout->DebugOutput()
            << "\nRHS = " << orig_dst_layout->DebugOutput()
            << "\nYou may need to use a shared memory to transform the "
               "layout";
        throw LayoutConflictException(oss.str());
      }

      if (dst_rep > src_rep) {
        return {{dst, dst_layout}};
      }
    }
388
389
390
391
392
393
  }
  return {};
}

TIR_REGISTER_TL_OP(ReduceOp, reduce)
    .set_num_inputs(4)
394
395
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
396

397
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
398
399
400
401
402
  /// CumSum constructor arguments:
  /// - src: input buffer
  /// - dst: output buffer
  /// - dim: dimension to cumsum
  /// - reverse: whether to cumsum in reverse order
403
  CHECK_EQ(args.size(), 4);
404
405
406
407
408
409
410
  ObjectPtr<CumSumOpNode> node = make_object<CumSumOpNode>();
  node->src = vmap[GetVarFromAccessPtr(args[0])];
  node->dst = vmap[GetVarFromAccessPtr(args[1])];
  node->dim = args[2].as<IntImm>().value()->value;
  node->reverse = args[3].as<Bool>().value();
  CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()));
  data_ = std::move(node);
411
412
}

413
Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
414
415
416
417
418
419
420
421
  if (this->src.scope() == "local.fragment" &&
      this->dst.scope() == "local.fragment") {
    LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue "
                  "if you need this feature.";
  } else if (this->src.scope() == "shared.dyn" ||
             this->src.scope() == "shared") {
    ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
    std::stringstream ss;
422
    auto threads = T.thread_bounds->extent;
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
    ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
       << (reverse ? "true" : "false") << ">::run";
    Array<PrimExpr> args = {StringImm(ss.str()), src.access_ptr(1),
                            dst.access_ptr(3)};
    for (int i = 0; i < src->shape.size(); i++) {
      args.push_back(src->shape[i]);
    }
    return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
  } else {
    ICHECK(false) << "Cannot lower cumsum for " << this->src.scope() << " and "
                  << this->dst.scope();
  }

  return Stmt();
}

439
440
LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
441
442
443
444
445
446
447
  return {};
}

TIR_REGISTER_TL_OP(CumSumOp, cumsum)
    .set_num_inputs(4)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
448
449
} // namespace tl
} // namespace tvm