reduce.cc 14.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*!
 * \file tl/op/reduce.cc
 *
 * Define reduce operator.
 */

#include "reduce.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
12
#include <tvm/tir/stmt_functor.h>
13
14

#include "../layout/utils.h"
15
#include "../op/parallel.h"
16
#include "../transform/loop_partition.h"
17
#include "tir/transforms/ir_utils.h"
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

namespace tvm {
namespace tl {

using namespace tir;

ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
  src = vmap[GetVarFromAccessPtr(args[0])];
  dst = vmap[GetVarFromAccessPtr(args[1])];
  String reduce_type = args[2].as<StringImm>().value()->value;
  dim = args[3].as<IntImm>().value()->value;
  if (reduce_type == "sum")
    type = ReduceType::kSum;
  else if (reduce_type == "abssum")
    type = ReduceType::kAbsSum;
33
34
  else if (reduce_type == "absmax")
    type = ReduceType::kAbsMax;
35
36
37
38
39
40
41
42
43
44
  else if (reduce_type == "max")
    type = ReduceType::kMax;
  else if (reduce_type == "min")
    type = ReduceType::kMin;
  else
    ICHECK(0) << "Unknown reduce type: " << reduce_type;
  clear = args[4].as<Bool>().value();
}

PrimExpr ReduceOp::MakeInitValue() const {
45
46
47
48
49
  auto dst_dtype = dst->dtype;
  auto is_int = dst_dtype.is_int();
  bool is_uint = dst_dtype.is_uint();
  auto bits = dst_dtype.bits();

50
  switch (type) {
51
52
53
54
55
  case ReduceType::kSum:
    return make_zero(dst->dtype);
  case ReduceType::kAbsSum:
    return make_zero(dst->dtype);
  case ReduceType::kMax:
56
57
58
59
60
61
62
    if (is_int) {
      return make_const(dst->dtype, -(1 << (bits - 1)));
    } else if (is_uint) {
      return make_const(dst->dtype, 0);
    } else {
      return make_const(dst->dtype, -INFINITY);
    }
63
  case ReduceType::kMin:
64
65
66
67
68
69
70
    if (is_int) {
      return make_const(dst->dtype, (1 << (bits - 1)) - 1);
    } else if (is_uint) {
      return make_const(dst->dtype, (1 << bits) - 1);
    } else {
      return make_const(dst->dtype, INFINITY);
    }
71
72
  case ReduceType::kAbsMax:
    return make_const(dst->dtype, 0);
73
74
  default:
    ICHECK(0);
75
76
77
  }
}

78
PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
79
80
81
82
83
  PrimExpr lhs = a, rhs = b;
  if (lhs->dtype != rhs->dtype) {
    rhs = Cast(lhs->dtype, rhs);
  }
  switch (type) {
84
85
86
87
88
89
90
91
  case ReduceType::kSum:
    return lhs + rhs;
  case ReduceType::kAbsSum:
    return lhs + Max(rhs, -rhs);
  case ReduceType::kMax:
    return Max(lhs, rhs);
  case ReduceType::kMin:
    return Min(lhs, rhs);
92
93
  case ReduceType::kAbsMax:
    return Max(Max(lhs, rhs), -Min(lhs, rhs));
94
95
96
  default:
    ICHECK(0);
    return PrimExpr(0);
97
98
99
100
101
  }
}

std::string ReduceOp::MakeCodegenReducer() const {
  switch (type) {
102
103
104
105
106
107
108
109
  case ReduceType::kSum:
    return "tl::SumOp";
  case ReduceType::kAbsSum:
    return "tl::SumOp";
  case ReduceType::kMax:
    return "tl::MaxOp";
  case ReduceType::kMin:
    return "tl::MinOp";
110
111
  case ReduceType::kAbsMax:
    return "tl::MaxOp";
112
113
114
  default:
    ICHECK(0);
    return "";
115
116
117
  }
}

118
119
120
Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
  ICHECK(this->src.scope() == "local.fragment" &&
         this->dst.scope() == "local.fragment")
121
122
123
124
125
      << "Reduce for shared memory not implemented.";
  auto src_buffer = T.buffer_remap[this->src];
  auto dst_buffer = T.buffer_remap[this->dst];
  Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
  Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
126
127
128
129
130
131
132
133
134
135
136
137
  size_t src_dim = src_layout->InputDim();
  size_t dst_dim = dst_layout->InputDim();

  bool is_1d_reduce = src_dim == dst_dim && dst_dim == 1;

  if (is_1d_reduce) {
    ICHECK(is_one(dst_layout->OutputShape().back()))
        << "Reduce for scalar not implemented.";
  } else {
    ICHECK(src_dim == dst_dim + 1) << "Reduce dimension mismatch.";
  }

138
  Array<IterVar> dst_vars;
139
  for (size_t i = 0; i < dst_dim; i++) {
140
    Var var = Var(std::string{char('i' + i)});
141
142
    dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
                               IterVarType::kDataPar));
143
  }
144
145
146
147
  Array<IterVar> src_vars;
  if (!is_1d_reduce) {
    src_vars = dst_vars;
  }
148
149
150
151
152
153
154
  src_vars.insert(src_vars.begin() + this->dim,
                  {Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
                   IterVarType::kDataPar});
  Array<PrimExpr> src_indices = src_layout->Forward(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
  Array<PrimExpr> dst_indices = dst_layout->Forward(
      dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
155
156
157

  Array<Stmt> stmts;

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
  bool require_init = this->clear;
  // sum op must be cleared
  if (this->type == ReduceType::kSum) {
    require_init = true;
  } else if (this->type == ReduceType::kAbsSum) {
    require_init = true;
  }

  Buffer clear_buffer = dst_buffer;
  bool need_duplicate = false;
  if (this->type == ReduceType::kSum && !this->clear) {
    need_duplicate = true;
  } else if (this->type == ReduceType::kAbsSum && !this->clear) {
    need_duplicate = true;
  }

  if (need_duplicate) {
    // Create a new buffer with same shape and dtype as dst_buffer
    clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype,
                               dst_buffer->name + "_clear",
                               GetPtrStorageScope(dst_buffer->data));
  }

181
  // make reduce-init stmt
182
  if (require_init)
183
    stmts.push_back(
184
        BufferStore(clear_buffer, this->MakeInitValue(), dst_indices));
185
186
187
188
189
190
191

  // make thread-local reduce
  Array<PrimExpr> src_indice_compressed;
  Array<IterVar> src_var_compressed;
  for (size_t i = 0; i < src_layout->OutputDim(); i++) {
    PrimExpr expr;
    IterVar var;
192
193
    std::tie(expr, var) = CompressIterator(src_indices[i], src_vars,
                                           src_vars[this->dim]->var, analyzer);
194
195
196
    src_indice_compressed.push_back(expr);
    src_var_compressed.push_back(var);
  }
197
  Stmt reduce_local = BufferStore(
198
199
      clear_buffer,
      this->MakeReduce(BufferLoad(clear_buffer, dst_indices),
200
201
                       BufferLoad(src_buffer, src_indice_compressed)),
      dst_indices);
202
203
  for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
    reduce_local =
204
        For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
205
            ForKind::kUnrolled, reduce_local, std::nullopt,
206
            {{tir::attr::pragma_unroll_explicit, Bool(false)}});
207
208
209
210
  }
  stmts.push_back(reduce_local);

  // make inter-thread reduce
211
212
213
214
215
  PrimExpr src_thread = src_layout->ForwardThread(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
  auto iter_sum =
      arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
  for (const auto &iter_split : iter_sum->args) {
216
    auto mark = iter_split->source->source.as<Var>();
217
    ICHECK(mark) << "Not a normalized iterator: " << iter_split->source;
218
219
220
221
    if (mark.value().same_as(src_vars[this->dim]->var)) {
      auto scale = as_const_int(iter_split->scale);
      auto extent = as_const_int(iter_split->extent);
      ICHECK(scale != nullptr && extent != nullptr);
222
223
      if (*extent == 1)
        continue;
224

225
226
      int reducing_threads = (*extent) * (*scale);
      std::stringstream ss;
227
228

      bool has_arch = T.target->attrs.count("arch") > 0;
229
      auto thread_offset = T.thread_bounds->min;
230
      if (has_arch && Downcast<String>(T.target->attrs["arch"]) == "sm_90") {
231
        auto all_threads = T.thread_bounds->extent;
232
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
233
234
           << reducing_threads << ", " << (*scale) << ", " << thread_offset
           << ", " << all_threads << ">::run_hopper";
235
236
      } else {
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
237
238
           << reducing_threads << ", " << (*scale) << ", " << thread_offset
           << ">::run";
239
      }
240
      Array<PrimExpr> thread_reduce_args = {
241
          StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)};
242
      if (reducing_threads >= 32) {
243
        PrimExpr workspace = T.AddWorkspace(
244
            *as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
245
246
        thread_reduce_args.push_back(workspace);
      }
247
      auto call =
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
          Call(clear_buffer->dtype, builtin::call_extern(), thread_reduce_args);
      stmts.push_back(BufferStore(clear_buffer, call, dst_indices));
    }
  }
  Stmt reduce_interthread = BufferStore(
      clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices);

  // copy clear_buffer to dst_buffer
  if (need_duplicate) {
    // if is reduce sum, we should add a copy from clear_buffer to dst_buffer
    if (this->type == ReduceType::kSum) {
      stmts.push_back(BufferStore(dst_buffer,
                                  Add(BufferLoad(dst_buffer, dst_indices),
                                      BufferLoad(clear_buffer, dst_indices)),
                                  dst_indices));
    } else if (this->type == ReduceType::kAbsSum) {
      stmts.push_back(BufferStore(dst_buffer,
                                  Add(BufferLoad(dst_buffer, dst_indices),
                                      BufferLoad(clear_buffer, dst_indices)),
                                  dst_indices));
    } else {
      ICHECK(false) << "Unsupported reduce type: " << (int)this->type;
270
271
272
273
274
    }
  }
  // make the outer spatial loop
  Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
  for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
275
276
    body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
               ForKind::kParallel, body);
277
278
279
  }

  body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
280
281
282
283
  if (need_duplicate) {
    body = Allocate(clear_buffer->data, clear_buffer->dtype,
                    clear_buffer->shape, const_true(), body);
  }
284
285
286
  return body;
}

287
288
289
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (level >= InferLevel::kStrict)
    return {};
290
  if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
291
      T.layout_map.count(src)) {
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    auto src_layout = T.layout_map[src].as<Fragment>().value();

    PrimExpr indice_rep_extent = src->shape[dim];
    PrimExpr src_rep_extent = src_layout->ReplicateExtent();
    PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent;

    Array<PrimExpr> fwd;
    for (int i = 0; i < static_cast<int>(src->shape.size()); i++) {
      if (i == dim) {
        fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
      } else if (i < dim) {
        fwd.push_back(InputPlaceholder(i));
      } else if (i > dim) {
        fwd.push_back(InputPlaceholder(i - 1));
      }
    }
308
309
    auto thd = src_layout->ForwardThread(
        fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
310
    Fragment dst_layout =
311
        Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
312
313
            ->CondenseReplicateVar()
            ->BindThreadRange(T.thread_bounds);
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    if (!T.layout_map.count(dst))
      return {{dst, dst_layout}};
    else {
      // Check if computed layout is compatible with existing: the existing one
      // must strictly contains the computed layout
      auto orig_dst_layout =
          T.layout_map.Get(dst).value().as<Fragment>().value();
      ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim());
      Array<PrimExpr> indices;
      indices.reserve(dst_layout->InputDim());
      arith::Analyzer inner_analyzer;
      for (int i = 0; i < dst_layout->InputDim(); ++i) {
        auto x = InputPlaceholder(i);
        indices.push_back(x);
        // should be literal - literal = 0, any analyzer will work
        ICHECK(is_zero(inner_analyzer.Simplify(
            dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i])));
        inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
      }

      ICHECK(as_const_int(dst_layout->ReplicateExtent()));
      ICHECK(as_const_int(src_layout->ReplicateExtent()));
      auto dst_rep = *as_const_int(dst_layout->ReplicateExtent());
      auto src_rep = *as_const_int(src_layout->ReplicateExtent());
      if (dst_rep < src_rep ||
          !ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices,
                                 inner_analyzer)) {
        std::ostringstream oss;
        oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. "
            << src << "\nLHS = " << src_layout->DebugOutput()
            << "\nRHS = " << orig_dst_layout->DebugOutput()
            << "\nYou may need to use a shared memory to transform the "
               "layout";
        throw LayoutConflictException(oss.str());
      }

      if (dst_rep > src_rep) {
        return {{dst, dst_layout}};
      }
    }
354
355
356
357
358
359
  }
  return {};
}

TIR_REGISTER_TL_OP(ReduceOp, reduce)
    .set_num_inputs(4)
360
361
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
362

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
  /*
    CumSum arguments:
      src: input buffer
      dst: output buffer
      dim: dimension to cumsum
      reverse: whether to cumsum in reverse order
   */
  CHECK_EQ(args.size(), 4);
  src = vmap[GetVarFromAccessPtr(args[0])];
  dst = vmap[GetVarFromAccessPtr(args[1])];
  dim = args[2].as<IntImm>().value()->value;
  reverse = args[3].as<Bool>().value();
  CHECK_LT(dim, static_cast<int>(src->shape.size()));
}

Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
  if (this->src.scope() == "local.fragment" &&
      this->dst.scope() == "local.fragment") {
    LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue "
                  "if you need this feature.";
  } else if (this->src.scope() == "shared.dyn" ||
             this->src.scope() == "shared") {
    ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
    std::stringstream ss;
388
    auto threads = T.thread_bounds->extent;
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
       << (reverse ? "true" : "false") << ">::run";
    Array<PrimExpr> args = {StringImm(ss.str()), src.access_ptr(1),
                            dst.access_ptr(3)};
    for (int i = 0; i < src->shape.size(); i++) {
      args.push_back(src->shape[i]);
    }
    return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
  } else {
    ICHECK(false) << "Cannot lower cumsum for " << this->src.scope() << " and "
                  << this->dst.scope();
  }

  return Stmt();
}

LayoutMap CumSumOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  return {};
}

TIR_REGISTER_TL_OP(CumSumOp, cumsum)
    .set_num_inputs(4)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
413
414
} // namespace tl
} // namespace tvm