parallel.cc 28.1 KB
Newer Older
1
2
3
4
5
6
7
/*!
 * \file op/parallel.cc
 * \brief Define Parallel for operator
 */

#include "parallel.h"

8
#include <algorithm>
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#include <tvm/tir/op.h>

#include "../layout/utils.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"

namespace tvm {
namespace tl {

using namespace tir;

namespace attr {
/*! \brief Mark that how the loop is vectorized. */
constexpr const char *coalesced_width = "coalesced_width";
24
} // namespace attr
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
// ProveFragmentContains checks whether the threads that access elements of a
// smaller fragment (small_frag) are a subset of the threads that access
// elements of a larger fragment (large_frag) for any given loop index. This
// function ensures that if the small fragment's layout corresponds to the loop
// itself, accessing the large fragment's elements is valid. Additionally, if
// small is updated to large, the originally valid access remains valid. The
// proof is performed by:
//
// 1. Defining a variable `rep_small` to represent the replicate index of the
//    small fragment that is being checked.
// 2. Using the `small_frag_indices` and `rep_small` to derive the thread
// accessing
//    the element in the small fragment.
// 3. Using `large_frag_indices` to derive the physical index of the large
// fragment
//    along with the thread information, and then feeding these into the inverse
//    of the large fragment to obtain the logical index and replicate index.
// 4. Verifying the mapping by checking whether the computed thread using the
// inverse
//    layout corresponds to the original thread calculated for the small
//    fragment. If they don't match, this indicates that the inverse layout's
//    domain does not include the thread and thus the access is invalid.
bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
                           Array<PrimExpr> small_frag_indices,
                           Array<PrimExpr> large_frag_indices,
                           arith::Analyzer &analyzer_) {
  Var rep_small("__checking_frag_contains_rep");
  analyzer_.Bind(rep_small,
                 Range(IntImm(small_frag->ReplicateExtent()->dtype, 0),
                       small_frag->ReplicateExtent()),
                 true); // Bind the replicate extent of small_frag.
  // Derive thread for small_frag.
  auto thread = small_frag->ForwardThread(small_frag_indices, rep_small);

  // Get physical index and thread for large_frag.
  auto large_frag_physical_and_thread = large_frag->Forward(large_frag_indices);
  // Add small_frag's thread to the large fragment's thread info.
  large_frag_physical_and_thread.push_back(thread);
  // Get the inverse of the large fragment.
  auto inv_large_frag = large_frag->Inverse();
  // Compute logical index and replicate index using inverse layout.
  auto inv_large_frag_logical_and_rep =
      inv_large_frag->Forward(large_frag_physical_and_thread);

  // Extract replicate index from the result.
  auto inv_large_frag_rep =
      inv_large_frag_logical_and_rep[inv_large_frag_logical_and_rep.size() - 1];

  // Calculate thread based on the logical index and replicate index.
  auto check_thread =
      large_frag->ForwardThread(large_frag_indices, inv_large_frag_rep);

  // Simplify the difference between the threads.
  auto diff = analyzer_.Simplify(thread - check_thread);
  // If the difference is zero, the threads match and the access is valid.
  return is_zero(diff);
}

84
class IfBufferRemapLoopGenerator : public StmtExprMutator {
85
public:
86
87
88
89
90
91
  static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
                 Map<Buffer, Layout> layout_map) {
    IfBufferRemapLoopGenerator generator(buffer_remap, layout_map);
    return Downcast<For>(generator(std::move(stmt)));
  }

92
93
94
private:
  IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap,
                             Map<Buffer, Layout> layout_map)
95
96
      : buffer_remap_(buffer_remap), layout_map_(layout_map) {}

97
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
98
99
100
101
102
103
104
105
106
107
108
    auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));

    if (buffer_remap_.count(load->buffer)) {
      auto new_indices = layout_map_[load->buffer]->Forward(load->indices);
      auto new_buffer = buffer_remap_[load->buffer];

      return BufferLoad(new_buffer, new_indices);
    }
    return load;
  }

109
  Stmt VisitStmt_(const BufferStoreNode *op) final {
110
111
112
113
114
115
116
117
118
119
120
121
122
    auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    if (buffer_remap_.count(store->buffer)) {
      auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
      auto new_buffer = buffer_remap_[store->buffer];
      return BufferStore(new_buffer, store->value, new_indices);
    }
    return store;
  }

  Map<Buffer, Buffer> buffer_remap_;
  Map<Buffer, Layout> layout_map_;
};

123
124
125
126
127
128
129
130
/**
 * @brief Handle a parallel For node during traversal, collecting loop metadata.
 *
 * Visits a parallel loop, asserts the loop is parallel, records a data-parallel
 * IterVar for the loop, binds the loop variable range into the analyzer scope,
 * and extracts any reducer information from the loop's annotations into the
 * visitor's reducer_info_map_. Continues traversal into the loop body.
 */
131
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
132
133
134
135
136
137
138
  if (op->kind == ForKind::kParallel)
    p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var,
                                    IterVarType::kDataPar));
  else
    p->inner_vars_.Set(op->loop_var,
                       IterVar(Range(op->min, op->extent), op->loop_var,
                               IterVarType::kOrdered));
139
  p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
140
141
142
143
144
145
  auto reducer_info_map =
      op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
  if (reducer_info_map) {
    for (auto &&[buffer, info] : reducer_info_map.value())
      p->reducer_info_map_.Set(buffer, info);
  }
146
147
148
  StmtExprVisitor::VisitStmt_(op);
}

149
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
150
151
152
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
153
154
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
155
156
157
158
159
160
161
162
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
    p->buffer_is_write_.insert(op->buffer);
  }
  StmtExprVisitor::VisitStmt_(op);
}

163
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
164
165
166
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
167
168
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
169
170
171
172
173
174
175
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
  }
  StmtExprVisitor::VisitExpr_(op);
}

176
177
178
179
180
ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) {
  V.VisitStmt(root);
}

TileOperator ParallelOpNode::Clone() const {
181
  auto op = tvm::ffi::make_object<ParallelOpNode>(*this);
182
183
184
  return ParallelOp(op);
}

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
void ParallelOpNode::ExpandLetBindings(
    const Map<Var, PrimExpr> &let_var_to_expr) {
  if (let_var_to_expr.empty())
    return;

  // Helper function to recursively find BufferLoads through let bindings
  std::function<void(const PrimExpr &)> expand = [&](const PrimExpr &expr) {
    PostOrderVisit(expr, [&](const ObjectRef &node) {
      if (auto bl = node.as<BufferLoadNode>()) {
        if (bl->buffer.scope() == "local.fragment" &&
            !indice_map_.count(bl->buffer)) {
          indice_map_.Set(bl->buffer, bl->indices);
        }
      } else if (auto var_node = node.as<VarNode>()) {
        auto var = tvm::ffi::GetRef<Var>(var_node);
        if (let_var_to_expr.count(var)) {
          expand(let_var_to_expr[var]);
        }
      }
    });
  };

  // Scan all let bindings
  for (const auto &[var, expr] : let_var_to_expr) {
    expand(expr);
  }
}

213
214
215
216
Stmt ParallelOpNode::Lower(const LowerArgs &T,
                           arith::Analyzer *analyzer) const {
  return root_;
}
217

218
bool ParallelOpNode::IsCommonAccessIndice(const Buffer &buffer) const {
219
  auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
220
221
222
  return StructuralEqual()(indice_map_[buffer], common_indice);
}

223
224
/*! \brief Infer the layout for parallel operations based on different inference
 * levels
225
 *
226
227
228
229
230
 * The inference level controls how aggressively we try to infer and optimize
 * layouts:
 * - kStrict (2): Most conservative level. Only allows explicitly defined
 * layouts. Returns empty layout map if loop_layout_ is not already defined.
 *                Used when exact layout control is required.
231
 *
232
233
234
 * - kCommon (1): Intermediate level between strict and free.
 *                Allows common layout patterns while maintaining some
 * constraints.
235
 *
236
237
238
239
 * - kFree (0):   Most permissive level. Allows maximum optimization freedom.
 *                Will attempt layout inference even without source buffers.
 *                Can generate new layouts based on vectorization and thread
 * bounds. Used when maximum performance optimization is desired.
240
 */
241
242
LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
                                      InferLevel level) const {
243
244
  if (loop_layout_.defined())
    return {};
245

246
247
248
249
250
  // Expand let bindings to find fragment buffer accesses
  if (!T.let_var_to_expr.empty()) {
    const_cast<ParallelOpNode *>(this)->ExpandLetBindings(T.let_var_to_expr);
  }

251
252
  if (level == InferLevel::kStrict) {
    LayoutMap results;
253
    // Deduce buffers that should be complicated replicated.
254
    // For example:
255
    // for i in T.Parallel(m):
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    //   fragment[0] = x[i]
    // then fragment[0] must be replicated on all threads.
    for (const auto &[buffer, indices] : indice_map_) {
      if (T.layout_map.count(buffer)) {
        continue;
      }
      if (buffer.scope() != "local.fragment")
        continue;

      // Check if all indices are zero
      bool all_indices_zero = true;
      for (const auto &index : indices) {
        if (const auto *imm = index.as<IntImmNode>()) {
          if (imm->value != 0) {
            all_indices_zero = false;
            LOG(FATAL)
                << "Fragment buffer access with non-zero index [" << imm->value
                << "] is not supported. "
                << "Only fragment[0] access is allowed within T.Parallel loop.";
          }
        } else {
          // Non-constant index, not all zero
          all_indices_zero = false;
        }
      }

      // Only set layout if all indices are zero
      if (all_indices_zero) {
        Array<IterVar> forward_vars;
        for (const auto &s : buffer->shape) {
          forward_vars.push_back(
              IterVar(Range(0, s), Var(), IterVarType::kDataPar));
        }
        Var rep;
        auto rep_iter =
            IterVar({0, T.thread_bounds->extent}, rep, IterVarType::kDataPar);

293
294
295
        // Use default fragment indexing (single output dim) to
        // stay consistent with other ops (e.g., ReduceOp), and
        // bind the thread range for comparability.
296
        const PrimExpr &forward_thread = rep;
297
298
299
300
        auto frag = Fragment(forward_vars, /*forward_index=*/{}, forward_thread,
                             rep_iter)
                        ->BindThreadRange(T.thread_bounds);
        results.Set(buffer, frag);
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
      }
    }
    return results;
  }
  auto buffer_is_completed_replicated = [&](const Buffer &buffer) {
    if (buffer.scope() != "local.fragment")
      return false;
    auto frag = T.layout_map[buffer].as<Fragment>().value();
    // buffer indices should be IntImm
    for (const auto &index : indice_map_[buffer]) {
      if (!index.as<IntImmNode>()) {
        return false;
      } else if (index.as<IntImmNode>()->value != 0) {
        LOG(FATAL) << "buffer " << buffer << " is not completed replicated";
      }
    }
    return frag->IsCompletedReplicated();
  };
  // Collect fragment buffers with const index and all fragment_buffers
  std::vector<Buffer> const_index_fragment_buffer, fragment_buffers;
  for (const auto &[buffer, indices] : indice_map_) {
    if (buffer.scope() != "local.fragment")
      continue;
    fragment_buffers.push_back(buffer);

    bool is_const_index = true;
    for (const auto &index : indices) {
      if (!index.as<IntImmNode>()) {
        is_const_index = false;
        break;
      }
    }
    if (is_const_index) {
      const_index_fragment_buffer.push_back(buffer);
    }
  }

  // Determine if common layout propagation should be applied.
  // If there are fragment buffers with non-constant indices, we need to
  // propagate the common layout pattern to ensure consistency across all
  // fragments. Example cases:
  //   - Need propagation: frag_a[0] = T.min(frag_a[0], frag_b[i])
  //     (const index frag_a interacts with non-const index frag_b)
  //   - No propagation needed: shared_a[i] = frag_a[0]
  //     (const index frag_a with non-fragment buffer)
346

347
  bool allow_layout_propgate =
348
349
      const_index_fragment_buffer.empty() ||
      (fragment_buffers.size() > const_index_fragment_buffer.size());
350
351
352

  // Step 1: try to infer loop's partition from a source fragment
  Buffer source_buffer, read_source_buffer;
353
354
  Buffer replicated_write_buffer; // Backup: fully replicated write buffer

355
  for (const auto &[buffer, indices] : indice_map_) {
356
    if (T.layout_map.count(buffer)) {
357
358
359
360
361
      // skip reducers with rep=ALL
      if (auto info = reducer_info_map_.Get(buffer->data);
          info && info.value()->rep == ReducerRepType::ALL)
        continue;

362
      auto frag = T.layout_map[buffer].as<Fragment>().value();
363
364
      bool is_fully_replicated = buffer_is_completed_replicated(buffer);

365
      if (buffer_is_write_.count(buffer)) {
366
        source_buffer = buffer;
367
368
369
370
      } else {
        // Keep the buffer with largest number of indices
        // (which means the inference based on that buffer is more accurate)
        // as read_source_buffer to get more accurate layout
371
372
373
374
375
        // if the buffer is completed replicated, we don't need to infer the
        // layout from this buffer.
        if ((!read_source_buffer.defined() ||
             indice_map_[buffer].size() >
                 indice_map_[read_source_buffer].size())) {
376
377
          read_source_buffer = buffer;
        }
378
379
380
381
        // If the buffer is not replicated and shape is equal to the
        // source_buffer, use it as source_buffer because the layout inference
        // is more accurate
        if (is_one(frag->ReplicateExtent()) && !source_buffer.defined()) {
382
383
          source_buffer = buffer;
        }
384
      }
385
386
    }
  }
387
  auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
388
    Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
389
390
    DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `"
               << buffer << "` of layout " << src_layout->DebugOutput() << '\n';
391

392
    Fragment result;
393
    if (IsCommonAccessIndice(buffer)) {
394
      result = src_layout;
395
396
    } else {
      Var rep;
397
398
399
400
      auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
                              IterVarType::kDataPar);
      PrimExpr loop_var_to_thread =
          src_layout->ForwardThread(indice_map_[buffer], rep);
401
      loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
402
403
404
405
406
407
408
409
410
      PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
        if (auto opt_var = objref.as<Var>();
            opt_var && inner_vars_.count(*opt_var)) {
          std::ostringstream oss;
          oss << "loop_var_to_thread = " << loop_var_to_thread
              << "contains inner var" << *opt_var;
          throw LayoutConflictException(oss.str());
        }
      });
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425

      try {
        result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
                     ->BindThreadRange(T.thread_bounds);
      } catch (const tvm::runtime::Error &err) {
        std::ostringstream msg;
        msg << "Layout inference for buffer `" << buffer->name
            << "` failed inside `T.parallel` loop.";

        msg << "\nUnderlying TVM error: " << err.what();
        msg << "\nProblematic loop AST:\n " << root_;
        msg << "\nHint: ensure the loop extent divides the thread binding or "
               "adjust the fragment mapping.";
        LOG(FATAL) << msg.str();
      }
426
    }
427
428
429
    DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get "
               << result->DebugOutput() << '\n';
    return result;
430
  };
431
432
433
434
435
436
437
438

  // Try to infer loop layout from buffers in order of preference:
  // 1. Non-replicated write buffer (most reliable)
  // 2. Non-replicated read buffer
  // 3. Fully replicated write buffer (backup, may cause issues)
  // 4. Free inference mode (no source buffer)

  if (source_buffer.defined() && allow_layout_propgate) {
439
440
    loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
  } else if (level == InferLevel::kFree) {
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    // For free layout inference
    // If replication exists and buffer has cross-thread shared memory access,
    // add predicate
    bool has_cross_thread_access = false;
    PostOrderVisit(root_, [&](const ObjectRef &obj) {
      if (const auto *store = obj.as<BufferStoreNode>()) {
        // check if scope is shared or global
        if (store->buffer.scope() == "shared" ||
            store->buffer.scope() == "shared.dyn" ||
            store->buffer.scope() == "global") {
          has_cross_thread_access = true;
        }
      } else if (const auto *load = obj.as<BufferLoadNode>()) {
        // check if scope is shared or global
        if (load->buffer.scope() == "shared" ||
            load->buffer.scope() == "shared.dyn" ||
            load->buffer.scope() == "global") {
          has_cross_thread_access = true;
        }
      }
    });

    // check if loop body contains a "pure" buffer store (i.e., direct
    // assignment, not compound update)
465
466
467
468
    std::vector<Buffer> store_shared_global_buffers, store_fragment_buffers;
    // Buffers that scope is above fragments.
    // global, shared, shared.dyn
    // which can be used to analysis replicate case
469
470
    PostOrderVisit(root_, [&](const ObjectRef &obj) {
      if (const auto *store = obj.as<BufferStoreNode>()) {
471
472
473
474
475
476
        auto buffer = store->buffer;
        if (buffer.scope() == "shared" || buffer.scope() == "shared.dyn" ||
            buffer.scope() == "global") {
          store_shared_global_buffers.emplace_back(buffer);
        } else if (buffer.scope() == "local.fragment") {
          store_fragment_buffers.emplace_back(buffer);
477
478
479
        }
      }
    });
480
    if (read_source_buffer.defined() && allow_layout_propgate) {
481
      loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
482
483
484
485
    }

    if (!loop_layout_.defined()) {
      // No source buffer available, use free mode inference
486
487
      // Vectorize Size must be aware of the buffer_remap
      // As the pass will do post processing to the layout
488
489
      auto maybe_remapped_root_ =
          IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
490
      int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer);
491
492
      DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n';

493
494
495
496
      PrimExpr loop_total_size = 1;
      for (Stmt l = root_; l.as<For>().has_value();
           l = l.as<For>().value()->body)
        loop_total_size = loop_total_size * l.as<For>().value()->extent;
497
498
      DLOG(INFO) << "[PlanLoopPartition] loop_total_size = " << loop_total_size
                 << '\n';
499
500
501
502
503
      while (!analyzer_.CanProve(
                 floormod(loop_total_size,
                          T.thread_bounds->extent * vector_size) == 0) &&
             vector_size > 1)
        vector_size /= 2;
504
505
      DLOG(INFO) << "[PlanLoopPartition] after adjust: vector_size = "
                 << vector_size << '\n';
506

507
      // Check if coalesced_width is defined
508
509
      if (auto coalesced_width =
              root_->annotations.Get(tl::attr::coalesced_width)) {
510
        if (const auto *imm = coalesced_width->as<IntImmNode>()) {
511
512
513
          int expected = imm->value;
          // Verify that vector_size is divisible by expected
          if (vector_size % expected != 0) {
514
515
            LOG(FATAL) << "Vector size " << vector_size
                       << " is not divisible by coalesced width " << expected;
516
517
518
519
520
521
          }
          vector_size = expected;
        } else {
          LOG(FATAL) << "coalesced_width should be an IntImmNode.";
        }
      }
522
523
524
      DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_
                 << " ############# vector_size = " << vector_size
                 << ", thread_bounds = " << T.thread_bounds << '\n';
525
      loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds);
526
527
      DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
                 << loop_layout_->DebugOutput() << '\n';
528
    }
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595

    // Lambda that guards replicated accesses:
    // - When a loop layout replicates a fragment buffer (rep > 1), each thread
    //   observes the same fragment elements. Blindly storing to shared/global
    //   memory in that case would add the same value multiple times.
    // - We therefore restrict the store so that only the replica with rep == 0
    //   performs the update (e.g. global[i] += fragment[i] only fires once).
    // Trigger conditions for this guard:
    // 1) There are cross-thread stores targeting shared/global memory (no
    //    fragment stores in this branch; atomic_add and similar remain TODO).
    // 2) The loop layout replicate extent is greater than 1, inferred from the
    //    thread bounds captured in the layout.

    [this, &store_shared_global_buffers, &store_fragment_buffers,
     &has_cross_thread_access, &const_index_fragment_buffer, &T]() {
      if (is_one(loop_layout_->ReplicateExtent()))
        return;
      if (!has_cross_thread_access)
        return;

      if (!store_fragment_buffers.empty()) {
        // Iterate replicated fragment stores: when the fragment index is a
        // constant (e.g. fragment[0]), every thread touches the same slot, so
        // the rep == 0 predicate is unnecessary. Example: for i in
        // T.Parallel(...):
        //   shared[i] = ...
        //   fragment[0] = ...
        bool replicate_is_from_dynamic_index_fragment = false;
        for (const auto &fragment : store_fragment_buffers) {
          if (!T.layout_map.count(fragment)) {
            continue;
          }

          auto fragment_layout = T.layout_map[fragment].as<Fragment>().value();
          if (is_one(fragment_layout->ReplicateExtent()))
            continue;

          if (analyzer_.CanProveEqual(fragment_layout->ReplicateExtent(),
                                      loop_layout_->ReplicateExtent()))
            continue;
          if (std::find(const_index_fragment_buffer.begin(),
                        const_index_fragment_buffer.end(),
                        fragment) == const_index_fragment_buffer.end()) {
            replicate_is_from_dynamic_index_fragment = true;
          }
        }

        if (!replicate_is_from_dynamic_index_fragment)
          return;

        ICHECK(store_shared_global_buffers.empty())
            << "Invalid layout: cannot have both fragment and shared store "
               "buffers "
               "in replicated loop layout.";
        return;
      } else {
        // Now, store is global or shared
        // or T.call_extern or T.call_intrin ...
        auto inv = loop_layout_->Inverse();
        Array<PrimExpr> fwd;
        for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
          fwd.push_back(0);
        fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
        auto rep = inv->Forward(fwd).back();
        AddPredicate(EQ(rep, 0));
      }
    }();
596
597
598
  } else {
    return {};
  }
599
600
601
602
603
604
605
606
607
608
  // check loop_layout_ is injective
  auto injective_res = loop_layout_->DetectInjective();
  if (!injective_res->errors.empty()) {
    std::ostringstream oss;
    oss << "Loop layout is not injective: " << loop_layout_->DebugOutput()
        << '\n'
        << "  errors: " << injective_res->errors << '\n'
        << "  loop AST: " << root_;
    throw LoopLayoutInjectiveException(oss.str());
  }
609

610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
  PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();

  auto block_size = T.thread_bounds->extent;
  if (loop_layout_.defined()) {
    if (loop_layout_->ThreadRange().defined()) {
      auto thread_range = loop_layout_->ThreadRange();
      block_size = thread_range->extent;
      AddPredicate(GE(InputPlaceholder(0), thread_range->min));
      AddPredicate(
          LT(InputPlaceholder(0), thread_range->min + thread_range->extent));
    }
  }

  if (!analyzer_.CanProveEqual(loop_thread_extent, block_size)) {
    AddPredicate(
        LT(InputPlaceholder(0), loop_thread_extent + T.thread_bounds->min));
  }

628
  // Step 2: Check that the loop's partition can correctly align with all source
629
630
  // fragment, and infer layout only when it's not yet layout-ed
  LayoutMap results;
631
  for (const auto &[buffer, _] : indice_map_) {
632
633
    if (T.layout_map.count(buffer)) {
      auto fragment = T.layout_map[buffer].as<Fragment>().value();
634
635
      auto vars =
          loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
636
637
638
639
      if (!ProveFragmentContains(loop_layout_, fragment, vars,
                                 indice_map_[buffer], analyzer_)) {
        std::ostringstream oss;
        oss << "Layout infer conflict between " << buffer << " and "
640
641
642
            << source_buffer << " in T.Parallel loop:" << '\n'
            << "    loop " << loop_layout_->DebugOutput() << '\n'
            << "    fragment " << fragment->DebugOutput() << '\n';
643
        throw LayoutConflictException(oss.str());
644
      }
645
646
647
648
    } else {
      auto dst_layout =
          CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
      results.Set(buffer, dst_layout);
649
    }
650
651
652
653
  }
  return results;
}

654
Optional<PrimExpr> ParallelOpNode::GetPredicate(Var thread_var) const {
655
656
657
  if (predicate_.defined()) {
    return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
  } else {
658
    return std::nullopt;
659
660
661
  }
}

662
Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
663
  ICHECK(loop_layout_.defined());
664
  if (IsCommonAccessIndice(buffer)) {
665
    return loop_layout_;
666
  }
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
  // Prefer a simple path: if original 2D indices form a bijective map, invert
  // them directly and avoid introducing a synthetic replicate dimension.
  {
    auto res2d =
        arith::DetectIterMap(indice_map_[buffer], ToVMap(loop_vars_), 1,
                             arith::IterMapLevel::Bijective,
                             const_cast<arith::Analyzer *>(&analyzer_));
    if (res2d->errors.empty()) {
      Layout ind_inv2d = Layout(loop_vars_, indice_map_[buffer])->Inverse();
      PrimExpr indice_rep_extent = 1;
      PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
      PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
      Array<PrimExpr> fwd2;
      for (size_t i = 0; i < buffer->shape.size(); i++) {
        fwd2.push_back(InputPlaceholder(i));
      }
      PrimExpr thd_b2 =
          loop_layout_->ForwardThread(ind_inv2d->Forward(fwd2), std::nullopt);
      return Fragment(buffer->shape, {}, thd_b2, dest_buffer_rep_extent,
                      std::nullopt)
          ->CondenseReplicateVar();
    }
  }
  // Otherwise, infer an extra flattened iterator that captures truly-unused
  // pieces of the loop space (if any), then try inversion with it.
692
693
  PrimExpr rep_b = MakeFlattenedExpression(
      DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
694
695
  auto bijective_indice = indice_map_[buffer];
  bijective_indice.push_back(rep_b);
696
  Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
697

698
699
  PrimExpr indice_rep_extent =
      ind_inv->InputShape().back(); // this is the size of rep_b
700
701
702
703
704
705
706
707
  PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
  PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
  Array<PrimExpr> fwd;
  for (size_t i = 0; i < buffer->shape.size(); i++) {
    fwd.push_back(InputPlaceholder(i));
  }
  fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
  PrimExpr thd_b = loop_layout_->ForwardThread(
708
709
      ind_inv->Forward(fwd),
      FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
710
711
  return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent,
                  std::nullopt)
712
713
714
      ->CondenseReplicateVar();
}

715
TVM_FFI_STATIC_INIT_BLOCK() { ParallelOpNode::RegisterReflection(); }
716

717
718
} // namespace tl
} // namespace tvm