reduce.cc 7.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/*!
 * \file tl/op/reduce.cc
 *
 * Define reduce operator.
 */

#include "reduce.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../layout/utils.h"
#include "../transform/loop_partition.h"

namespace tvm {
namespace tl {

using namespace tir;

ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
  src = vmap[GetVarFromAccessPtr(args[0])];
  dst = vmap[GetVarFromAccessPtr(args[1])];
  String reduce_type = args[2].as<StringImm>().value()->value;
  dim = args[3].as<IntImm>().value()->value;
  if (reduce_type == "sum")
    type = ReduceType::kSum;
  else if (reduce_type == "abssum")
    type = ReduceType::kAbsSum;
30
31
  else if (reduce_type == "absmax")
    type = ReduceType::kAbsMax;
32
33
34
35
36
37
38
39
40
41
42
  else if (reduce_type == "max")
    type = ReduceType::kMax;
  else if (reduce_type == "min")
    type = ReduceType::kMin;
  else
    ICHECK(0) << "Unknown reduce type: " << reduce_type;
  clear = args[4].as<Bool>().value();
}

PrimExpr ReduceOp::MakeInitValue() const {
  switch (type) {
43
44
45
46
47
48
49
50
  case ReduceType::kSum:
    return make_zero(dst->dtype);
  case ReduceType::kAbsSum:
    return make_zero(dst->dtype);
  case ReduceType::kMax:
    return make_const(dst->dtype, -INFINITY);
  case ReduceType::kMin:
    return make_const(dst->dtype, INFINITY);
51
52
  case ReduceType::kAbsMax:
    return make_const(dst->dtype, 0);
53
54
  default:
    ICHECK(0);
55
56
57
  }
}

58
PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
59
60
61
62
63
  PrimExpr lhs = a, rhs = b;
  if (lhs->dtype != rhs->dtype) {
    rhs = Cast(lhs->dtype, rhs);
  }
  switch (type) {
64
65
66
67
68
69
70
71
  case ReduceType::kSum:
    return lhs + rhs;
  case ReduceType::kAbsSum:
    return lhs + Max(rhs, -rhs);
  case ReduceType::kMax:
    return Max(lhs, rhs);
  case ReduceType::kMin:
    return Min(lhs, rhs);
72
73
  case ReduceType::kAbsMax:
    return Max(Max(lhs, rhs), -Min(lhs, rhs));
74
75
76
  default:
    ICHECK(0);
    return PrimExpr(0);
77
78
79
80
81
  }
}

std::string ReduceOp::MakeCodegenReducer() const {
  switch (type) {
82
83
84
85
86
87
88
89
  case ReduceType::kSum:
    return "tl::SumOp";
  case ReduceType::kAbsSum:
    return "tl::SumOp";
  case ReduceType::kMax:
    return "tl::MaxOp";
  case ReduceType::kMin:
    return "tl::MinOp";
90
91
  case ReduceType::kAbsMax:
    return "tl::MaxOp";
92
93
94
  default:
    ICHECK(0);
    return "";
95
96
97
  }
}

98
99
100
Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
  ICHECK(this->src.scope() == "local.fragment" &&
         this->dst.scope() == "local.fragment")
101
102
103
104
105
106
107
108
109
      << "Reduce for shared memory not implemented.";
  auto src_buffer = T.buffer_remap[this->src];
  auto dst_buffer = T.buffer_remap[this->dst];
  Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
  Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
  ICHECK(src_layout->InputDim() == dst_layout->InputDim() + 1);
  Array<IterVar> dst_vars;
  for (size_t i = 0; i < dst_layout->InputDim(); i++) {
    Var var = Var(std::string{char('i' + i)});
110
111
    dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
                               IterVarType::kDataPar));
112
113
  }
  Array<IterVar> src_vars = dst_vars;
114
115
116
117
118
119
120
  src_vars.insert(src_vars.begin() + this->dim,
                  {Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
                   IterVarType::kDataPar});
  Array<PrimExpr> src_indices = src_layout->Forward(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
  Array<PrimExpr> dst_indices = dst_layout->Forward(
      dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
121
122
123
124

  Array<Stmt> stmts;

  // make reduce-init stmt
125
126
127
  if (this->clear)
    stmts.push_back(
        BufferStore(dst_buffer, this->MakeInitValue(), dst_indices));
128
129
130
131
132
133
134

  // make thread-local reduce
  Array<PrimExpr> src_indice_compressed;
  Array<IterVar> src_var_compressed;
  for (size_t i = 0; i < src_layout->OutputDim(); i++) {
    PrimExpr expr;
    IterVar var;
135
136
    std::tie(expr, var) = CompressIterator(src_indices[i], src_vars,
                                           src_vars[this->dim]->var, analyzer);
137
138
139
    src_indice_compressed.push_back(expr);
    src_var_compressed.push_back(var);
  }
140
141
142
143
144
  Stmt reduce_local = BufferStore(
      dst_buffer,
      this->MakeReduce(BufferLoad(dst_buffer, dst_indices),
                       BufferLoad(src_buffer, src_indice_compressed)),
      dst_indices);
145
146
  for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
    reduce_local =
147
148
149
        For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
            ForKind::kUnrolled, reduce_local, NullOpt,
            {{tir::attr::pragma_unroll_explicit, Bool(false)}});
150
151
152
153
  }
  stmts.push_back(reduce_local);

  // make inter-thread reduce
154
155
156
157
158
  PrimExpr src_thread = src_layout->ForwardThread(
      src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
  auto iter_sum =
      arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
  for (const auto &iter_split : iter_sum->args) {
159
160
161
162
163
164
    auto mark = iter_split->source->source.as<Var>();
    ICHECK(mark.defined());
    if (mark.value().same_as(src_vars[this->dim]->var)) {
      auto scale = as_const_int(iter_split->scale);
      auto extent = as_const_int(iter_split->extent);
      ICHECK(scale != nullptr && extent != nullptr);
165
166
      if (*extent == 1)
        continue;
167
168
      int reducing_threads = (*extent) * (*scale);
      std::stringstream ss;
169
170
171

      bool has_arch = T.target->attrs.count("arch") > 0;
      if (has_arch && Downcast<String>(T.target->attrs["arch"]) == "sm_90") {
172
173
174
175
176
177
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
           << reducing_threads << ", " << (*scale) << ">::run_hopper";
      } else {
        ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
           << reducing_threads << ", " << (*scale) << ">::run";
      }
178
179
      Array<PrimExpr> thread_reduce_args = {
          StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)};
180
181
182
183
      if (reducing_threads >= 32) {
        PrimExpr workspace = T.AddWorkspace(T.block_size, dst_buffer->dtype);
        thread_reduce_args.push_back(workspace);
      }
184
185
      auto call =
          Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args);
186
187
188
189
190
191
192
193
194
      stmts.push_back(BufferStore(dst_buffer, call, dst_indices));
    }
  }
  Stmt reduce_interthread =
      BufferStore(dst_buffer, BufferLoad(dst_buffer, dst_indices), dst_indices);

  // make the outer spatial loop
  Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
  for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
195
196
    body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
               ForKind::kParallel, body);
197
198
199
200
201
202
  }

  body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
  return body;
}

203
204
205
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (level >= InferLevel::kStrict)
    return {};
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
  if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
      T.layout_map.count(src) && !T.layout_map.count(dst)) {
    auto src_layout = T.layout_map[src].as<Fragment>().value();

    PrimExpr indice_rep_extent = src->shape[dim];
    PrimExpr src_rep_extent = src_layout->ReplicateExtent();
    PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent;

    Array<PrimExpr> fwd;
    for (int i = 0; i < static_cast<int>(src->shape.size()); i++) {
      if (i == dim) {
        fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
      } else if (i < dim) {
        fwd.push_back(InputPlaceholder(i));
      } else if (i > dim) {
        fwd.push_back(InputPlaceholder(i - 1));
      }
    }
224
225
    auto thd = src_layout->ForwardThread(
        fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
226
    Fragment dst_layout =
227
228
        Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)
            ->CondenseReplicateVar();
229
230
231
232
233
234
235
    return {{dst, dst_layout}};
  }
  return {};
}

TIR_REGISTER_TL_OP(ReduceOp, reduce)
    .set_num_inputs(4)
236
237
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
238

239
240
} // namespace tl
} // namespace tvm