reduce.cc 11.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/*!
 * \file tl/op/reduce.cc
 *
 * Define reduce operator.
 */

#include "reduce.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../layout/utils.h"
#include "../transform/loop_partition.h"
15
#include "tir/transforms/ir_utils.h"
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

namespace tvm {
namespace tl {

using namespace tir;

ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
  src = vmap[GetVarFromAccessPtr(args[0])];
  dst = vmap[GetVarFromAccessPtr(args[1])];
  String reduce_type = args[2].as<StringImm>().value()->value;
  dim = args[3].as<IntImm>().value()->value;
  if (reduce_type == "sum")
    type = ReduceType::kSum;
  else if (reduce_type == "abssum")
    type = ReduceType::kAbsSum;
31
32
  else if (reduce_type == "absmax")
    type = ReduceType::kAbsMax;
33
34
35
36
37
38
39
40
41
42
43
  else if (reduce_type == "max")
    type = ReduceType::kMax;
  else if (reduce_type == "min")
    type = ReduceType::kMin;
  else
    ICHECK(0) << "Unknown reduce type: " << reduce_type;
  clear = args[4].as<Bool>().value();
}

PrimExpr ReduceOp::MakeInitValue() const {
  switch (type) {
44
45
46
47
48
49
50
51
  case ReduceType::kSum:
    return make_zero(dst->dtype);
  case ReduceType::kAbsSum:
    return make_zero(dst->dtype);
  case ReduceType::kMax:
    return make_const(dst->dtype, -INFINITY);
  case ReduceType::kMin:
    return make_const(dst->dtype, INFINITY);
52
53
  case ReduceType::kAbsMax:
    return make_const(dst->dtype, 0);
54
55
  default:
    ICHECK(0);
56
57
58
  }
}

59
PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
60
61
62
63
64
  PrimExpr lhs = a, rhs = b;
  if (lhs->dtype != rhs->dtype) {
    rhs = Cast(lhs->dtype, rhs);
  }
  switch (type) {
65
66
67
68
69
70
71
72
  case ReduceType::kSum:
    return lhs + rhs;
  case ReduceType::kAbsSum:
    return lhs + Max(rhs, -rhs);
  case ReduceType::kMax:
    return Max(lhs, rhs);
  case ReduceType::kMin:
    return Min(lhs, rhs);
73
74
  case ReduceType::kAbsMax:
    return Max(Max(lhs, rhs), -Min(lhs, rhs));
75
76
77
  default:
    ICHECK(0);
    return PrimExpr(0);
78
79
80
81
82
  }
}

std::string ReduceOp::MakeCodegenReducer() const {
  switch (type) {
83
84
85
86
87
88
89
90
  case ReduceType::kSum:
    return "tl::SumOp";
  case ReduceType::kAbsSum:
    return "tl::SumOp";
  case ReduceType::kMax:
    return "tl::MaxOp";
  case ReduceType::kMin:
    return "tl::MinOp";
91
92
  case ReduceType::kAbsMax:
    return "tl::MaxOp";
93
94
95
  default:
    ICHECK(0);
    return "";
96
97
98
  }
}

99
100
101
Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
  ICHECK(this->src.scope() == "local.fragment" &&
         this->dst.scope() == "local.fragment")
102
103
104
105
106
107
108
109
110
      << "Reduce for shared memory not implemented.";
  auto src_buffer = T.buffer_remap[this->src];
  auto dst_buffer = T.buffer_remap[this->dst];
  Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
  Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
  ICHECK(src_layout->InputDim() == dst_layout->InputDim() + 1);
  Array<IterVar> dst_vars;
  for (size_t i = 0; i < dst_layout->InputDim(); i++) {
    Var var = Var(std::string{char('i' + i)});
111
112
    dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
                               IterVarType::kDataPar));
113
114
  }
  Array<IterVar> src_vars = dst_vars;
115
116
117
118
119
120
121
  src_vars.insert(src_vars.begin() + this->dim,
                  {Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
                   IterVarType::kDataPar});
  Array<PrimExpr> src_indices = src_layout->Forward(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
  Array<PrimExpr> dst_indices = dst_layout->Forward(
      dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
122
123
124

  Array<Stmt> stmts;

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
  bool require_init = this->clear;
  // sum op must be cleared
  if (this->type == ReduceType::kSum) {
    require_init = true;
  } else if (this->type == ReduceType::kAbsSum) {
    require_init = true;
  }

  Buffer clear_buffer = dst_buffer;
  bool need_duplicate = false;
  if (this->type == ReduceType::kSum && !this->clear) {
    need_duplicate = true;
  } else if (this->type == ReduceType::kAbsSum && !this->clear) {
    need_duplicate = true;
  }

  if (need_duplicate) {
    // Create a new buffer with same shape and dtype as dst_buffer
    clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype,
                               dst_buffer->name + "_clear",
                               GetPtrStorageScope(dst_buffer->data));
  }

148
  // make reduce-init stmt
149
  if (require_init)
150
    stmts.push_back(
151
        BufferStore(clear_buffer, this->MakeInitValue(), dst_indices));
152
153
154
155
156
157
158

  // make thread-local reduce
  Array<PrimExpr> src_indice_compressed;
  Array<IterVar> src_var_compressed;
  for (size_t i = 0; i < src_layout->OutputDim(); i++) {
    PrimExpr expr;
    IterVar var;
159
160
    std::tie(expr, var) = CompressIterator(src_indices[i], src_vars,
                                           src_vars[this->dim]->var, analyzer);
161
162
163
    src_indice_compressed.push_back(expr);
    src_var_compressed.push_back(var);
  }
164
  Stmt reduce_local = BufferStore(
165
166
      clear_buffer,
      this->MakeReduce(BufferLoad(clear_buffer, dst_indices),
167
168
                       BufferLoad(src_buffer, src_indice_compressed)),
      dst_indices);
169
170
  for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
    reduce_local =
171
172
173
        For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
            ForKind::kUnrolled, reduce_local, NullOpt,
            {{tir::attr::pragma_unroll_explicit, Bool(false)}});
174
175
176
177
  }
  stmts.push_back(reduce_local);

  // make inter-thread reduce
178
179
180
181
182
  PrimExpr src_thread = src_layout->ForwardThread(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
  auto iter_sum =
      arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
  for (const auto &iter_split : iter_sum->args) {
183
184
185
186
187
188
    auto mark = iter_split->source->source.as<Var>();
    ICHECK(mark.defined());
    if (mark.value().same_as(src_vars[this->dim]->var)) {
      auto scale = as_const_int(iter_split->scale);
      auto extent = as_const_int(iter_split->extent);
      ICHECK(scale != nullptr && extent != nullptr);
189
190
      if (*extent == 1)
        continue;
191
192
      int reducing_threads = (*extent) * (*scale);
      std::stringstream ss;
193
194
195

      bool has_arch = T.target->attrs.count("arch") > 0;
      if (has_arch && Downcast<String>(T.target->attrs["arch"]) == "sm_90") {
196
197
198
199
200
201
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
           << reducing_threads << ", " << (*scale) << ">::run_hopper";
      } else {
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
           << reducing_threads << ", " << (*scale) << ">::run";
      }
202
      Array<PrimExpr> thread_reduce_args = {
203
          StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)};
204
      if (reducing_threads >= 32) {
205
        PrimExpr workspace = T.AddWorkspace(
206
            *as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
207
208
        thread_reduce_args.push_back(workspace);
      }
209
      auto call =
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
          Call(clear_buffer->dtype, builtin::call_extern(), thread_reduce_args);
      stmts.push_back(BufferStore(clear_buffer, call, dst_indices));
    }
  }
  Stmt reduce_interthread = BufferStore(
      clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices);

  // copy clear_buffer to dst_buffer
  if (need_duplicate) {
    // if is reduce sum, we should add a copy from clear_buffer to dst_buffer
    if (this->type == ReduceType::kSum) {
      stmts.push_back(BufferStore(dst_buffer,
                                  Add(BufferLoad(dst_buffer, dst_indices),
                                      BufferLoad(clear_buffer, dst_indices)),
                                  dst_indices));
    } else if (this->type == ReduceType::kAbsSum) {
      stmts.push_back(BufferStore(dst_buffer,
                                  Add(BufferLoad(dst_buffer, dst_indices),
                                      BufferLoad(clear_buffer, dst_indices)),
                                  dst_indices));
    } else {
      ICHECK(false) << "Unsupported reduce type: " << (int)this->type;
232
233
234
235
236
    }
  }
  // make the outer spatial loop
  Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
  for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
237
238
    body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
               ForKind::kParallel, body);
239
240
241
  }

  body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
242
243
244
245
  if (need_duplicate) {
    body = Allocate(clear_buffer->data, clear_buffer->dtype,
                    clear_buffer->shape, const_true(), body);
  }
246
247
248
  return body;
}

249
250
251
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (level >= InferLevel::kStrict)
    return {};
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
  if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
      T.layout_map.count(src) && !T.layout_map.count(dst)) {
    auto src_layout = T.layout_map[src].as<Fragment>().value();

    PrimExpr indice_rep_extent = src->shape[dim];
    PrimExpr src_rep_extent = src_layout->ReplicateExtent();
    PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent;

    Array<PrimExpr> fwd;
    for (int i = 0; i < static_cast<int>(src->shape.size()); i++) {
      if (i == dim) {
        fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
      } else if (i < dim) {
        fwd.push_back(InputPlaceholder(i));
      } else if (i > dim) {
        fwd.push_back(InputPlaceholder(i - 1));
      }
    }
270
271
    auto thd = src_layout->ForwardThread(
        fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
272
    Fragment dst_layout =
273
274
        Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)
            ->CondenseReplicateVar();
275
276
277
278
279
280
281
    return {{dst, dst_layout}};
  }
  return {};
}

TIR_REGISTER_TL_OP(ReduceOp, reduce)
    .set_num_inputs(4)
282
283
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
284

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
  /*
    CumSum arguments:
      src: input buffer
      dst: output buffer
      dim: dimension to cumsum
      reverse: whether to cumsum in reverse order
   */
  CHECK_EQ(args.size(), 4);
  src = vmap[GetVarFromAccessPtr(args[0])];
  dst = vmap[GetVarFromAccessPtr(args[1])];
  dim = args[2].as<IntImm>().value()->value;
  reverse = args[3].as<Bool>().value();
  CHECK_LT(dim, static_cast<int>(src->shape.size()));
}

Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
  if (this->src.scope() == "local.fragment" &&
      this->dst.scope() == "local.fragment") {
    LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue "
                  "if you need this feature.";
  } else if (this->src.scope() == "shared.dyn" ||
             this->src.scope() == "shared") {
    ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
    std::stringstream ss;
    auto threads = T.thread_bounds->extent - T.thread_bounds->min;
    ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
       << (reverse ? "true" : "false") << ">::run";
    Array<PrimExpr> args = {StringImm(ss.str()), src.access_ptr(1),
                            dst.access_ptr(3)};
    for (int i = 0; i < src->shape.size(); i++) {
      args.push_back(src->shape[i]);
    }
    return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
  } else {
    ICHECK(false) << "Cannot lower cumsum for " << this->src.scope() << " and "
                  << this->dst.scope();
  }

  return Stmt();
}

LayoutMap CumSumOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  return {};
}

TIR_REGISTER_TL_OP(CumSumOp, cumsum)
    .set_num_inputs(4)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
335
336
} // namespace tl
} // namespace tvm