parallel.cc 13.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/*!
 * \file op/parallel.cc
 * \brief Define Parallel for operator
 */

#include "parallel.h"

#include <tvm/tir/op.h>

#include "../layout/utils.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"

namespace tvm {
namespace tl {

using namespace tir;

namespace attr {
/*! \brief Mark that how the loop is vectorized. */
constexpr const char *coalesced_width = "coalesced_width";
23
} // namespace attr
24
25

class IfBufferRemapLoopGenerator : public StmtExprMutator {
26
public:
27
28
29
30
31
32
  static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
                 Map<Buffer, Layout> layout_map) {
    IfBufferRemapLoopGenerator generator(buffer_remap, layout_map);
    return Downcast<For>(generator(std::move(stmt)));
  }

33
34
35
private:
  IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap,
                             Map<Buffer, Layout> layout_map)
36
37
      : buffer_remap_(buffer_remap), layout_map_(layout_map) {}

38
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
39
40
41
42
43
44
45
46
47
48
49
    auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));

    if (buffer_remap_.count(load->buffer)) {
      auto new_indices = layout_map_[load->buffer]->Forward(load->indices);
      auto new_buffer = buffer_remap_[load->buffer];

      return BufferLoad(new_buffer, new_indices);
    }
    return load;
  }

50
  Stmt VisitStmt_(const BufferStoreNode *op) final {
51
52
53
54
55
56
57
58
59
60
61
62
63
    auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    if (buffer_remap_.count(store->buffer)) {
      auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
      auto new_buffer = buffer_remap_[store->buffer];
      return BufferStore(new_buffer, store->value, new_indices);
    }
    return store;
  }

  Map<Buffer, Buffer> buffer_remap_;
  Map<Buffer, Layout> layout_map_;
};

64
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
65
  ICHECK(op->kind == ForKind::kParallel);
66
67
  p->loop_vars_.push_back(
      IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
68
69
70
71
  p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
  StmtExprVisitor::VisitStmt_(op);
}

72
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
73
74
75
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
76
77
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
78
79
80
81
82
83
84
85
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
    p->buffer_is_write_.insert(op->buffer);
  }
  StmtExprVisitor::VisitStmt_(op);
}

86
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
87
88
89
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
90
91
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
92
93
94
95
96
97
98
99
100
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
  }
  StmtExprVisitor::VisitExpr_(op);
}

ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); }

101
102
bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
  auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
103
104
105
  return StructuralEqual()(indice_map_[buffer], common_indice);
}

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
/*! \brief Infer the layout for parallel operations based on different inference
 * levels
 *
 * The inference level controls how aggressively we try to infer and optimize
 * layouts:
 * - kStrict (2): Most conservative level. Only allows explicitly defined
 * layouts. Returns empty layout map if loop_layout_ is not already defined.
 *                Used when exact layout control is required.
 *
 * - kCommon (1): Intermediate level between strict and free.
 *                Allows common layout patterns while maintaining some
 * constraints.
 *
 * - kFree (0):   Most permissive level. Allows maximum optimization freedom.
 *                Will attempt layout inference even without source buffers.
 *                Can generate new layouts based on vectorization and thread
 * bounds. Used when maximum performance optimization is desired.
 */
124
125
126
127
128
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (loop_layout_.defined())
    return {};
  if (level == InferLevel::kStrict)
    return {};
129
130
131

  // Step 1: try to infer loop's partition from a source fragment
  Buffer source_buffer, read_source_buffer;
132
  for (const auto &[buffer, indices] : indice_map_) {
133
134
    if (T.layout_map.count(buffer)) {
      auto frag = T.layout_map[buffer].as<Fragment>().value();
135
      if (buffer_is_write_.count(buffer)) {
136
        source_buffer = buffer;
137
138
139
140
141
142
143
144
145
      } else {
        // Keep the buffer with largest number of indices
        // (which means the inference based on that buffer is more accurate)
        // as read_source_buffer to get more accurate layout
        if (!read_source_buffer.defined() ||
            indice_map_[buffer].size() >
                indice_map_[read_source_buffer].size()) {
          read_source_buffer = buffer;
        }
146
147
148
149
        // If the buffer is not replicated and shape is equal to the
        // source_buffer, use it as source_buffer because the layout inference
        // is more accurate
        if (is_one(frag->ReplicateExtent()) && !source_buffer.defined()) {
150
151
          source_buffer = buffer;
        }
152
      }
153
154
    }
  }
155
  auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
156
157
158
159
160
    Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
    if (IsCommonAccessIndice(buffer)) {
      return src_layout;
    } else {
      Var rep;
161
162
163
164
      auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
                              IterVarType::kDataPar);
      PrimExpr loop_var_to_thread =
          src_layout->ForwardThread(indice_map_[buffer], rep);
165
      return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
166
          ->BindThreadRange(T.thread_bounds);
167
168
169
170
171
172
173
    }
  };
  if (source_buffer.defined()) {
    loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
  } else if (level == InferLevel::kFree) {
    if (read_source_buffer.defined()) {
      loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
174
175
176
      // // Loop don't need to be replicated.
      // if (!is_one(loop_layout_->ReplicateExtent()))
      //   loop_layout_ = loop_layout_->DeReplicate();
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

      // For free layout inference
      // If replication exists and buffer has cross-thread shared memory access,
      // add predicate
      bool has_cross_thread_access = false;
      PostOrderVisit(root_, [&](const ObjectRef &obj) {
        if (const auto *store = obj.as<BufferStoreNode>()) {
          // check if scope is shared or global
          if (store->buffer.scope() == "shared" ||
              store->buffer.scope() == "shared.dyn" ||
              store->buffer.scope() == "global") {
            has_cross_thread_access = true;
          }
        } else if (const auto *load = obj.as<BufferLoadNode>()) {
          // check if scope is shared or global
          if (load->buffer.scope() == "shared" ||
              load->buffer.scope() == "shared.dyn" ||
              load->buffer.scope() == "global") {
            has_cross_thread_access = true;
          }
        }
      });

      if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access) {
        auto inv = loop_layout_->Inverse();
        Array<PrimExpr> fwd;
        for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
          fwd.push_back(0);
        fwd.push_back(InputPlaceholder(0));
        auto rep = inv->Forward(fwd).back();
        AddPredicate(EQ(rep, 0));
      }
209
210
211
    } else {
      // Vectorize Size must be aware of the buffer_remap
      // As the pass will do post processing to the layout
212
213
      auto maybe_remapped_root_ =
          IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
214
215
216
      int vector_size = GetVectorizeSize(maybe_remapped_root_);

      // Check if coalesced_width is defined
217
218
219
      if (auto coalesced_width =
              root_->annotations.Get(tl::attr::coalesced_width)) {
        if (const auto *imm = coalesced_width.as<IntImmNode>()) {
220
221
222
          int expected = imm->value;
          // Verify that vector_size is divisible by expected
          if (vector_size % expected != 0) {
223
224
            LOG(FATAL) << "Vector size " << vector_size
                       << " is not divisible by coalesced width " << expected;
225
226
227
228
229
230
          }
          vector_size = expected;
        } else {
          LOG(FATAL) << "coalesced_width should be an IntImmNode.";
        }
      }
231
      loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds);
232
233
234
235
    }
  } else {
    return {};
  }
236

237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
  PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();

  auto block_size = T.thread_bounds->extent;
  if (loop_layout_.defined()) {
    if (loop_layout_->ThreadRange().defined()) {
      auto thread_range = loop_layout_->ThreadRange();
      block_size = thread_range->extent;
      AddPredicate(GE(InputPlaceholder(0), thread_range->min));
      AddPredicate(
          LT(InputPlaceholder(0), thread_range->min + thread_range->extent));
    }
  }

  if (!analyzer_.CanProveEqual(loop_thread_extent, block_size)) {
    AddPredicate(
        LT(InputPlaceholder(0), loop_thread_extent + T.thread_bounds->min));
  }

255
256
257
  // Step 2: Check that the loop's partition can correctly align with all source
  // fragment
  for (const auto &[buffer, _] : indice_map_) {
258
259
260
261
    if (T.layout_map.count(buffer)) {
      auto fragment = T.layout_map[buffer].as<Fragment>().value();
      // TODO: Add thread checks for replicated cases
      // need to wildcard match the rhs with lhs
262
263
      if (!is_one(loop_layout_->ReplicateExtent()) ||
          !is_one(fragment->ReplicateExtent()))
264
        continue;
265
266
      auto vars =
          loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
267
268
269
      auto lhs = loop_layout_->ForwardThread(vars, NullOpt);
      auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt);
      auto diff = analyzer_.Simplify(lhs - rhs);
270
271
272
      ICHECK(is_zero(diff))
          << "Layout infer conflict for " << buffer << " " << source_buffer
          << "\nLHS = " << lhs << "\nRHS = " << rhs;
273
274
275
276
    }
  }
  // Step 3: Infer other fragment's layout from the loop's partition
  LayoutMap results;
277
  for (const auto &[buffer, _] : indice_map_) {
278
    if (!T.layout_map.count(buffer)) {
279
      results.Set(buffer, CompleteBufferFragment(buffer)->BindThreadRange(
280
                              T.thread_bounds));
281
282
283
284
285
286
287
288
289
    }

    // Layout infer conflict for local.fragment can noy be handled here
    // because the source_buffer is not always available
    if (buffer.scope() == "local.fragment" && source_buffer.defined() &&
        source_buffer.scope() == "local.fragment") {
      if (T.layout_map.count(buffer)) {
        const FragmentNode *src_layout =
            T.layout_map[buffer].as<Fragment>().get();
290
        Fragment dst_layout_fragment =
291
            CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
292
293
        const FragmentNode *dst_layout =
            dst_layout_fragment.as<Fragment>().get();
294
295
296
297
298
299
300
        if (as_const_int(dst_layout->ReplicateExtent()) &&
            as_const_int(src_layout->ReplicateExtent()) &&
            (*as_const_int(dst_layout->ReplicateExtent()) >
             *as_const_int(src_layout->ReplicateExtent()))) {
          results.Set(buffer, dst_layout_fragment);
          continue;
        }
301
302
303
        if (src_layout && dst_layout) {
          ICHECK(src_layout->IsEqual(dst_layout, true))
              << "Layout may conflict with ParallelOp for buffer " << buffer
304
              << " vs. " << source_buffer << "\nError body begin:\n"
305
              << GetRoot()->body << "\nError body end"
306
307
308
309
310
311
312
              << "\nLHS = " << src_layout->DebugOutput()
              << "\nRHS = " << dst_layout->DebugOutput()
              << "\nYou may need to use a shared memory to transform the "
                 "layout";
        }
      }
    }
313
314
315
316
317
318
319
320
321
322
323
324
  }
  return results;
}

Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
  if (predicate_.defined()) {
    return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
  } else {
    return NullOpt;
  }
}

325
Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
326
  ICHECK(loop_layout_.defined());
327
  if (IsCommonAccessIndice(buffer)) {
328
    return loop_layout_;
329
  }
330
331
  PrimExpr rep_b = MakeFlattenedExpression(
      DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
332
333
334
  auto bijective_indice = indice_map_[buffer];
  bijective_indice.push_back(rep_b);
  Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
335
336
  PrimExpr indice_rep_extent =
      ind_inv->InputShape().back(); // this is the size of rep_b
337
338
339
340
341
342
343
344
  PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
  PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
  Array<PrimExpr> fwd;
  for (size_t i = 0; i < buffer->shape.size(); i++) {
    fwd.push_back(InputPlaceholder(i));
  }
  fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
  PrimExpr thd_b = loop_layout_->ForwardThread(
345
346
      ind_inv->Forward(fwd),
      FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
347
348
349
350
  return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt)
      ->CondenseReplicateVar();
}

351
352
} // namespace tl
} // namespace tvm