parallel.cc 14.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/*!
 * \file op/parallel.cc
 * \brief Define Parallel for operator
 */

#include "parallel.h"

#include <tvm/tir/op.h>

#include "../layout/utils.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"

namespace tvm {
namespace tl {

using namespace tir;

namespace attr {
/*! \brief Mark that how the loop is vectorized. */
constexpr const char *coalesced_width = "coalesced_width";
23
} // namespace attr
24
25

class IfBufferRemapLoopGenerator : public StmtExprMutator {
26
public:
27
28
29
30
31
32
  static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
                 Map<Buffer, Layout> layout_map) {
    IfBufferRemapLoopGenerator generator(buffer_remap, layout_map);
    return Downcast<For>(generator(std::move(stmt)));
  }

33
34
35
private:
  IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap,
                             Map<Buffer, Layout> layout_map)
36
37
      : buffer_remap_(buffer_remap), layout_map_(layout_map) {}

38
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
39
40
41
42
43
44
45
46
47
48
49
    auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));

    if (buffer_remap_.count(load->buffer)) {
      auto new_indices = layout_map_[load->buffer]->Forward(load->indices);
      auto new_buffer = buffer_remap_[load->buffer];

      return BufferLoad(new_buffer, new_indices);
    }
    return load;
  }

50
  Stmt VisitStmt_(const BufferStoreNode *op) final {
51
52
53
54
55
56
57
58
59
60
61
62
63
    auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    if (buffer_remap_.count(store->buffer)) {
      auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
      auto new_buffer = buffer_remap_[store->buffer];
      return BufferStore(new_buffer, store->value, new_indices);
    }
    return store;
  }

  Map<Buffer, Buffer> buffer_remap_;
  Map<Buffer, Layout> layout_map_;
};

64
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
65
  ICHECK(op->kind == ForKind::kParallel);
66
67
  p->loop_vars_.push_back(
      IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
68
69
70
71
  p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
  StmtExprVisitor::VisitStmt_(op);
}

72
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
73
74
75
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
76
77
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
78
79
80
81
82
83
84
85
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
    p->buffer_is_write_.insert(op->buffer);
  }
  StmtExprVisitor::VisitStmt_(op);
}

86
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
87
88
89
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
90
91
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
92
93
94
95
96
97
98
99
100
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
  }
  StmtExprVisitor::VisitExpr_(op);
}

ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); }

101
102
bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
  auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
103
104
105
  return StructuralEqual()(indice_map_[buffer], common_indice);
}

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
/*! \brief Infer the layout for parallel operations based on different inference
 * levels
 *
 * The inference level controls how aggressively we try to infer and optimize
 * layouts:
 * - kStrict (2): Most conservative level. Only allows explicitly defined
 * layouts. Returns empty layout map if loop_layout_ is not already defined.
 *                Used when exact layout control is required.
 *
 * - kCommon (1): Intermediate level between strict and free.
 *                Allows common layout patterns while maintaining some
 * constraints.
 *
 * - kFree (0):   Most permissive level. Allows maximum optimization freedom.
 *                Will attempt layout inference even without source buffers.
 *                Can generate new layouts based on vectorization and thread
 * bounds. Used when maximum performance optimization is desired.
 */
124
125
126
127
128
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (loop_layout_.defined())
    return {};
  if (level == InferLevel::kStrict)
    return {};
129
130
131

  // Step 1: try to infer loop's partition from a source fragment
  Buffer source_buffer, read_source_buffer;
132
  for (const auto &[buffer, indices] : indice_map_) {
133
134
    if (T.layout_map.count(buffer)) {
      auto frag = T.layout_map[buffer].as<Fragment>().value();
135
      if (buffer_is_write_.count(buffer)) {
136
        source_buffer = buffer;
137
138
139
140
141
142
143
144
145
      } else {
        // Keep the buffer with largest number of indices
        // (which means the inference based on that buffer is more accurate)
        // as read_source_buffer to get more accurate layout
        if (!read_source_buffer.defined() ||
            indice_map_[buffer].size() >
                indice_map_[read_source_buffer].size()) {
          read_source_buffer = buffer;
        }
146
147
148
149
        // If the buffer is not replicated and shape is equal to the
        // source_buffer, use it as source_buffer because the layout inference
        // is more accurate
        if (is_one(frag->ReplicateExtent()) && !source_buffer.defined()) {
150
151
          source_buffer = buffer;
        }
152
      }
153
154
    }
  }
155
  auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
156
157
158
159
160
    Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
    if (IsCommonAccessIndice(buffer)) {
      return src_layout;
    } else {
      Var rep;
161
162
163
164
      auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
                              IterVarType::kDataPar);
      PrimExpr loop_var_to_thread =
          src_layout->ForwardThread(indice_map_[buffer], rep);
165
      return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
166
          ->BindThreadRange(T.thread_bounds);
167
168
169
170
171
172
173
    }
  };
  if (source_buffer.defined()) {
    loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
  } else if (level == InferLevel::kFree) {
    if (read_source_buffer.defined()) {
      loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
174
175
176
      // // Loop don't need to be replicated.
      // if (!is_one(loop_layout_->ReplicateExtent()))
      //   loop_layout_ = loop_layout_->DeReplicate();
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

      // For free layout inference
      // If replication exists and buffer has cross-thread shared memory access,
      // add predicate
      bool has_cross_thread_access = false;
      PostOrderVisit(root_, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
          // check if scope is shared or global
          if (store->buffer.scope() == "shared" ||
              store->buffer.scope() == "shared.dyn" ||
              store->buffer.scope() == "global") {
            has_cross_thread_access = true;
          }
        } else if (const auto *load = obj.as<BufferLoadNode>()) {
          // check if scope is shared or global
          if (load->buffer.scope() == "shared" ||
              load->buffer.scope() == "shared.dyn" ||
              load->buffer.scope() == "global") {
            has_cross_thread_access = true;
          }
        }
      });

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
      // check if loop body contains a "pure" buffer store (i.e., direct
      // assignment, not compound update)
      bool has_pure_buffer_store = false;
      PostOrderVisit(root_, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
          // Check if the value is a direct load from another buffer (i.e., b[i]
          // = a[i])
          if (const auto *load = store->value.as<BufferLoadNode>()) {
            has_pure_buffer_store = true;
          }
        }
      });

      if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
          !has_pure_buffer_store) {
215
216
217
218
        auto inv = loop_layout_->Inverse();
        Array<PrimExpr> fwd;
        for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
          fwd.push_back(0);
219
        fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
220
221
222
        auto rep = inv->Forward(fwd).back();
        AddPredicate(EQ(rep, 0));
      }
223
224
225
    } else {
      // Vectorize Size must be aware of the buffer_remap
      // As the pass will do post processing to the layout
226
227
      auto maybe_remapped_root_ =
          IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
228
229
230
      int vector_size = GetVectorizeSize(maybe_remapped_root_);

      // Check if coalesced_width is defined
231
232
233
      if (auto coalesced_width =
              root_->annotations.Get(tl::attr::coalesced_width)) {
        if (const auto *imm = coalesced_width.as<IntImmNode>()) {
234
235
236
          int expected = imm->value;
          // Verify that vector_size is divisible by expected
          if (vector_size % expected != 0) {
237
238
            LOG(FATAL) << "Vector size " << vector_size
                       << " is not divisible by coalesced width " << expected;
239
240
241
242
243
244
          }
          vector_size = expected;
        } else {
          LOG(FATAL) << "coalesced_width should be an IntImmNode.";
        }
      }
245
      loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds);
246
247
248
249
    }
  } else {
    return {};
  }
250

251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
  PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();

  auto block_size = T.thread_bounds->extent;
  if (loop_layout_.defined()) {
    if (loop_layout_->ThreadRange().defined()) {
      auto thread_range = loop_layout_->ThreadRange();
      block_size = thread_range->extent;
      AddPredicate(GE(InputPlaceholder(0), thread_range->min));
      AddPredicate(
          LT(InputPlaceholder(0), thread_range->min + thread_range->extent));
    }
  }

  if (!analyzer_.CanProveEqual(loop_thread_extent, block_size)) {
    AddPredicate(
        LT(InputPlaceholder(0), loop_thread_extent + T.thread_bounds->min));
  }

269
270
271
  // Step 2: Check that the loop's partition can correctly align with all source
  // fragment
  for (const auto &[buffer, _] : indice_map_) {
272
273
274
275
    if (T.layout_map.count(buffer)) {
      auto fragment = T.layout_map[buffer].as<Fragment>().value();
      // TODO: Add thread checks for replicated cases
      // need to wildcard match the rhs with lhs
276
277
      if (!is_one(loop_layout_->ReplicateExtent()) ||
          !is_one(fragment->ReplicateExtent()))
278
        continue;
279
280
      auto vars =
          loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
281
282
283
      auto lhs = loop_layout_->ForwardThread(vars, NullOpt);
      auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt);
      auto diff = analyzer_.Simplify(lhs - rhs);
284
285
286
      ICHECK(is_zero(diff))
          << "Layout infer conflict for " << buffer << " " << source_buffer
          << "\nLHS = " << lhs << "\nRHS = " << rhs;
287
288
289
290
    }
  }
  // Step 3: Infer other fragment's layout from the loop's partition
  LayoutMap results;
291
  for (const auto &[buffer, _] : indice_map_) {
292
    if (!T.layout_map.count(buffer)) {
293
      results.Set(buffer, CompleteBufferFragment(buffer)->BindThreadRange(
294
                              T.thread_bounds));
295
296
    }

297
    // Layout infer conflict for local.fragment can not be handled here
298
    // because the source_buffer is not always available
299
300
301
302
    // (zhengju) do not modify strict layout even if it is conflict with the
    // dst layout. This will not influence the result because the strict
    // layout is usually with rep = 1 Since the real layout map is
    // controlled by layout_inference.cc, we should add this check there
303
304
305
306
307
    if (buffer.scope() == "local.fragment" && source_buffer.defined() &&
        source_buffer.scope() == "local.fragment") {
      if (T.layout_map.count(buffer)) {
        const FragmentNode *src_layout =
            T.layout_map[buffer].as<Fragment>().get();
308
        Fragment dst_layout_fragment =
309
            CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
310
311
        const FragmentNode *dst_layout =
            dst_layout_fragment.as<Fragment>().get();
312
313
314
315
316
317
318
        if (as_const_int(dst_layout->ReplicateExtent()) &&
            as_const_int(src_layout->ReplicateExtent()) &&
            (*as_const_int(dst_layout->ReplicateExtent()) >
             *as_const_int(src_layout->ReplicateExtent()))) {
          results.Set(buffer, dst_layout_fragment);
          continue;
        }
319
320
321
        if (src_layout && dst_layout) {
          ICHECK(src_layout->IsEqual(dst_layout, true))
              << "Layout may conflict with ParallelOp for buffer " << buffer
322
              << " vs. " << source_buffer << "\nError body begin:\n"
323
              << GetRoot()->body << "\nError body end"
324
325
326
327
328
329
330
              << "\nLHS = " << src_layout->DebugOutput()
              << "\nRHS = " << dst_layout->DebugOutput()
              << "\nYou may need to use a shared memory to transform the "
                 "layout";
        }
      }
    }
331
332
333
334
335
336
337
338
339
340
341
342
  }
  return results;
}

Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
  if (predicate_.defined()) {
    return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
  } else {
    return NullOpt;
  }
}

343
Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
344
  ICHECK(loop_layout_.defined());
345
  if (IsCommonAccessIndice(buffer)) {
346
    return loop_layout_;
347
  }
348
349
  PrimExpr rep_b = MakeFlattenedExpression(
      DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
350
351
352
  auto bijective_indice = indice_map_[buffer];
  bijective_indice.push_back(rep_b);
  Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
353
354
  PrimExpr indice_rep_extent =
      ind_inv->InputShape().back(); // this is the size of rep_b
355
356
357
358
359
360
361
362
  PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
  PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
  Array<PrimExpr> fwd;
  for (size_t i = 0; i < buffer->shape.size(); i++) {
    fwd.push_back(InputPlaceholder(i));
  }
  fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
  PrimExpr thd_b = loop_layout_->ForwardThread(
363
364
      ind_inv->Forward(fwd),
      FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
365
366
367
368
  return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt)
      ->CondenseReplicateVar();
}

369
370
} // namespace tl
} // namespace tvm