parallel.cc 18.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/*!
 * \file op/parallel.cc
 * \brief Define Parallel for operator
 */

#include "parallel.h"

#include <tvm/tir/op.h>

#include "../layout/utils.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"

namespace tvm {
namespace tl {

using namespace tir;

namespace attr {
/*! \brief Mark that how the loop is vectorized. */
constexpr const char *coalesced_width = "coalesced_width";
23
} // namespace attr
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// ProveFragmentContains checks whether the threads that access elements of a
// smaller fragment (small_frag) are a subset of the threads that access
// elements of a larger fragment (large_frag) for any given loop index. This
// function ensures that if the small fragment's layout corresponds to the loop
// itself, accessing the large fragment's elements is valid. Additionally, if
// small is updated to large, the originally valid access remains valid. The
// proof is performed by:
//
// 1. Defining a variable `rep_small` to represent the replicate index of the
//    small fragment that is being checked.
// 2. Using the `small_frag_indices` and `rep_small` to derive the thread
// accessing
//    the element in the small fragment.
// 3. Using `large_frag_indices` to derive the physical index of the large
// fragment
//    along with the thread information, and then feeding these into the inverse
//    of the large fragment to obtain the logical index and replicate index.
// 4. Verifying the mapping by checking whether the computed thread using the
// inverse
//    layout corresponds to the original thread calculated for the small
//    fragment. If they don't match, this indicates that the inverse layout's
//    domain does not include the thread and thus the access is invalid.
bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
                           Array<PrimExpr> small_frag_indices,
                           Array<PrimExpr> large_frag_indices,
                           arith::Analyzer &analyzer_) {
  Var rep_small("__checking_frag_contains_rep");
  analyzer_.Bind(rep_small,
                 Range(IntImm(small_frag->ReplicateExtent()->dtype, 0),
                       small_frag->ReplicateExtent()),
                 true); // Bind the replicate extent of small_frag.
  // Derive thread for small_frag.
  auto thread = small_frag->ForwardThread(small_frag_indices, rep_small);

  // Get physical index and thread for large_frag.
  auto large_frag_physical_and_thread = large_frag->Forward(large_frag_indices);
  // Add small_frag's thread to the large fragment's thread info.
  large_frag_physical_and_thread.push_back(thread);
  // Get the inverse of the large fragment.
  auto inv_large_frag = large_frag->Inverse();
  // Compute logical index and replicate index using inverse layout.
  auto inv_large_frag_logical_and_rep =
      inv_large_frag->Forward(large_frag_physical_and_thread);

  // Extract replicate index from the result.
  auto inv_large_frag_rep =
      inv_large_frag_logical_and_rep[inv_large_frag_logical_and_rep.size() - 1];

  // Calculate thread based on the logical index and replicate index.
  auto check_thread =
      large_frag->ForwardThread(large_frag_indices, inv_large_frag_rep);

  // Simplify the difference between the threads.
  auto diff = analyzer_.Simplify(thread - check_thread);
  // If the difference is zero, the threads match and the access is valid.
  return is_zero(diff);
}

83
class IfBufferRemapLoopGenerator : public StmtExprMutator {
84
public:
85
86
87
88
89
90
  static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
                 Map<Buffer, Layout> layout_map) {
    IfBufferRemapLoopGenerator generator(buffer_remap, layout_map);
    return Downcast<For>(generator(std::move(stmt)));
  }

91
92
93
private:
  IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap,
                             Map<Buffer, Layout> layout_map)
94
95
      : buffer_remap_(buffer_remap), layout_map_(layout_map) {}

96
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
97
98
99
100
101
102
103
104
105
106
107
    auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));

    if (buffer_remap_.count(load->buffer)) {
      auto new_indices = layout_map_[load->buffer]->Forward(load->indices);
      auto new_buffer = buffer_remap_[load->buffer];

      return BufferLoad(new_buffer, new_indices);
    }
    return load;
  }

108
  Stmt VisitStmt_(const BufferStoreNode *op) final {
109
110
111
112
113
114
115
116
117
118
119
120
121
    auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    if (buffer_remap_.count(store->buffer)) {
      auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
      auto new_buffer = buffer_remap_[store->buffer];
      return BufferStore(new_buffer, store->value, new_indices);
    }
    return store;
  }

  Map<Buffer, Buffer> buffer_remap_;
  Map<Buffer, Layout> layout_map_;
};

122
123
124
125
126
127
128
129
/**
 * @brief Handle a parallel For node during traversal, collecting loop metadata.
 *
 * Visits a parallel loop, asserts the loop is parallel, records a data-parallel
 * IterVar for the loop, binds the loop variable range into the analyzer scope,
 * and extracts any reducer information from the loop's annotations into the
 * visitor's reducer_info_map_. Continues traversal into the loop body.
 */
130
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
131
132
133
134
135
136
137
  if (op->kind == ForKind::kParallel)
    p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var,
                                    IterVarType::kDataPar));
  else
    p->inner_vars_.Set(op->loop_var,
                       IterVar(Range(op->min, op->extent), op->loop_var,
                               IterVarType::kOrdered));
138
  p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
139
140
141
142
143
144
  auto reducer_info_map =
      op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
  if (reducer_info_map) {
    for (auto &&[buffer, info] : reducer_info_map.value())
      p->reducer_info_map_.Set(buffer, info);
  }
145
146
147
  StmtExprVisitor::VisitStmt_(op);
}

148
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
149
150
151
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
152
153
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
154
155
156
157
158
159
160
161
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
    p->buffer_is_write_.insert(op->buffer);
  }
  StmtExprVisitor::VisitStmt_(op);
}

162
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
163
164
165
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
166
167
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
168
169
170
171
172
173
174
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
  }
  StmtExprVisitor::VisitExpr_(op);
}

175
176
177
178
179
180
181
182
183
184
185
186
187
ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) {
  V.VisitStmt(root);
}

TileOperator ParallelOpNode::Clone() const {
  auto op = make_object<ParallelOpNode>(*this);
  return ParallelOp(op);
}

Stmt ParallelOpNode::Lower(const LowerArgs &T,
                           arith::Analyzer *analyzer) const {
  return root_;
}
188

189
bool ParallelOpNode::IsCommonAccessIndice(const Buffer &buffer) const {
190
  auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
191
192
193
  return StructuralEqual()(indice_map_[buffer], common_indice);
}

194
195
/*! \brief Infer the layout for parallel operations based on different inference
 * levels
196
 *
197
198
199
200
201
 * The inference level controls how aggressively we try to infer and optimize
 * layouts:
 * - kStrict (2): Most conservative level. Only allows explicitly defined
 * layouts. Returns empty layout map if loop_layout_ is not already defined.
 *                Used when exact layout control is required.
202
 *
203
204
205
 * - kCommon (1): Intermediate level between strict and free.
 *                Allows common layout patterns while maintaining some
 * constraints.
206
 *
207
208
209
210
 * - kFree (0):   Most permissive level. Allows maximum optimization freedom.
 *                Will attempt layout inference even without source buffers.
 *                Can generate new layouts based on vectorization and thread
 * bounds. Used when maximum performance optimization is desired.
211
 */
212
213
LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
                                      InferLevel level) const {
214
215
216
217
  if (loop_layout_.defined())
    return {};
  if (level == InferLevel::kStrict)
    return {};
218
219
220

  // Step 1: try to infer loop's partition from a source fragment
  Buffer source_buffer, read_source_buffer;
221
  for (const auto &[buffer, indices] : indice_map_) {
222
    if (T.layout_map.count(buffer)) {
223
224
225
226
227
      // skip reducers with rep=ALL
      if (auto info = reducer_info_map_.Get(buffer->data);
          info && info.value()->rep == ReducerRepType::ALL)
        continue;

228
      auto frag = T.layout_map[buffer].as<Fragment>().value();
229
      if (buffer_is_write_.count(buffer)) {
230
        source_buffer = buffer;
231
232
233
234
235
236
237
238
239
      } else {
        // Keep the buffer with largest number of indices
        // (which means the inference based on that buffer is more accurate)
        // as read_source_buffer to get more accurate layout
        if (!read_source_buffer.defined() ||
            indice_map_[buffer].size() >
                indice_map_[read_source_buffer].size()) {
          read_source_buffer = buffer;
        }
240
241
242
243
        // If the buffer is not replicated and shape is equal to the
        // source_buffer, use it as source_buffer because the layout inference
        // is more accurate
        if (is_one(frag->ReplicateExtent()) && !source_buffer.defined()) {
244
245
          source_buffer = buffer;
        }
246
      }
247
248
    }
  }
249
  auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
250
    Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
251
252
253
    DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `"
               << buffer << "` of layout " << src_layout->DebugOutput() << '\n';
    Fragment result;
254
    if (IsCommonAccessIndice(buffer)) {
255
      result = src_layout;
256
257
    } else {
      Var rep;
258
259
260
261
      auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
                              IterVarType::kDataPar);
      PrimExpr loop_var_to_thread =
          src_layout->ForwardThread(indice_map_[buffer], rep);
262
263
264
265
266
267
268
269
270
271
272
273
      loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
      PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
        if (auto opt_var = objref.as<Var>();
            opt_var && inner_vars_.count(*opt_var)) {
          std::ostringstream oss;
          oss << "loop_var_to_thread = " << loop_var_to_thread
              << "contains inner var" << *opt_var;
          throw LayoutConflictException(oss.str());
        }
      });
      result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
                   ->BindThreadRange(T.thread_bounds);
274
    }
275
276
277
    DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get "
               << result->DebugOutput() << '\n';
    return result;
278
279
280
281
282
283
  };
  if (source_buffer.defined()) {
    loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
  } else if (level == InferLevel::kFree) {
    if (read_source_buffer.defined()) {
      loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
284
285
286
      // // Loop don't need to be replicated.
      // if (!is_one(loop_layout_->ReplicateExtent()))
      //   loop_layout_ = loop_layout_->DeReplicate();
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

      // For free layout inference
      // If replication exists and buffer has cross-thread shared memory access,
      // add predicate
      bool has_cross_thread_access = false;
      PostOrderVisit(root_, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
          // check if scope is shared or global
          if (store->buffer.scope() == "shared" ||
              store->buffer.scope() == "shared.dyn" ||
              store->buffer.scope() == "global") {
            has_cross_thread_access = true;
          }
        } else if (const auto *load = obj.as<BufferLoadNode>()) {
          // check if scope is shared or global
          if (load->buffer.scope() == "shared" ||
              load->buffer.scope() == "shared.dyn" ||
              load->buffer.scope() == "global") {
            has_cross_thread_access = true;
          }
        }
      });

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
      // check if loop body contains a "pure" buffer store (i.e., direct
      // assignment, not compound update)
      bool has_pure_buffer_store = false;
      PostOrderVisit(root_, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
          // Check if the value is a direct load from another buffer (i.e., b[i]
          // = a[i])
          if (const auto *load = store->value.as<BufferLoadNode>()) {
            has_pure_buffer_store = true;
          }
        }
      });

      if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
          !has_pure_buffer_store) {
325
326
327
328
        auto inv = loop_layout_->Inverse();
        Array<PrimExpr> fwd;
        for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
          fwd.push_back(0);
329
        fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
330
331
332
        auto rep = inv->Forward(fwd).back();
        AddPredicate(EQ(rep, 0));
      }
333
334
335
    } else {
      // Vectorize Size must be aware of the buffer_remap
      // As the pass will do post processing to the layout
336
337
      auto maybe_remapped_root_ =
          IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
338
339
      int vector_size = GetVectorizeSize(maybe_remapped_root_);

340
341
      DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n';

342
343
344
345
      PrimExpr loop_total_size = 1;
      for (Stmt l = root_; l.as<For>().has_value();
           l = l.as<For>().value()->body)
        loop_total_size = loop_total_size * l.as<For>().value()->extent;
346
347
      DLOG(INFO) << "[PlanLoopPartition] loop_total_size = " << loop_total_size
                 << '\n';
348
349
350
351
352
      while (!analyzer_.CanProve(
                 floormod(loop_total_size,
                          T.thread_bounds->extent * vector_size) == 0) &&
             vector_size > 1)
        vector_size /= 2;
353
354
      DLOG(INFO) << "[PlanLoopPartition] after adjust: vector_size = "
                 << vector_size << '\n';
355

356
      // Check if coalesced_width is defined
357
358
      if (auto coalesced_width =
              root_->annotations.Get(tl::attr::coalesced_width)) {
359
        if (const auto *imm = coalesced_width->as<IntImmNode>()) {
360
361
362
          int expected = imm->value;
          // Verify that vector_size is divisible by expected
          if (vector_size % expected != 0) {
363
364
            LOG(FATAL) << "Vector size " << vector_size
                       << " is not divisible by coalesced width " << expected;
365
366
367
368
369
370
          }
          vector_size = expected;
        } else {
          LOG(FATAL) << "coalesced_width should be an IntImmNode.";
        }
      }
371
372
373
      DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_
                 << " ############# vector_size = " << vector_size
                 << ", thread_bounds = " << T.thread_bounds << '\n';
374
      loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds);
375
376
      DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
                 << loop_layout_->DebugOutput() << '\n';
377
378
379
380
    }
  } else {
    return {};
  }
381

382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
  PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();

  auto block_size = T.thread_bounds->extent;
  if (loop_layout_.defined()) {
    if (loop_layout_->ThreadRange().defined()) {
      auto thread_range = loop_layout_->ThreadRange();
      block_size = thread_range->extent;
      AddPredicate(GE(InputPlaceholder(0), thread_range->min));
      AddPredicate(
          LT(InputPlaceholder(0), thread_range->min + thread_range->extent));
    }
  }

  if (!analyzer_.CanProveEqual(loop_thread_extent, block_size)) {
    AddPredicate(
        LT(InputPlaceholder(0), loop_thread_extent + T.thread_bounds->min));
  }

400
  // Step 2: Check that the loop's partition can correctly align with all source
401
402
  // fragment, and infer layout only when it's not yet layout-ed
  LayoutMap results;
403
  for (const auto &[buffer, _] : indice_map_) {
404
405
    if (T.layout_map.count(buffer)) {
      auto fragment = T.layout_map[buffer].as<Fragment>().value();
406
407
      auto vars =
          loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
408
409
410
411
      if (!ProveFragmentContains(loop_layout_, fragment, vars,
                                 indice_map_[buffer], analyzer_)) {
        std::ostringstream oss;
        oss << "Layout infer conflict between " << buffer << " and "
412
413
414
            << source_buffer << " in T.Parallel loop:" << '\n'
            << "    loop " << loop_layout_->DebugOutput() << '\n'
            << "    fragment " << fragment->DebugOutput() << '\n';
415
        throw LayoutConflictException(oss.str());
416
      }
417
418
419
420
    } else {
      auto dst_layout =
          CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
      results.Set(buffer, dst_layout);
421
    }
422
423
424
425
  }
  return results;
}

426
Optional<PrimExpr> ParallelOpNode::GetPredicate(Var thread_var) const {
427
428
429
  if (predicate_.defined()) {
    return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
  } else {
430
    return std::nullopt;
431
432
433
  }
}

434
Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
435
  ICHECK(loop_layout_.defined());
436
  if (IsCommonAccessIndice(buffer)) {
437
    return loop_layout_;
438
  }
439
440
  PrimExpr rep_b = MakeFlattenedExpression(
      DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
441
442
443
  auto bijective_indice = indice_map_[buffer];
  bijective_indice.push_back(rep_b);
  Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
444
445
  PrimExpr indice_rep_extent =
      ind_inv->InputShape().back(); // this is the size of rep_b
446
447
448
449
450
451
452
453
  PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
  PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
  Array<PrimExpr> fwd;
  for (size_t i = 0; i < buffer->shape.size(); i++) {
    fwd.push_back(InputPlaceholder(i));
  }
  fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
  PrimExpr thd_b = loop_layout_->ForwardThread(
454
455
      ind_inv->Forward(fwd),
      FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
456
457
  return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent,
                  std::nullopt)
458
459
460
      ->CondenseReplicateVar();
}

461
462
TVM_FFI_STATIC_INIT_BLOCK({ ParallelOpNode::RegisterReflection(); });

463
464
} // namespace tl
} // namespace tvm