thread_partial_sync.cc 12.3 KB
Newer Older
1
2
3
/*!
 * \file thread_storage_sync.cc
 */
4
5
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
6
7
8
9
10
11
12
13
14
15
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <unordered_map>
#include <unordered_set>

#include "../op/builtin.h"
16
#include "./storage_access.h"
17
18
19
20
21
22
23
24
#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {

using namespace tir;

25
class TileLangThreadPartialSyncPlanner : public TileLangStorageAccessVisitor {
26
public:
27
  explicit TileLangThreadPartialSyncPlanner(StorageScope sync_scope)
28
      : sync_scope_(sync_scope) {}
29
30

  // The syncs inserted before each statement
31
  std::unordered_set<const Object *> syncs_inserted_;
32
33
  std::unordered_map<const Object *, std::tuple<int, int>>
      partial_syncs_inserted_;
34

35
36
protected:
  bool Enabled(const VarNode *buf, const StorageScope &scope) const final {
37
38
39
    return in_device_env() && scope == sync_scope_;
  }
  // Plan the sync
40
41
  std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq,
                                     const ForNode *loop) final {
42
43
44
    // Redirect all "shared.dyn" buffer access to the same buffer var
    // so that the accesses can be planned together.
    Var shared_dyn_buf;
45
46
47
48
    for (StmtEntry &entry : seq) {
      for (AccessEntry &access : entry.access) {
        if (access.scope.rank == StorageRank::kShared &&
            access.scope.tag == ".dyn" && access.buffer.defined()) {
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
          if (!shared_dyn_buf.defined()) {
            shared_dyn_buf = access.buffer;
          } else {
            access.buffer = shared_dyn_buf;
          }
        }
      }
    }

    // Unsynced reads and writes
    std::vector<AccessEntry> reads;
    std::vector<AccessEntry> writes;
    // if it is a loop, rotate two times to consider effect of loop.
    // simulation based approach to find dependencies
    for (size_t i = 0; i < seq.size(); ++i) {
64
      const StmtEntry &s = seq[i];
65
66
67
68
69
70
71
      // check if sync before statement is needed.
      bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
      // Apply the syncs added already.
      if (sync_before_stmt) {
        reads.clear();
        writes.clear();
      }
72
      for (const AccessEntry &acc : s.access) {
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        if (acc.type == kRead) {
          if (FindConflict(writes, acc, false)) {
            sync_before_stmt = true;
            break;
          }
        } else if (acc.type == kWrite) {
          if (FindConflict(reads, acc, false)) {
            sync_before_stmt = true;
            break;
          }
        } else if (acc.type == kSync) {
          reads.clear();
          writes.clear();
        }
      }
      // If sync is inserted. remove the irrelevant things.
      if (sync_before_stmt) {
        reads.clear();
        writes.clear();
      }
      // Add the read/write of current statement
94
      for (const AccessEntry &acc : s.access) {
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        if (acc.type == kRead) {
          reads.push_back(acc);
        } else if (acc.type == kWrite) {
          writes.push_back(acc);
        } else if (acc.type == kSync) {
          reads.clear();
          writes.clear();
        }
      }
      if (sync_before_stmt) {
        insert_syncs(s.stmt);
      }
    }
    if (loop != nullptr) {
      for (size_t i = 0; i < seq.size(); ++i) {
110
111
112
113
114
        const StmtEntry &s = seq[i];
        if (syncs_inserted_.count(s.stmt) != 0)
          break;
        if (reads.empty() && writes.empty())
          break;
115
        bool sync_before_stmt = false;
116
        for (const AccessEntry &acc : s.access) {
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
          if (acc.type == kRead) {
            if (FindConflict(writes, acc, true)) {
              sync_before_stmt = true;
              break;
            }
          } else if (acc.type == kWrite) {
            if (FindConflict(reads, acc, true)) {
              sync_before_stmt = true;
              break;
            }
          } else if (acc.type == kSync) {
            reads.clear();
            writes.clear();
          }
        }
        if (sync_before_stmt) {
          insert_syncs(s.stmt);
          break;
        }
      }
    }
    // return the exposed entries, remove unnecessary ones.
    int sync_count = 0;
    // head are before first sync, tail are after last sync
    std::vector<AccessEntry> head, tail;
    AccessEntry esync;
    esync.threads = this->env_threads();
    esync.type = kSync;
    esync.scope = sync_scope_;

147
    for (const StmtEntry &s : seq) {
148
149
150
151
152
153
154
155
      if (syncs_inserted_.count(s.stmt)) {
        if (sync_count != 0) {
          tail.clear();
        } else {
          head.push_back(esync);
        }
        ++sync_count;
      }
156
      for (const AccessEntry &acc : s.access) {
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        if (acc.type == kSync) {
          if (sync_count != 0) {
            tail.clear();
          } else {
            head.push_back(esync);
          }
          ++sync_count;
        } else {
          if (sync_count != 0) {
            tail.push_back(acc);
          } else {
            head.push_back(acc);
          }
        }
      }
    }
    head.insert(head.end(), tail.begin(), tail.end());
    if (loop != nullptr) {
      // clear double buffer flag after a loop is finished.
176
      for (AccessEntry &e : head) {
177
178
179
180
181
182
        e.double_buffer_write = false;
      }
    }
    return head;
  }

183
private:
184
  // find conflicting entry in vec.
185
186
187
  bool FindConflict(const std::vector<AccessEntry> &prev,
                    const AccessEntry &curr, bool loop_carry) {
    for (const AccessEntry &x : prev) {
188
189
190
191
192
193
194
      if (FindConflict(x, curr, loop_carry)) {
        return true;
      }
    }
    return false;
  }

195
196
  bool FindConflict(const AccessEntry &prev, const AccessEntry &curr,
                    bool loop_carry) {
197
198
199
200
201
202
203
204
205
206
207
208
    // Access to different buffers does not conflict.
    if (!prev.buffer.same_as(curr.buffer)) {
      return false;
    }

    // Assumes no race between threads
    // Same index value means no conflicts
    // TODO(tqchen) more standard set based testing.
    bool has_same_index = true;
    // Even if access has the same index, those indices need to
    // depend on the innermost thread id to avoid race condition
    bool depends_on_thread_index = true;
209
    const VarNode *thread_index_var = nullptr;
210
211
212
213
214
    if (!curr.threads.empty()) {
      thread_index_var = curr.threads.back()->var.get();
    }

    for (size_t i = 0; i < prev.touched.size(); i++) {
215
216
      const auto &prev_intset = prev.touched[i];
      const auto &curr_intset = curr.touched[i];
217
218
219
220
221
222

      if (prev_intset.IsSinglePoint() && curr_intset.IsSinglePoint()) {
        PrimExpr prev_index = prev_intset.PointValue();
        PrimExpr curr_index = curr_intset.PointValue();
        has_same_index = ExprDeepEqual()(prev_index, curr_index);
        if (thread_index_var != nullptr) {
223
          auto f_uses_thread_index = [=](const tvm::tir::VarNode *parameter) {
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
            return parameter == thread_index_var;
          };
          depends_on_thread_index = depends_on_thread_index &&
                                    UsesVar(curr_index, f_uses_thread_index) &&
                                    UsesVar(prev_index, f_uses_thread_index);
        }
      } else {
        has_same_index = false;
      }

      if (!(has_same_index && depends_on_thread_index)) {
        break;
      }
    }
    if (has_same_index && depends_on_thread_index) {
      return false;
    }

    // If this is a read into a double buffer that was previously
    // swapped out, then it doesn't conflict.
    if (prev.double_buffer_write && curr.type == kRead && !loop_carry) {
      return false;
    }

    // If nothing else allows sharing the same buffer, then they are
    // in conflict.
    return true;
  }

253
  void VisitStmt_(const AttrStmtNode *op) final {
254
255
256
257
258
259
260
    if (op->attr_key == "kWarpSpecializationScope") {
      IfThenElse body = Downcast<IfThenElse>(op->body);
      auto partitions = Downcast<Array<IntImm>>(op->node);
      ICHECK(partitions.size() == 2);

      scope_.push_back(std::vector<StmtEntry>());
      num_partial_threads_ = partitions[0];
261
      barrier_id_ += 1;
262
263
264
265
266
      this->VisitStmt(body->then_case);
      StmtEntry s;
      s.stmt = op;
      s.access = Summarize(std::move(scope_.back()), nullptr);
      scope_.pop_back();
267
268
269
      if (!has_sync_)
        barrier_id_ -= 1;
      has_sync_ = false;
270
271
      num_partial_threads_ = partitions[1];
      scope_.push_back(std::vector<StmtEntry>());
272
      barrier_id_ += 1;
273
274
275
      VisitStmt(body->else_case.value());
      auto v = Summarize(std::move(scope_.back()), nullptr);
      scope_.pop_back();
276
277
278
      if (!has_sync_)
        barrier_id_ -= 1;
      has_sync_ = false;
279
280
      s.access.insert(s.access.end(), v.begin(), v.end());

281
      num_partial_threads_ = std::nullopt;
282
    } else {
283
      TileLangStorageAccessVisitor::VisitStmt_(op);
284
285
286
    }
  }

287
288
289
290
291
  void insert_syncs(const Object *obj) {
    // ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside
    // condition";
    if (syncs_inserted_.count(obj))
      return;
292
293
    if (num_partial_threads_.defined() && barrier_id_ >= 0 &&
        barrier_id_ < 16) {
294
      syncs_inserted_.insert(obj);
295
296
297
      partial_syncs_inserted_[obj] = std::make_tuple(
          static_cast<int>(num_partial_threads_.value()->value), barrier_id_);
      has_sync_ = true;
298
299
300
301
302
    } else {
      syncs_inserted_.insert(obj);
    }
  }

303
private:
304
305
306
  Optional<IntImm> num_partial_threads_;
  // synchronization scope
  StorageScope sync_scope_;
307
308
  int barrier_id_{-1};
  bool has_sync_{false};
309
310
};

311
312
313
314
315
// There are cases where necessary syncthreads is not inserted by
// ThreadPartialSyncInserter. For example, syncthreads is needed after
// async_wait_queue in the second loop below, but since
// ThreadPartialSyncInserter is not aware of the asynchronous semantics, it
// cannot tell that the syncthreads is needed there.
316
317
318
319
320
321
322
323
324
325
326
327
328
329
//
// // Pipeline prologue
// for i in range(125):
//    async_commit_queue(0):
//       async_scope:
//          shared[(i + 3) % 4] = ...
// ...
//
// // Pipeline Epilogue
// for i in range(3):
//    async_wait_queue(0, 2 - i):
//       local[...] = shared[(i + 125) % 4]

class ThreadPartialSyncInserter : public StmtExprMutator {
330
331
332
public:
  ThreadPartialSyncInserter(
      StorageScope sync_scope, const std::unordered_set<const Object *> &syncs,
333
      std::unordered_map<const Object *, std::tuple<int, int>> partial_syncs)
334
335
      : sync_scope_(sync_scope), syncs_(syncs), partial_syncs_(partial_syncs) {}

336
337
338
  Stmt VisitStmt(const Stmt &stmt) final {
    if (syncs_.size() == 0)
      return stmt;
339
340
341
342
343
    if (syncs_.count(stmt.get())) {
      Stmt barrier;
      if (partial_syncs_.count(stmt.get())) {
        auto iter = partial_syncs_.find(stmt.get());
        ICHECK(sync_scope_.rank == StorageRank::kShared);
344
345
346
347
        int num_threads, barrier_id;
        std::tie(num_threads, barrier_id) = iter->second;
        barrier = Evaluate(Call(DataType::Int(32), tl::sync_thread_partial(),
                                {num_threads, barrier_id}));
348
349
350
351
352
353
354
355
356
357
358
359
      } else {
        return StmtExprMutator::VisitStmt(stmt);
      }
      // Mutate after query, to avoid stmt change.
      auto ret = StmtExprMutator::VisitStmt(stmt);
      ret = SeqStmt({barrier, ret});
      return ret;
    } else {
      return StmtExprMutator::VisitStmt(stmt);
    }
  }

360
private:
361
362
  // data structure.
  StorageScope sync_scope_;
363
  const std::unordered_set<const Object *> &syncs_;
364
365
  const std::unordered_map<const Object *, std::tuple<int, int>>
      &partial_syncs_;
366
367
};

368
Stmt TileLangThreadPartialSync(Stmt stmt, std::string storage_scope) {
369
  StorageScope sync_scope = StorageScope::Create(storage_scope);
370
  TileLangThreadPartialSyncPlanner planner(sync_scope);
371
372
  planner(stmt);
  return ThreadPartialSyncInserter(sync_scope, planner.syncs_inserted_,
373
374
                                   planner.partial_syncs_inserted_)(
      std::move(stmt));
375
376
377
378
379
380
}

using namespace tir::transform;

namespace transform {

381
Pass TileLangThreadPartialSync(String storage_scope) {
382
  auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
383
    auto *n = f.CopyOnWrite();
384
    n->body = tl::TileLangThreadPartialSync(std::move(n->body), storage_scope);
385
386
387
388
389
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.ThreadPartialSync", {});
}

390
391
392
393
394
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.ThreadPartialSync",
                        TileLangThreadPartialSync);
});
395

396
397
398
} // namespace transform
} // namespace tl
} // namespace tvm