atomic_add.cc 8.25 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file tl/op/atomic_add.cc
 *
 * Define elment-wise operators.
 */

7
8
#include "./atomic_add.h"
#include "./region.h"
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../target/utils.h"
#include "../transform/atomicadd_vectorize.h"
#include "../transform/common/loop_fusion_utils.h"
#include "../transform/loop_partition.h"
#include "builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

static int GetArchInt(Target target) {
  int arch_int = 0;
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
  const char *arch_str = s.value().c_str();
  if (arch_str[0] == 's' && arch_str[1] == 'm' && arch_str[2] == '_') {
    arch_int = atoi(&arch_str[3]);
  } else {
    arch_int = 0;
  }
  return arch_int;
}

37
38
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
  ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
39
40
41
42
43
44
45
  Array<Range> rgs[2];
  Buffer bf[2];
  for (int i = 0; i < 2; i++) {
    auto expr = args[i];
    auto call = expr.as<CallNode>();
    ICHECK(call);
    auto region = RegionOp(call->args, vmap);
46
47
    rgs[i] = region->GetRanges();
    bf[i] = region->GetBuffer();
48
  }
49
50
  std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
  std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
51
  if (args.size() >= 3) {
52
    node->coalesced_width = Downcast<IntImm>(args[2]);
53
  }
54
  data_ = std::move(node);
55
56
}

57
58
59
60
61
62
63
64
65
TileOperator AtomicAddNode::Clone() const {
  auto op = make_object<AtomicAddNode>(*this);
  if (par_op_.defined()) {
    op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
  }
  return AtomicAdd(op);
}

Array<IterVar> AtomicAddNode::MakeIterVars() const {
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  Array<IterVar> loop_vars;
  size_t idx = 0;
  for (size_t i = 0; i < src_range.size(); i++) {
    if (is_one(src_range[i]->extent))
      continue;
    Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
    idx++;
    loop_vars.push_back(
        {Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
  }
  return loop_vars;
}

// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
81
82
Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
                                           int src_dst) const {
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
  Array<PrimExpr> indices;
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
    if (is_one(ranges[i]->extent))
      indices.push_back(ranges[i]->min);
    else {
      indices.push_back(ranges[i]->min + ivs[idx]->var);
      idx++;
    }
  }
  ICHECK(idx == ivs.size())
      << "idx = " << idx << ", ivs.size() = " << ivs.size()
      << "src name = " << src->name << ", dst name = " << dst->name;
  return indices;
}

100
101
102
103
PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
                                      const Array<IterVar> &ivs,
                                      Array<PrimExpr> extents,
                                      int src_dst) const {
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  Array<PrimExpr> cond_list;
  ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
    if (is_one(ranges[i]->extent))
      continue;
    PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    cond = ranges[i]->min + ivs[idx]->var >= 0;
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    idx++;
  }
  if (cond_list.empty())
    return {};
  else {
    PrimExpr cond = cond_list[0];
    for (size_t i = 1; i < cond_list.size(); i++)
      cond = And(cond, cond_list[i]);
    return cond;
  }
}

131
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
  Array<IterVar> loop_vars = MakeIterVars();
  bool is_scalar = loop_vars.size() == 0;
  if (is_scalar) {
    return For(Var("i"), 0, 1, ForKind::kSerial,
               BufferStore(dst, BufferLoad(src, {0}), {0}));
  }

  for (const auto &iv : loop_vars)
    analyzer->Bind(iv->var, iv->dom);

  ICHECK(loop_vars.size() <= src_range.size())
      << "loop_vars.size() = " << loop_vars.size()
      << ", src_range.size() = " << src_range.size() << ", src = " << src->name
      << ", dst = " << dst->name;

  ICHECK(loop_vars.size() <= dst_range.size())
      << "loop_vars.size() = " << loop_vars.size()
      << ", dst_range.size() = " << dst_range.size() << ", src = " << src->name
      << ", dst = " << dst->name;

  Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
  Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);

  PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
  PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);

  Array<PrimExpr> new_args;
  new_args.push_back(StringImm("AtomicAdd"));

  PrimExpr src_value = BufferLoad(src, src_indices);
  if (src->dtype != dst->dtype)
    src_value = Cast(dst->dtype, src_value);
  if (src_predicate.defined())
    src_value = if_then_else(src_predicate, src_value, make_zero(dst->dtype));

  PrimExpr dst_value = BufferLoad(dst, dst_indices);
  if (dst_predicate.defined())
    dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype));

  Call address_of_value =
      tvm::tir::Call(DataType::Handle(), builtin::address_of(), {dst_value});

  new_args.push_back(address_of_value);
  new_args.push_back(src_value);

  Call atomicadd_call =
      tvm::tir::Call(dst->dtype, builtin::call_extern(), new_args);

  Stmt body = tvm::tir::Evaluate(atomicadd_call);

  for (int i = loop_vars.size() - 1; i >= 0; i--) {
    Map<String, ObjectRef> annotations = {};
    if (coalesced_width.defined()) {
      annotations.Set("coalesced_width", coalesced_width);
    }

    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body, std::nullopt, annotations);
  }
  return Downcast<For>(body);
}

194
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
195
196
197
  Target target = T.target;
  auto simt_loop = MakeSIMTLoop(analyzer);
  auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
198
  auto par_op = ParallelOp(fused_loop);
199

200
201
202
  std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
                                    InferLevel::kFree};
  for (auto level : levels) {
203
    (par_op)->InferLayout(
204
        {T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
205
  }
206
207
208
209
210
211
212
213
214
  auto loop_layout = par_op->GetLoopLayout();
  Var thread_var = T.thread_var;
  Range thread_bounds = T.thread_bounds;
  auto thread_loop =
      PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
  // TODO(@dyq): buggy implementation, need to fix
  // vectorized_thread_loop = VectorizeAtomicAdd(
  //     thread_loop, thread_var, thread_bounds, GetArchInt(target));
  auto vectorized_thread_loop = VectorizeLoop(thread_loop);
215
216
217
218
219
220
221
222
223

  if (par_op->GetPredicate(T.thread_var).defined()) {
    return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                      vectorized_thread_loop);
  }

  return vectorized_thread_loop;
}

224
225
226
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
                                     InferLevel level) const {
  if (!par_op_.defined()) {
227
    arith::Analyzer analyzer;
228
    par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
  }
  if (T.layout_map.count(src) && T.layout_map.count(dst)) {
    if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
      const FragmentNode *src_layout = T.layout_map[src].as<FragmentNode>();
      const FragmentNode *dst_layout = T.layout_map[dst].as<FragmentNode>();
      if (src_layout && dst_layout) {
        ICHECK(src_layout->IsEqual(dst_layout, true))
            << "Get different layout for " << src << " and " << dst
            << "\nLHS = " << src_layout->DebugOutput()
            << "\nRHS = " << dst_layout->DebugOutput()
            << "\nYou may need to use a shared memory to transform the layout";
      }
    }
  }
  return par_op_->InferLayout(T, level);
}

TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
    .set_num_inputs(2)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

} // namespace tl
} // namespace tvm