reduce.cc 9.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/*!
 * \file tl/op/reduce.cc
 *
 * Define reduce operator.
 */

#include "reduce.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../layout/utils.h"
#include "../transform/loop_partition.h"

namespace tvm {
namespace tl {

using namespace tir;

ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
  src = vmap[GetVarFromAccessPtr(args[0])];
  dst = vmap[GetVarFromAccessPtr(args[1])];
  String reduce_type = args[2].as<StringImm>().value()->value;
  dim = args[3].as<IntImm>().value()->value;
  if (reduce_type == "sum")
    type = ReduceType::kSum;
  else if (reduce_type == "abssum")
    type = ReduceType::kAbsSum;
30
31
  else if (reduce_type == "absmax")
    type = ReduceType::kAbsMax;
32
33
34
35
36
37
38
39
40
41
42
  else if (reduce_type == "max")
    type = ReduceType::kMax;
  else if (reduce_type == "min")
    type = ReduceType::kMin;
  else
    ICHECK(0) << "Unknown reduce type: " << reduce_type;
  clear = args[4].as<Bool>().value();
}

PrimExpr ReduceOp::MakeInitValue() const {
  switch (type) {
43
44
45
46
47
48
49
50
  case ReduceType::kSum:
    return make_zero(dst->dtype);
  case ReduceType::kAbsSum:
    return make_zero(dst->dtype);
  case ReduceType::kMax:
    return make_const(dst->dtype, -INFINITY);
  case ReduceType::kMin:
    return make_const(dst->dtype, INFINITY);
51
52
  case ReduceType::kAbsMax:
    return make_const(dst->dtype, 0);
53
54
  default:
    ICHECK(0);
55
56
57
  }
}

58
PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
59
60
61
62
63
  PrimExpr lhs = a, rhs = b;
  if (lhs->dtype != rhs->dtype) {
    rhs = Cast(lhs->dtype, rhs);
  }
  switch (type) {
64
65
66
67
68
69
70
71
  case ReduceType::kSum:
    return lhs + rhs;
  case ReduceType::kAbsSum:
    return lhs + Max(rhs, -rhs);
  case ReduceType::kMax:
    return Max(lhs, rhs);
  case ReduceType::kMin:
    return Min(lhs, rhs);
72
73
  case ReduceType::kAbsMax:
    return Max(Max(lhs, rhs), -Min(lhs, rhs));
74
75
76
  default:
    ICHECK(0);
    return PrimExpr(0);
77
78
79
80
81
  }
}

std::string ReduceOp::MakeCodegenReducer() const {
  switch (type) {
82
83
84
85
86
87
88
89
  case ReduceType::kSum:
    return "tl::SumOp";
  case ReduceType::kAbsSum:
    return "tl::SumOp";
  case ReduceType::kMax:
    return "tl::MaxOp";
  case ReduceType::kMin:
    return "tl::MinOp";
90
91
  case ReduceType::kAbsMax:
    return "tl::MaxOp";
92
93
94
  default:
    ICHECK(0);
    return "";
95
96
97
  }
}

98
99
100
Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
  ICHECK(this->src.scope() == "local.fragment" &&
         this->dst.scope() == "local.fragment")
101
102
103
104
105
106
107
108
109
      << "Reduce for shared memory not implemented.";
  auto src_buffer = T.buffer_remap[this->src];
  auto dst_buffer = T.buffer_remap[this->dst];
  Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
  Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
  ICHECK(src_layout->InputDim() == dst_layout->InputDim() + 1);
  Array<IterVar> dst_vars;
  for (size_t i = 0; i < dst_layout->InputDim(); i++) {
    Var var = Var(std::string{char('i' + i)});
110
111
    dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
                               IterVarType::kDataPar));
112
113
  }
  Array<IterVar> src_vars = dst_vars;
114
115
116
117
118
119
120
  src_vars.insert(src_vars.begin() + this->dim,
                  {Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
                   IterVarType::kDataPar});
  Array<PrimExpr> src_indices = src_layout->Forward(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
  Array<PrimExpr> dst_indices = dst_layout->Forward(
      dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
121
122
123
124

  Array<Stmt> stmts;

  // make reduce-init stmt
125
126
127
  if (this->clear)
    stmts.push_back(
        BufferStore(dst_buffer, this->MakeInitValue(), dst_indices));
128
129
130
131
132
133
134

  // make thread-local reduce
  Array<PrimExpr> src_indice_compressed;
  Array<IterVar> src_var_compressed;
  for (size_t i = 0; i < src_layout->OutputDim(); i++) {
    PrimExpr expr;
    IterVar var;
135
136
    std::tie(expr, var) = CompressIterator(src_indices[i], src_vars,
                                           src_vars[this->dim]->var, analyzer);
137
138
139
    src_indice_compressed.push_back(expr);
    src_var_compressed.push_back(var);
  }
140
141
142
143
144
  Stmt reduce_local = BufferStore(
      dst_buffer,
      this->MakeReduce(BufferLoad(dst_buffer, dst_indices),
                       BufferLoad(src_buffer, src_indice_compressed)),
      dst_indices);
145
146
  for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
    reduce_local =
147
148
149
        For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
            ForKind::kUnrolled, reduce_local, NullOpt,
            {{tir::attr::pragma_unroll_explicit, Bool(false)}});
150
151
152
153
  }
  stmts.push_back(reduce_local);

  // make inter-thread reduce
154
155
156
157
158
  PrimExpr src_thread = src_layout->ForwardThread(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
  auto iter_sum =
      arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
  for (const auto &iter_split : iter_sum->args) {
159
160
161
162
163
164
    auto mark = iter_split->source->source.as<Var>();
    ICHECK(mark.defined());
    if (mark.value().same_as(src_vars[this->dim]->var)) {
      auto scale = as_const_int(iter_split->scale);
      auto extent = as_const_int(iter_split->extent);
      ICHECK(scale != nullptr && extent != nullptr);
165
166
      if (*extent == 1)
        continue;
167
168
      int reducing_threads = (*extent) * (*scale);
      std::stringstream ss;
169
170
171

      bool has_arch = T.target->attrs.count("arch") > 0;
      if (has_arch && Downcast<String>(T.target->attrs["arch"]) == "sm_90") {
172
173
174
175
176
177
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
           << reducing_threads << ", " << (*scale) << ">::run_hopper";
      } else {
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
           << reducing_threads << ", " << (*scale) << ">::run";
      }
178
179
      Array<PrimExpr> thread_reduce_args = {
          StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)};
180
      if (reducing_threads >= 32) {
181
182
        PrimExpr workspace = T.AddWorkspace(
            *as_const_int(T.thread_bounds->extent), dst_buffer->dtype);
183
184
        thread_reduce_args.push_back(workspace);
      }
185
186
      auto call =
          Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args);
187
188
189
190
191
192
193
194
195
      stmts.push_back(BufferStore(dst_buffer, call, dst_indices));
    }
  }
  Stmt reduce_interthread =
      BufferStore(dst_buffer, BufferLoad(dst_buffer, dst_indices), dst_indices);

  // make the outer spatial loop
  Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
  for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
196
197
    body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
               ForKind::kParallel, body);
198
199
200
201
202
203
  }

  body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
  return body;
}

204
205
206
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (level >= InferLevel::kStrict)
    return {};
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
  if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
      T.layout_map.count(src) && !T.layout_map.count(dst)) {
    auto src_layout = T.layout_map[src].as<Fragment>().value();

    PrimExpr indice_rep_extent = src->shape[dim];
    PrimExpr src_rep_extent = src_layout->ReplicateExtent();
    PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent;

    Array<PrimExpr> fwd;
    for (int i = 0; i < static_cast<int>(src->shape.size()); i++) {
      if (i == dim) {
        fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
      } else if (i < dim) {
        fwd.push_back(InputPlaceholder(i));
      } else if (i > dim) {
        fwd.push_back(InputPlaceholder(i - 1));
      }
    }
225
226
    auto thd = src_layout->ForwardThread(
        fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
227
    Fragment dst_layout =
228
229
        Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)
            ->CondenseReplicateVar();
230
231
232
233
234
235
236
    return {{dst, dst_layout}};
  }
  return {};
}

TIR_REGISTER_TL_OP(ReduceOp, reduce)
    .set_num_inputs(4)
237
238
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
239

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
  /*
    CumSum arguments:
      src: input buffer
      dst: output buffer
      dim: dimension to cumsum
      reverse: whether to cumsum in reverse order
   */
  CHECK_EQ(args.size(), 4);
  src = vmap[GetVarFromAccessPtr(args[0])];
  dst = vmap[GetVarFromAccessPtr(args[1])];
  dim = args[2].as<IntImm>().value()->value;
  reverse = args[3].as<Bool>().value();
  CHECK_LT(dim, static_cast<int>(src->shape.size()));
}

Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
  if (this->src.scope() == "local.fragment" &&
      this->dst.scope() == "local.fragment") {
    LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue "
                  "if you need this feature.";
  } else if (this->src.scope() == "shared.dyn" ||
             this->src.scope() == "shared") {
    ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
    std::stringstream ss;
    auto threads = T.thread_bounds->extent - T.thread_bounds->min;
    ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
       << (reverse ? "true" : "false") << ">::run";
    Array<PrimExpr> args = {StringImm(ss.str()), src.access_ptr(1),
                            dst.access_ptr(3)};
    for (int i = 0; i < src->shape.size(); i++) {
      args.push_back(src->shape[i]);
    }
    return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
  } else {
    ICHECK(false) << "Cannot lower cumsum for " << this->src.scope() << " and "
                  << this->dst.scope();
  }

  return Stmt();
}

LayoutMap CumSumOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  return {};
}

TIR_REGISTER_TL_OP(CumSumOp, cumsum)
    .set_num_inputs(4)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
290
291
} // namespace tl
} // namespace tvm