parallel.cc 11.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file op/parallel.cc
 * \brief Define Parallel for operator
 */

#include "parallel.h"

#include <tvm/tir/op.h>

#include "../layout/utils.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"

namespace tvm {
namespace tl {

using namespace tir;

namespace attr {
/*! \brief Mark that how the loop is vectorized. */
constexpr const char *coalesced_width = "coalesced_width";
42
} // namespace attr
43
44

class IfBufferRemapLoopGenerator : public StmtExprMutator {
45
public:
46
47
48
49
50
51
  static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
                 Map<Buffer, Layout> layout_map) {
    IfBufferRemapLoopGenerator generator(buffer_remap, layout_map);
    return Downcast<For>(generator(std::move(stmt)));
  }

52
53
54
private:
  IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap,
                             Map<Buffer, Layout> layout_map)
55
56
      : buffer_remap_(buffer_remap), layout_map_(layout_map) {}

57
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
58
59
60
61
62
63
64
65
66
67
68
    auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));

    if (buffer_remap_.count(load->buffer)) {
      auto new_indices = layout_map_[load->buffer]->Forward(load->indices);
      auto new_buffer = buffer_remap_[load->buffer];

      return BufferLoad(new_buffer, new_indices);
    }
    return load;
  }

69
  Stmt VisitStmt_(const BufferStoreNode *op) final {
70
71
72
73
74
75
76
77
78
79
80
81
82
    auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    if (buffer_remap_.count(store->buffer)) {
      auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
      auto new_buffer = buffer_remap_[store->buffer];
      return BufferStore(new_buffer, store->value, new_indices);
    }
    return store;
  }

  Map<Buffer, Buffer> buffer_remap_;
  Map<Buffer, Layout> layout_map_;
};

83
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
84
  ICHECK(op->kind == ForKind::kParallel);
85
86
  p->loop_vars_.push_back(
      IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
87
88
89
90
  p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
  StmtExprVisitor::VisitStmt_(op);
}

91
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
92
93
94
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
95
96
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
97
98
99
100
101
102
103
104
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
    p->buffer_is_write_.insert(op->buffer);
  }
  StmtExprVisitor::VisitStmt_(op);
}

105
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
106
107
108
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
109
110
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
111
112
113
114
115
116
117
118
119
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
  }
  StmtExprVisitor::VisitExpr_(op);
}

ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); }

120
121
bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
  auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
122
123
124
  return StructuralEqual()(indice_map_[buffer], common_indice);
}

125
126
127
128
129
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (loop_layout_.defined())
    return {};
  if (level == InferLevel::kStrict)
    return {};
130

131
  auto block_size = T.thread_bounds->extent - T.thread_bounds->min;
132
133
  // Step 1: try to infer loop's partition from a source fragment
  Buffer source_buffer, read_source_buffer;
134
  for (const auto &[buffer, indices] : indice_map_) {
135
136
    if (T.layout_map.count(buffer)) {
      auto frag = T.layout_map[buffer].as<Fragment>().value();
137
      if (buffer_is_write_.count(buffer)) {
138
        source_buffer = buffer;
139
140
141
142
143
144
145
146
147
148
      } else {
        // Keep the buffer with largest number of indices
        // (which means the inference based on that buffer is more accurate)
        // as read_source_buffer to get more accurate layout
        if (!read_source_buffer.defined() ||
            indice_map_[buffer].size() >
                indice_map_[read_source_buffer].size()) {
          read_source_buffer = buffer;
        }
      }
149
150
    }
  }
151
  auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
152
153
154
155
156
    Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
    if (IsCommonAccessIndice(buffer)) {
      return src_layout;
    } else {
      Var rep;
157
158
159
160
      auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
                              IterVarType::kDataPar);
      PrimExpr loop_var_to_thread =
          src_layout->ForwardThread(indice_map_[buffer], rep);
161
162
163
164
165
166
167
168
169
      return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter);
    }
  };
  if (source_buffer.defined()) {
    loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
  } else if (level == InferLevel::kFree) {
    if (read_source_buffer.defined()) {
      loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
      // Loop don't need to be replicated.
170
171
      if (!is_one(loop_layout_->ReplicateExtent()))
        loop_layout_ = loop_layout_->DeReplicate();
172
173
174
175
      // if still has replication, add a condition
      if (!is_one(loop_layout_->ReplicateExtent())) {
        auto inv = loop_layout_->Inverse();
        Array<PrimExpr> fwd;
176
177
        for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
          fwd.push_back(0);
178
179
180
181
182
183
184
        fwd.push_back(InputPlaceholder(0));
        auto rep = inv->Forward(fwd).back();
        AddPredicate(EQ(rep, 0));
      }
    } else {
      // Vectorize Size must be aware of the buffer_remap
      // As the pass will do post processing to the layout
185
186
      auto maybe_remapped_root_ =
          IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
187
188
189
      int vector_size = GetVectorizeSize(maybe_remapped_root_);

      // Check if coalesced_width is defined
190
191
192
      if (auto coalesced_width =
              root_->annotations.Get(tl::attr::coalesced_width)) {
        if (const auto *imm = coalesced_width.as<IntImmNode>()) {
193
194
195
          int expected = imm->value;
          // Verify that vector_size is divisible by expected
          if (vector_size % expected != 0) {
196
197
            LOG(FATAL) << "Vector size " << vector_size
                       << " is not divisible by coalesced width " << expected;
198
199
200
201
202
203
          }
          vector_size = expected;
        } else {
          LOG(FATAL) << "coalesced_width should be an IntImmNode.";
        }
      }
204
      loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds);
205
206
    }
    PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
207
    if (!analyzer_.CanProveEqual(loop_thread_extent, block_size))
208
209
210
211
      AddPredicate(LT(InputPlaceholder(0), loop_thread_extent));
  } else {
    return {};
  }
212
213
214
  // Step 2: Check that the loop's partition can correctly align with all source
  // fragment
  for (const auto &[buffer, _] : indice_map_) {
215
216
217
218
    if (T.layout_map.count(buffer)) {
      auto fragment = T.layout_map[buffer].as<Fragment>().value();
      // TODO: Add thread checks for replicated cases
      // need to wildcard match the rhs with lhs
219
220
      if (!is_one(loop_layout_->ReplicateExtent()) ||
          !is_one(fragment->ReplicateExtent()))
221
        continue;
222
223
      auto vars =
          loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
224
225
226
      auto lhs = loop_layout_->ForwardThread(vars, NullOpt);
      auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt);
      auto diff = analyzer_.Simplify(lhs - rhs);
227
228
229
      ICHECK(is_zero(diff))
          << "Layout infer conflict for " << buffer << " " << source_buffer
          << "\nLHS = " << lhs << "\nRHS = " << rhs;
230
231
232
233
    }
  }
  // Step 3: Infer other fragment's layout from the loop's partition
  LayoutMap results;
234
  for (const auto &[buffer, _] : indice_map_) {
235
    if (!T.layout_map.count(buffer)) {
236
      results.Set(buffer, CompleteBufferFragment(buffer));
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    }
    // Though they may exist some conflicts, but it's fine.

    // Layout infer conflict for local.fragment can noy be handled here
    // because the source_buffer is not always available
    if (buffer.scope() == "local.fragment" && source_buffer.defined() &&
        source_buffer.scope() == "local.fragment") {
      if (T.layout_map.count(buffer)) {
        const FragmentNode *src_layout =
            T.layout_map[buffer].as<Fragment>().get();
        Fragment dst_layout_fragment = CompleteBufferFragment(buffer);
        const FragmentNode *dst_layout =
            dst_layout_fragment.as<Fragment>().get();
        if (src_layout && dst_layout) {
          ICHECK(src_layout->IsEqual(dst_layout, true))
              << "Layout may conflict with ParallelOp for buffer " << buffer
253
254
              << "\nError body begin:\n"
              << GetRoot()->body << "\nError body end"
255
256
257
258
259
260
261
              << "\nLHS = " << src_layout->DebugOutput()
              << "\nRHS = " << dst_layout->DebugOutput()
              << "\nYou may need to use a shared memory to transform the "
                 "layout";
        }
      }
    }
262
263
264
265
266
267
268
269
270
271
272
273
  }
  return results;
}

Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
  if (predicate_.defined()) {
    return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
  } else {
    return NullOpt;
  }
}

274
Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
275
  ICHECK(loop_layout_.defined());
276
277
  if (IsCommonAccessIndice(buffer))
    return loop_layout_;
278

279
280
  PrimExpr rep_b = MakeFlattenedExpression(
      DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
281
282
283
284
285

  auto bijective_indice = indice_map_[buffer];
  bijective_indice.push_back(rep_b);
  Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();

286
287
  PrimExpr indice_rep_extent =
      ind_inv->InputShape().back(); // this is the size of rep_b
288
289
290
291
292
293
294
295
296
  PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
  PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;

  Array<PrimExpr> fwd;
  for (size_t i = 0; i < buffer->shape.size(); i++) {
    fwd.push_back(InputPlaceholder(i));
  }
  fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
  PrimExpr thd_b = loop_layout_->ForwardThread(
297
298
      ind_inv->Forward(fwd),
      FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
299
300
301
302
303

  return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt)
      ->CondenseReplicateVar();
}

304
305
} // namespace tl
} // namespace tvm