reduce.cc 14.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*!
 * \file tl/op/reduce.cc
 *
 * Define reduce operator.
 */

#include "reduce.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
12
#include <tvm/tir/stmt_functor.h>
13
14

#include "../layout/utils.h"
15
#include "../op/parallel.h"
16
#include "../transform/loop_partition.h"
17
#include "tir/transforms/ir_utils.h"
18
19
20
21
22
23
24

namespace tvm {
namespace tl {

using namespace tir;

ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
25
26
27
28
29
  ObjectPtr<ReduceOpNode> node = make_object<ReduceOpNode>();
  node->src = vmap[GetVarFromAccessPtr(args[0])];
  node->dst = vmap[GetVarFromAccessPtr(args[1])];
  std::string reduce_type = args[2].as<StringImm>().value()->value;
  node->dim = args[3].as<IntImm>().value()->value;
30
  if (reduce_type == "sum")
31
    node->type = ReduceType::kSum;
32
  else if (reduce_type == "abssum")
33
    node->type = ReduceType::kAbsSum;
34
  else if (reduce_type == "absmax")
35
    node->type = ReduceType::kAbsMax;
36
  else if (reduce_type == "max")
37
    node->type = ReduceType::kMax;
38
  else if (reduce_type == "min")
39
    node->type = ReduceType::kMin;
40
41
  else
    ICHECK(0) << "Unknown reduce type: " << reduce_type;
42
43
  node->clear = args[4].as<Bool>().value();
  data_ = std::move(node);
44
45
}

46
47
48
49
50
51
52
53
54
55
56
TileOperator ReduceOpNode::Clone() const {
  auto op = make_object<ReduceOpNode>(*this);
  return ReduceOp(op);
}

TileOperator CumSumOpNode::Clone() const {
  auto op = make_object<CumSumOpNode>(*this);
  return CumSumOp(op);
}

PrimExpr ReduceOpNode::MakeInitValue() const {
57
58
59
60
61
  auto dst_dtype = dst->dtype;
  auto is_int = dst_dtype.is_int();
  bool is_uint = dst_dtype.is_uint();
  auto bits = dst_dtype.bits();

62
  switch (type) {
63
64
65
66
67
  case ReduceType::kSum:
    return make_zero(dst->dtype);
  case ReduceType::kAbsSum:
    return make_zero(dst->dtype);
  case ReduceType::kMax:
68
69
70
71
72
73
74
    if (is_int) {
      return make_const(dst->dtype, -(1 << (bits - 1)));
    } else if (is_uint) {
      return make_const(dst->dtype, 0);
    } else {
      return make_const(dst->dtype, -INFINITY);
    }
75
  case ReduceType::kMin:
76
77
78
79
80
81
82
    if (is_int) {
      return make_const(dst->dtype, (1 << (bits - 1)) - 1);
    } else if (is_uint) {
      return make_const(dst->dtype, (1 << bits) - 1);
    } else {
      return make_const(dst->dtype, INFINITY);
    }
83
84
  case ReduceType::kAbsMax:
    return make_const(dst->dtype, 0);
85
86
  default:
    ICHECK(0);
87
88
89
  }
}

90
PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
91
92
93
94
95
  PrimExpr lhs = a, rhs = b;
  if (lhs->dtype != rhs->dtype) {
    rhs = Cast(lhs->dtype, rhs);
  }
  switch (type) {
96
97
98
99
100
101
102
103
  case ReduceType::kSum:
    return lhs + rhs;
  case ReduceType::kAbsSum:
    return lhs + Max(rhs, -rhs);
  case ReduceType::kMax:
    return Max(lhs, rhs);
  case ReduceType::kMin:
    return Min(lhs, rhs);
104
105
  case ReduceType::kAbsMax:
    return Max(Max(lhs, rhs), -Min(lhs, rhs));
106
107
108
  default:
    ICHECK(0);
    return PrimExpr(0);
109
110
111
  }
}

112
std::string ReduceOpNode::MakeCodegenReducer() const {
113
  switch (type) {
114
115
116
117
118
119
120
121
  case ReduceType::kSum:
    return "tl::SumOp";
  case ReduceType::kAbsSum:
    return "tl::SumOp";
  case ReduceType::kMax:
    return "tl::MaxOp";
  case ReduceType::kMin:
    return "tl::MinOp";
122
123
  case ReduceType::kAbsMax:
    return "tl::MaxOp";
124
125
126
  default:
    ICHECK(0);
    return "";
127
128
129
  }
}

130
Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
131
132
  ICHECK(this->src.scope() == "local.fragment" &&
         this->dst.scope() == "local.fragment")
133
134
135
136
137
      << "Reduce for shared memory not implemented.";
  auto src_buffer = T.buffer_remap[this->src];
  auto dst_buffer = T.buffer_remap[this->dst];
  Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
  Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
138
139
140
141
142
143
144
145
146
147
148
149
  size_t src_dim = src_layout->InputDim();
  size_t dst_dim = dst_layout->InputDim();

  bool is_1d_reduce = src_dim == dst_dim && dst_dim == 1;

  if (is_1d_reduce) {
    ICHECK(is_one(dst_layout->OutputShape().back()))
        << "Reduce for scalar not implemented.";
  } else {
    ICHECK(src_dim == dst_dim + 1) << "Reduce dimension mismatch.";
  }

150
  Array<IterVar> dst_vars;
151
  for (size_t i = 0; i < dst_dim; i++) {
152
    Var var = Var(std::string{char('i' + i)});
153
154
    dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
                               IterVarType::kDataPar));
155
  }
156
157
158
159
  Array<IterVar> src_vars;
  if (!is_1d_reduce) {
    src_vars = dst_vars;
  }
160
161
162
163
164
165
166
  src_vars.insert(src_vars.begin() + this->dim,
                  {Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
                   IterVarType::kDataPar});
  Array<PrimExpr> src_indices = src_layout->Forward(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
  Array<PrimExpr> dst_indices = dst_layout->Forward(
      dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
167
168
169

  Array<Stmt> stmts;

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
  bool require_init = this->clear;
  // sum op must be cleared
  if (this->type == ReduceType::kSum) {
    require_init = true;
  } else if (this->type == ReduceType::kAbsSum) {
    require_init = true;
  }

  Buffer clear_buffer = dst_buffer;
  bool need_duplicate = false;
  if (this->type == ReduceType::kSum && !this->clear) {
    need_duplicate = true;
  } else if (this->type == ReduceType::kAbsSum && !this->clear) {
    need_duplicate = true;
  }

  if (need_duplicate) {
    // Create a new buffer with same shape and dtype as dst_buffer
    clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype,
                               dst_buffer->name + "_clear",
                               GetPtrStorageScope(dst_buffer->data));
  }

193
  // make reduce-init stmt
194
  if (require_init)
195
    stmts.push_back(
196
        BufferStore(clear_buffer, this->MakeInitValue(), dst_indices));
197
198
199
200
201
202
203

  // make thread-local reduce
  Array<PrimExpr> src_indice_compressed;
  Array<IterVar> src_var_compressed;
  for (size_t i = 0; i < src_layout->OutputDim(); i++) {
    PrimExpr expr;
    IterVar var;
204
205
    std::tie(expr, var) = CompressIterator(src_indices[i], src_vars,
                                           src_vars[this->dim]->var, analyzer);
206
207
208
    src_indice_compressed.push_back(expr);
    src_var_compressed.push_back(var);
  }
209
  Stmt reduce_local = BufferStore(
210
211
      clear_buffer,
      this->MakeReduce(BufferLoad(clear_buffer, dst_indices),
212
213
                       BufferLoad(src_buffer, src_indice_compressed)),
      dst_indices);
214
215
  for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
    reduce_local =
216
        For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
217
            ForKind::kUnrolled, reduce_local, std::nullopt,
218
            {{tir::attr::pragma_unroll_explicit, Bool(false)}});
219
220
221
222
  }
  stmts.push_back(reduce_local);

  // make inter-thread reduce
223
224
225
226
227
  PrimExpr src_thread = src_layout->ForwardThread(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
  auto iter_sum =
      arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
  for (const auto &iter_split : iter_sum->args) {
228
    auto mark = iter_split->source->source.as<Var>();
229
    ICHECK(mark) << "Not a normalized iterator: " << iter_split->source;
230
231
232
233
    if (mark.value().same_as(src_vars[this->dim]->var)) {
      auto scale = as_const_int(iter_split->scale);
      auto extent = as_const_int(iter_split->extent);
      ICHECK(scale != nullptr && extent != nullptr);
234
235
      if (*extent == 1)
        continue;
236

237
238
      int reducing_threads = (*extent) * (*scale);
      std::stringstream ss;
239
240

      bool has_arch = T.target->attrs.count("arch") > 0;
241
      auto thread_offset = T.thread_bounds->min;
242
      if (has_arch && Downcast<String>(T.target->attrs["arch"]) == "sm_90") {
243
        auto all_threads = T.thread_bounds->extent;
244
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
245
246
           << reducing_threads << ", " << (*scale) << ", " << thread_offset
           << ", " << all_threads << ">::run_hopper";
247
248
      } else {
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
249
250
           << reducing_threads << ", " << (*scale) << ", " << thread_offset
           << ">::run";
251
      }
252
      Array<PrimExpr> thread_reduce_args = {
253
          StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)};
254
      if (reducing_threads >= 32) {
255
        PrimExpr workspace = T.AddWorkspace(
256
            *as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
257
258
        thread_reduce_args.push_back(workspace);
      }
259
      auto call =
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
          Call(clear_buffer->dtype, builtin::call_extern(), thread_reduce_args);
      stmts.push_back(BufferStore(clear_buffer, call, dst_indices));
    }
  }
  Stmt reduce_interthread = BufferStore(
      clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices);

  // copy clear_buffer to dst_buffer
  if (need_duplicate) {
    // if is reduce sum, we should add a copy from clear_buffer to dst_buffer
    if (this->type == ReduceType::kSum) {
      stmts.push_back(BufferStore(dst_buffer,
                                  Add(BufferLoad(dst_buffer, dst_indices),
                                      BufferLoad(clear_buffer, dst_indices)),
                                  dst_indices));
    } else if (this->type == ReduceType::kAbsSum) {
      stmts.push_back(BufferStore(dst_buffer,
                                  Add(BufferLoad(dst_buffer, dst_indices),
                                      BufferLoad(clear_buffer, dst_indices)),
                                  dst_indices));
    } else {
      ICHECK(false) << "Unsupported reduce type: " << (int)this->type;
282
283
284
285
286
    }
  }
  // make the outer spatial loop
  Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
  for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
287
288
    body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
               ForKind::kParallel, body);
289
290
291
  }

  body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
292
293
294
295
  if (need_duplicate) {
    body = Allocate(clear_buffer->data, clear_buffer->dtype,
                    clear_buffer->shape, const_true(), body);
  }
296
297
298
  return body;
}

299
300
LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
301
302
  if (level >= InferLevel::kStrict)
    return {};
303
  if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
304
      T.layout_map.count(src)) {
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    auto src_layout = T.layout_map[src].as<Fragment>().value();

    PrimExpr indice_rep_extent = src->shape[dim];
    PrimExpr src_rep_extent = src_layout->ReplicateExtent();
    PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent;

    Array<PrimExpr> fwd;
    for (int i = 0; i < static_cast<int>(src->shape.size()); i++) {
      if (i == dim) {
        fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
      } else if (i < dim) {
        fwd.push_back(InputPlaceholder(i));
      } else if (i > dim) {
        fwd.push_back(InputPlaceholder(i - 1));
      }
    }
321
322
    auto thd = src_layout->ForwardThread(
        fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
323
    Fragment dst_layout =
324
        Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
325
326
            ->CondenseReplicateVar()
            ->BindThreadRange(T.thread_bounds);
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    if (!T.layout_map.count(dst))
      return {{dst, dst_layout}};
    else {
      // Check if computed layout is compatible with existing: the existing one
      // must strictly contains the computed layout
      auto orig_dst_layout =
          T.layout_map.Get(dst).value().as<Fragment>().value();
      ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim());
      Array<PrimExpr> indices;
      indices.reserve(dst_layout->InputDim());
      arith::Analyzer inner_analyzer;
      for (int i = 0; i < dst_layout->InputDim(); ++i) {
        auto x = InputPlaceholder(i);
        indices.push_back(x);
        // should be literal - literal = 0, any analyzer will work
        ICHECK(is_zero(inner_analyzer.Simplify(
            dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i])));
        inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
      }

      ICHECK(as_const_int(dst_layout->ReplicateExtent()));
      ICHECK(as_const_int(src_layout->ReplicateExtent()));
      auto dst_rep = *as_const_int(dst_layout->ReplicateExtent());
      auto src_rep = *as_const_int(src_layout->ReplicateExtent());
      if (dst_rep < src_rep ||
          !ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices,
                                 inner_analyzer)) {
        std::ostringstream oss;
        oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. "
            << src << "\nLHS = " << src_layout->DebugOutput()
            << "\nRHS = " << orig_dst_layout->DebugOutput()
            << "\nYou may need to use a shared memory to transform the "
               "layout";
        throw LayoutConflictException(oss.str());
      }

      if (dst_rep > src_rep) {
        return {{dst, dst_layout}};
      }
    }
367
368
369
370
371
372
  }
  return {};
}

TIR_REGISTER_TL_OP(ReduceOp, reduce)
    .set_num_inputs(4)
373
374
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
375

376
377
378
379
380
381
382
383
384
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
  /*
    CumSum arguments:
      src: input buffer
      dst: output buffer
      dim: dimension to cumsum
      reverse: whether to cumsum in reverse order
   */
  CHECK_EQ(args.size(), 4);
385
386
387
388
389
390
391
  ObjectPtr<CumSumOpNode> node = make_object<CumSumOpNode>();
  node->src = vmap[GetVarFromAccessPtr(args[0])];
  node->dst = vmap[GetVarFromAccessPtr(args[1])];
  node->dim = args[2].as<IntImm>().value()->value;
  node->reverse = args[3].as<Bool>().value();
  CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()));
  data_ = std::move(node);
392
393
}

394
Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
395
396
397
398
399
400
401
402
  if (this->src.scope() == "local.fragment" &&
      this->dst.scope() == "local.fragment") {
    LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue "
                  "if you need this feature.";
  } else if (this->src.scope() == "shared.dyn" ||
             this->src.scope() == "shared") {
    ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
    std::stringstream ss;
403
    auto threads = T.thread_bounds->extent;
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
       << (reverse ? "true" : "false") << ">::run";
    Array<PrimExpr> args = {StringImm(ss.str()), src.access_ptr(1),
                            dst.access_ptr(3)};
    for (int i = 0; i < src->shape.size(); i++) {
      args.push_back(src->shape[i]);
    }
    return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
  } else {
    ICHECK(false) << "Cannot lower cumsum for " << this->src.scope() << " and "
                  << this->dst.scope();
  }

  return Stmt();
}

420
421
LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
                                    InferLevel level) const {
422
423
424
425
426
427
428
  return {};
}

TIR_REGISTER_TL_OP(CumSumOp, cumsum)
    .set_num_inputs(4)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
429
430
} // namespace tl
} // namespace tvm