parallel.cc 12.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/*!
 * \file op/parallel.cc
 * \brief Define Parallel for operator
 */

#include "parallel.h"

#include <tvm/tir/op.h>

#include "../layout/utils.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"

namespace tvm {
namespace tl {

using namespace tir;

namespace attr {
/*! \brief Mark that how the loop is vectorized. */
constexpr const char *coalesced_width = "coalesced_width";
23
} // namespace attr
24
25

class IfBufferRemapLoopGenerator : public StmtExprMutator {
26
public:
27
28
29
30
31
32
  static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
                 Map<Buffer, Layout> layout_map) {
    IfBufferRemapLoopGenerator generator(buffer_remap, layout_map);
    return Downcast<For>(generator(std::move(stmt)));
  }

33
34
35
private:
  IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap,
                             Map<Buffer, Layout> layout_map)
36
37
      : buffer_remap_(buffer_remap), layout_map_(layout_map) {}

38
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
39
40
41
42
43
44
45
46
47
48
49
    auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));

    if (buffer_remap_.count(load->buffer)) {
      auto new_indices = layout_map_[load->buffer]->Forward(load->indices);
      auto new_buffer = buffer_remap_[load->buffer];

      return BufferLoad(new_buffer, new_indices);
    }
    return load;
  }

50
  Stmt VisitStmt_(const BufferStoreNode *op) final {
51
52
53
54
55
56
57
58
59
60
61
62
63
    auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    if (buffer_remap_.count(store->buffer)) {
      auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
      auto new_buffer = buffer_remap_[store->buffer];
      return BufferStore(new_buffer, store->value, new_indices);
    }
    return store;
  }

  Map<Buffer, Buffer> buffer_remap_;
  Map<Buffer, Layout> layout_map_;
};

64
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
65
  ICHECK(op->kind == ForKind::kParallel);
66
67
  p->loop_vars_.push_back(
      IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
68
69
70
71
  p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
  StmtExprVisitor::VisitStmt_(op);
}

72
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
73
74
75
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
76
77
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
78
79
80
81
82
83
84
85
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
    p->buffer_is_write_.insert(op->buffer);
  }
  StmtExprVisitor::VisitStmt_(op);
}

86
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
87
88
89
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
90
91
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
92
93
94
95
96
97
98
99
100
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
  }
  StmtExprVisitor::VisitExpr_(op);
}

ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); }

101
102
bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
  auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
103
104
105
  return StructuralEqual()(indice_map_[buffer], common_indice);
}

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
/*! \brief Infer the layout for parallel operations based on different inference
 * levels
 *
 * The inference level controls how aggressively we try to infer and optimize
 * layouts:
 * - kStrict (2): Most conservative level. Only allows explicitly defined
 * layouts. Returns empty layout map if loop_layout_ is not already defined.
 *                Used when exact layout control is required.
 *
 * - kCommon (1): Intermediate level between strict and free.
 *                Allows common layout patterns while maintaining some
 * constraints.
 *
 * - kFree (0):   Most permissive level. Allows maximum optimization freedom.
 *                Will attempt layout inference even without source buffers.
 *                Can generate new layouts based on vectorization and thread
 * bounds. Used when maximum performance optimization is desired.
 */
124
125
126
127
128
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (loop_layout_.defined())
    return {};
  if (level == InferLevel::kStrict)
    return {};
129
130
131

  // Step 1: try to infer loop's partition from a source fragment
  Buffer source_buffer, read_source_buffer;
132
  for (const auto &[buffer, indices] : indice_map_) {
133
134
    if (T.layout_map.count(buffer)) {
      auto frag = T.layout_map[buffer].as<Fragment>().value();
135
      if (buffer_is_write_.count(buffer)) {
136
        source_buffer = buffer;
137
138
139
140
141
142
143
144
145
      } else {
        // Keep the buffer with largest number of indices
        // (which means the inference based on that buffer is more accurate)
        // as read_source_buffer to get more accurate layout
        if (!read_source_buffer.defined() ||
            indice_map_[buffer].size() >
                indice_map_[read_source_buffer].size()) {
          read_source_buffer = buffer;
        }
146
147
148
149
        // If the buffer is not replicated and shape is equal to the
        // source_buffer, use it as source_buffer because the layout inference
        // is more accurate
        if (is_one(frag->ReplicateExtent()) && !source_buffer.defined()) {
150
151
          source_buffer = buffer;
        }
152
      }
153
154
    }
  }
155
  auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
156
157
158
159
160
    Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
    if (IsCommonAccessIndice(buffer)) {
      return src_layout;
    } else {
      Var rep;
161
162
163
164
      auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
                              IterVarType::kDataPar);
      PrimExpr loop_var_to_thread =
          src_layout->ForwardThread(indice_map_[buffer], rep);
165
      return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
166
          ->BindThreadRange(T.thread_bounds);
167
168
169
170
171
172
173
    }
  };
  if (source_buffer.defined()) {
    loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
  } else if (level == InferLevel::kFree) {
    if (read_source_buffer.defined()) {
      loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
174
175
176
177
178
179
180
181
182
183
184
185
186
      // // Loop don't need to be replicated.
      // if (!is_one(loop_layout_->ReplicateExtent()))
      //   loop_layout_ = loop_layout_->DeReplicate();
      // // if still has replication, add a condition
      // if (!is_one(loop_layout_->ReplicateExtent())) {
      //   auto inv = loop_layout_->Inverse();
      //   Array<PrimExpr> fwd;
      //   for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
      //     fwd.push_back(0);
      //   fwd.push_back(InputPlaceholder(0));
      //   auto rep = inv->Forward(fwd).back();
      //   AddPredicate(EQ(rep, 0));
      // }
187
188
189
    } else {
      // Vectorize Size must be aware of the buffer_remap
      // As the pass will do post processing to the layout
190
191
      auto maybe_remapped_root_ =
          IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
192
193
194
      int vector_size = GetVectorizeSize(maybe_remapped_root_);

      // Check if coalesced_width is defined
195
196
197
      if (auto coalesced_width =
              root_->annotations.Get(tl::attr::coalesced_width)) {
        if (const auto *imm = coalesced_width.as<IntImmNode>()) {
198
199
200
          int expected = imm->value;
          // Verify that vector_size is divisible by expected
          if (vector_size % expected != 0) {
201
202
            LOG(FATAL) << "Vector size " << vector_size
                       << " is not divisible by coalesced width " << expected;
203
204
205
206
207
208
          }
          vector_size = expected;
        } else {
          LOG(FATAL) << "coalesced_width should be an IntImmNode.";
        }
      }
209
      loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds);
210
211
212
213
    }
  } else {
    return {};
  }
214

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
  PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();

  auto block_size = T.thread_bounds->extent;
  if (loop_layout_.defined()) {
    if (loop_layout_->ThreadRange().defined()) {
      auto thread_range = loop_layout_->ThreadRange();
      block_size = thread_range->extent;
      AddPredicate(GE(InputPlaceholder(0), thread_range->min));
      AddPredicate(
          LT(InputPlaceholder(0), thread_range->min + thread_range->extent));
    }
  }

  if (!analyzer_.CanProveEqual(loop_thread_extent, block_size)) {
    AddPredicate(
        LT(InputPlaceholder(0), loop_thread_extent + T.thread_bounds->min));
  }

233
234
235
  // Step 2: Check that the loop's partition can correctly align with all source
  // fragment
  for (const auto &[buffer, _] : indice_map_) {
236
237
238
239
    if (T.layout_map.count(buffer)) {
      auto fragment = T.layout_map[buffer].as<Fragment>().value();
      // TODO: Add thread checks for replicated cases
      // need to wildcard match the rhs with lhs
240
241
      if (!is_one(loop_layout_->ReplicateExtent()) ||
          !is_one(fragment->ReplicateExtent()))
242
        continue;
243
244
      auto vars =
          loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
245
246
247
      auto lhs = loop_layout_->ForwardThread(vars, NullOpt);
      auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt);
      auto diff = analyzer_.Simplify(lhs - rhs);
248
249
250
      ICHECK(is_zero(diff))
          << "Layout infer conflict for " << buffer << " " << source_buffer
          << "\nLHS = " << lhs << "\nRHS = " << rhs;
251
252
253
254
    }
  }
  // Step 3: Infer other fragment's layout from the loop's partition
  LayoutMap results;
255
  for (const auto &[buffer, _] : indice_map_) {
256
    if (!T.layout_map.count(buffer)) {
257
      results.Set(buffer, CompleteBufferFragment(buffer)->BindThreadRange(
258
                              T.thread_bounds));
259
260
261
262
263
264
265
266
267
    }

    // Layout infer conflict for local.fragment can noy be handled here
    // because the source_buffer is not always available
    if (buffer.scope() == "local.fragment" && source_buffer.defined() &&
        source_buffer.scope() == "local.fragment") {
      if (T.layout_map.count(buffer)) {
        const FragmentNode *src_layout =
            T.layout_map[buffer].as<Fragment>().get();
268
        Fragment dst_layout_fragment =
269
            CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds);
270
271
        const FragmentNode *dst_layout =
            dst_layout_fragment.as<Fragment>().get();
272
273
274
275
276
277
278
        if (as_const_int(dst_layout->ReplicateExtent()) &&
            as_const_int(src_layout->ReplicateExtent()) &&
            (*as_const_int(dst_layout->ReplicateExtent()) >
             *as_const_int(src_layout->ReplicateExtent()))) {
          results.Set(buffer, dst_layout_fragment);
          continue;
        }
279
280
281
        if (src_layout && dst_layout) {
          ICHECK(src_layout->IsEqual(dst_layout, true))
              << "Layout may conflict with ParallelOp for buffer " << buffer
282
              << " vs. " << source_buffer << "\nError body begin:\n"
283
              << GetRoot()->body << "\nError body end"
284
285
286
287
288
289
290
              << "\nLHS = " << src_layout->DebugOutput()
              << "\nRHS = " << dst_layout->DebugOutput()
              << "\nYou may need to use a shared memory to transform the "
                 "layout";
        }
      }
    }
291
292
293
294
295
296
297
298
299
300
301
302
  }
  return results;
}

Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
  if (predicate_.defined()) {
    return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
  } else {
    return NullOpt;
  }
}

303
Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
304
  ICHECK(loop_layout_.defined());
305
  if (IsCommonAccessIndice(buffer)) {
306
    return loop_layout_;
307
  }
308
309
  PrimExpr rep_b = MakeFlattenedExpression(
      DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
310
311
312
  auto bijective_indice = indice_map_[buffer];
  bijective_indice.push_back(rep_b);
  Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
313
314
  PrimExpr indice_rep_extent =
      ind_inv->InputShape().back(); // this is the size of rep_b
315
316
317
318
319
320
321
322
  PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
  PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
  Array<PrimExpr> fwd;
  for (size_t i = 0; i < buffer->shape.size(); i++) {
    fwd.push_back(InputPlaceholder(i));
  }
  fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
  PrimExpr thd_b = loop_layout_->ForwardThread(
323
324
      ind_inv->Forward(fwd),
      FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
325
326
327
328
  return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt)
      ->CondenseReplicateVar();
}

329
330
} // namespace tl
} // namespace tvm