thread_partial_sync.cc 11.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/*!
 * \file thread_storage_sync.cc
 */
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <unordered_map>
#include <unordered_set>

#include "../op/builtin.h"
15
#include "./storage_access.h"
16
17
18
19
20
21
22
23
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {

using namespace tir;

24
class TileLangThreadPartialSyncPlanner : public TileLangStorageAccessVisitor {
25
public:
26
  explicit TileLangThreadPartialSyncPlanner(StorageScope sync_scope)
27
      : sync_scope_(sync_scope) {}
28
29

  // The syncs inserted before each statement
30
31
  std::unordered_set<const Object *> syncs_inserted_;
  std::unordered_map<const Object *, int> partial_syncs_inserted_;
32

33
34
protected:
  bool Enabled(const VarNode *buf, const StorageScope &scope) const final {
35
36
37
    return in_device_env() && scope == sync_scope_;
  }
  // Plan the sync
38
39
  std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq,
                                     const ForNode *loop) final {
40
41
42
    // Redirect all "shared.dyn" buffer access to the same buffer var
    // so that the accesses can be planned together.
    Var shared_dyn_buf;
43
44
45
46
    for (StmtEntry &entry : seq) {
      for (AccessEntry &access : entry.access) {
        if (access.scope.rank == StorageRank::kShared &&
            access.scope.tag == ".dyn" && access.buffer.defined()) {
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
          if (!shared_dyn_buf.defined()) {
            shared_dyn_buf = access.buffer;
          } else {
            access.buffer = shared_dyn_buf;
          }
        }
      }
    }

    // Unsynced reads and writes
    std::vector<AccessEntry> reads;
    std::vector<AccessEntry> writes;
    // if it is a loop, rotate two times to consider effect of loop.
    // simulation based approach to find dependencies
    for (size_t i = 0; i < seq.size(); ++i) {
62
      const StmtEntry &s = seq[i];
63
64
65
66
67
68
69
      // check if sync before statement is needed.
      bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
      // Apply the syncs added already.
      if (sync_before_stmt) {
        reads.clear();
        writes.clear();
      }
70
      for (const AccessEntry &acc : s.access) {
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        if (acc.type == kRead) {
          if (FindConflict(writes, acc, false)) {
            sync_before_stmt = true;
            break;
          }
        } else if (acc.type == kWrite) {
          if (FindConflict(reads, acc, false)) {
            sync_before_stmt = true;
            break;
          }
        } else if (acc.type == kSync) {
          reads.clear();
          writes.clear();
        }
      }
      // If sync is inserted. remove the irrelevant things.
      if (sync_before_stmt) {
        reads.clear();
        writes.clear();
      }
      // Add the read/write of current statement
92
      for (const AccessEntry &acc : s.access) {
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        if (acc.type == kRead) {
          reads.push_back(acc);
        } else if (acc.type == kWrite) {
          writes.push_back(acc);
        } else if (acc.type == kSync) {
          reads.clear();
          writes.clear();
        }
      }
      if (sync_before_stmt) {
        insert_syncs(s.stmt);
      }
    }
    if (loop != nullptr) {
      for (size_t i = 0; i < seq.size(); ++i) {
108
109
110
111
112
        const StmtEntry &s = seq[i];
        if (syncs_inserted_.count(s.stmt) != 0)
          break;
        if (reads.empty() && writes.empty())
          break;
113
        bool sync_before_stmt = false;
114
        for (const AccessEntry &acc : s.access) {
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
          if (acc.type == kRead) {
            if (FindConflict(writes, acc, true)) {
              sync_before_stmt = true;
              break;
            }
          } else if (acc.type == kWrite) {
            if (FindConflict(reads, acc, true)) {
              sync_before_stmt = true;
              break;
            }
          } else if (acc.type == kSync) {
            reads.clear();
            writes.clear();
          }
        }
        if (sync_before_stmt) {
          insert_syncs(s.stmt);
          break;
        }
      }
    }
    // return the exposed entries, remove unnecessary ones.
    int sync_count = 0;
    // head are before first sync, tail are after last sync
    std::vector<AccessEntry> head, tail;
    AccessEntry esync;
    esync.threads = this->env_threads();
    esync.type = kSync;
    esync.scope = sync_scope_;

145
    for (const StmtEntry &s : seq) {
146
147
148
149
150
151
152
153
      if (syncs_inserted_.count(s.stmt)) {
        if (sync_count != 0) {
          tail.clear();
        } else {
          head.push_back(esync);
        }
        ++sync_count;
      }
154
      for (const AccessEntry &acc : s.access) {
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        if (acc.type == kSync) {
          if (sync_count != 0) {
            tail.clear();
          } else {
            head.push_back(esync);
          }
          ++sync_count;
        } else {
          if (sync_count != 0) {
            tail.push_back(acc);
          } else {
            head.push_back(acc);
          }
        }
      }
    }
    head.insert(head.end(), tail.begin(), tail.end());
    if (loop != nullptr) {
      // clear double buffer flag after a loop is finished.
174
      for (AccessEntry &e : head) {
175
176
177
178
179
180
        e.double_buffer_write = false;
      }
    }
    return head;
  }

181
private:
182
  // find conflicting entry in vec.
183
184
185
  bool FindConflict(const std::vector<AccessEntry> &prev,
                    const AccessEntry &curr, bool loop_carry) {
    for (const AccessEntry &x : prev) {
186
187
188
189
190
191
192
      if (FindConflict(x, curr, loop_carry)) {
        return true;
      }
    }
    return false;
  }

193
194
  bool FindConflict(const AccessEntry &prev, const AccessEntry &curr,
                    bool loop_carry) {
195
196
197
198
199
200
201
202
203
204
205
206
    // Access to different buffers does not conflict.
    if (!prev.buffer.same_as(curr.buffer)) {
      return false;
    }

    // Assumes no race between threads
    // Same index value means no conflicts
    // TODO(tqchen) more standard set based testing.
    bool has_same_index = true;
    // Even if access has the same index, those indices need to
    // depend on the innermost thread id to avoid race condition
    bool depends_on_thread_index = true;
207
    const VarNode *thread_index_var = nullptr;
208
209
210
211
212
    if (!curr.threads.empty()) {
      thread_index_var = curr.threads.back()->var.get();
    }

    for (size_t i = 0; i < prev.touched.size(); i++) {
213
214
      const auto &prev_intset = prev.touched[i];
      const auto &curr_intset = curr.touched[i];
215
216
217
218
219
220

      if (prev_intset.IsSinglePoint() && curr_intset.IsSinglePoint()) {
        PrimExpr prev_index = prev_intset.PointValue();
        PrimExpr curr_index = curr_intset.PointValue();
        has_same_index = ExprDeepEqual()(prev_index, curr_index);
        if (thread_index_var != nullptr) {
221
          auto f_uses_thread_index = [=](const tvm::tir::VarNode *parameter) {
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
            return parameter == thread_index_var;
          };
          depends_on_thread_index = depends_on_thread_index &&
                                    UsesVar(curr_index, f_uses_thread_index) &&
                                    UsesVar(prev_index, f_uses_thread_index);
        }
      } else {
        has_same_index = false;
      }

      if (!(has_same_index && depends_on_thread_index)) {
        break;
      }
    }
    if (has_same_index && depends_on_thread_index) {
      return false;
    }

    // If this is a read into a double buffer that was previously
    // swapped out, then it doesn't conflict.
    if (prev.double_buffer_write && curr.type == kRead && !loop_carry) {
      return false;
    }

    // If nothing else allows sharing the same buffer, then they are
    // in conflict.
    return true;
  }

251
  void VisitStmt_(const AttrStmtNode *op) final {
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    if (op->attr_key == "kWarpSpecializationScope") {
      IfThenElse body = Downcast<IfThenElse>(op->body);
      auto partitions = Downcast<Array<IntImm>>(op->node);
      ICHECK(partitions.size() == 2);

      scope_.push_back(std::vector<StmtEntry>());
      num_partial_threads_ = partitions[0];
      this->VisitStmt(body->then_case);
      StmtEntry s;
      s.stmt = op;
      s.access = Summarize(std::move(scope_.back()), nullptr);
      scope_.pop_back();

      num_partial_threads_ = partitions[1];
      scope_.push_back(std::vector<StmtEntry>());
      VisitStmt(body->else_case.value());
      auto v = Summarize(std::move(scope_.back()), nullptr);
      scope_.pop_back();
      s.access.insert(s.access.end(), v.begin(), v.end());

      num_partial_threads_ = NullOpt;
    } else {
274
      TileLangStorageAccessVisitor::VisitStmt_(op);
275
276
277
    }
  }

278
279
280
281
282
  void insert_syncs(const Object *obj) {
    // ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside
    // condition";
    if (syncs_inserted_.count(obj))
      return;
283
284
    if (num_partial_threads_.defined()) {
      syncs_inserted_.insert(obj);
285
286
      partial_syncs_inserted_[obj] =
          static_cast<int>(num_partial_threads_.value()->value);
287
288
289
290
291
    } else {
      syncs_inserted_.insert(obj);
    }
  }

292
private:
293
294
295
296
297
  Optional<IntImm> num_partial_threads_;
  // synchronization scope
  StorageScope sync_scope_;
};

298
299
300
301
302
// There are cases where necessary syncthreads is not inserted by
// ThreadPartialSyncInserter. For example, syncthreads is needed after
// async_wait_queue in the second loop below, but since
// ThreadPartialSyncInserter is not aware of the asynchronous semantics, it
// cannot tell that the syncthreads is needed there.
303
304
305
306
307
308
309
310
311
312
313
314
315
316
//
// // Pipeline prologue
// for i in range(125):
//    async_commit_queue(0):
//       async_scope:
//          shared[(i + 3) % 4] = ...
// ...
//
// // Pipeline Epilogue
// for i in range(3):
//    async_wait_queue(0, 2 - i):
//       local[...] = shared[(i + 125) % 4]

class ThreadPartialSyncInserter : public StmtExprMutator {
317
318
319
320
public:
  ThreadPartialSyncInserter(
      StorageScope sync_scope, const std::unordered_set<const Object *> &syncs,
      std::unordered_map<const Object *, int> partial_syncs)
321
322
      : sync_scope_(sync_scope), syncs_(syncs), partial_syncs_(partial_syncs) {}

323
324
325
  Stmt VisitStmt(const Stmt &stmt) final {
    if (syncs_.size() == 0)
      return stmt;
326
327
328
329
330
    if (syncs_.count(stmt.get())) {
      Stmt barrier;
      if (partial_syncs_.count(stmt.get())) {
        auto iter = partial_syncs_.find(stmt.get());
        ICHECK(sync_scope_.rank == StorageRank::kShared);
331
332
        barrier = Evaluate(Call(DataType::Int(32), tl::SyncThreadsPartialOp(),
                                {iter->second}));
333
334
335
336
337
338
339
340
341
342
343
344
      } else {
        return StmtExprMutator::VisitStmt(stmt);
      }
      // Mutate after query, to avoid stmt change.
      auto ret = StmtExprMutator::VisitStmt(stmt);
      ret = SeqStmt({barrier, ret});
      return ret;
    } else {
      return StmtExprMutator::VisitStmt(stmt);
    }
  }

345
private:
346
347
  // data structure.
  StorageScope sync_scope_;
348
349
  const std::unordered_set<const Object *> &syncs_;
  const std::unordered_map<const Object *, int> &partial_syncs_;
350
351
};

352
Stmt TileLangThreadPartialSync(Stmt stmt, std::string storage_scope) {
353
  StorageScope sync_scope = StorageScope::Create(storage_scope);
354
  TileLangThreadPartialSyncPlanner planner(sync_scope);
355
356
  planner(stmt);
  return ThreadPartialSyncInserter(sync_scope, planner.syncs_inserted_,
357
358
                                   planner.partial_syncs_inserted_)(
      std::move(stmt));
359
360
361
362
363
364
}

using namespace tir::transform;

namespace transform {

365
Pass TileLangThreadPartialSync(String storage_scope) {
366
  auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
367
    auto *n = f.CopyOnWrite();
368
    n->body = tl::TileLangThreadPartialSync(std::move(n->body), storage_scope);
369
370
371
372
373
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.ThreadPartialSync", {});
}

374
TVM_REGISTER_GLOBAL("tl.transform.ThreadPartialSync")
375
    .set_body_typed(TileLangThreadPartialSync);
376

377
378
379
} // namespace transform
} // namespace tl
} // namespace tvm