parallel.cc 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file op/parallel.cc
 * \brief Define Parallel for operator
 */

#include "parallel.h"

#include <tvm/tir/op.h>

#include "../layout/utils.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"

namespace tvm {
namespace tl {

using namespace tir;

namespace attr {
/*! \brief Mark that how the loop is vectorized. */
constexpr const char *coalesced_width = "coalesced_width";
42
} // namespace attr
43
44

class IfBufferRemapLoopGenerator : public StmtExprMutator {
45
public:
46
47
48
49
50
51
  static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
                 Map<Buffer, Layout> layout_map) {
    IfBufferRemapLoopGenerator generator(buffer_remap, layout_map);
    return Downcast<For>(generator(std::move(stmt)));
  }

52
53
54
private:
  IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap,
                             Map<Buffer, Layout> layout_map)
55
56
      : buffer_remap_(buffer_remap), layout_map_(layout_map) {}

57
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
58
59
60
61
62
63
64
65
66
67
68
    auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));

    if (buffer_remap_.count(load->buffer)) {
      auto new_indices = layout_map_[load->buffer]->Forward(load->indices);
      auto new_buffer = buffer_remap_[load->buffer];

      return BufferLoad(new_buffer, new_indices);
    }
    return load;
  }

69
  Stmt VisitStmt_(const BufferStoreNode *op) final {
70
71
72
73
74
75
76
77
78
79
80
81
82
    auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    if (buffer_remap_.count(store->buffer)) {
      auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
      auto new_buffer = buffer_remap_[store->buffer];
      return BufferStore(new_buffer, store->value, new_indices);
    }
    return store;
  }

  Map<Buffer, Buffer> buffer_remap_;
  Map<Buffer, Layout> layout_map_;
};

83
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
84
  ICHECK(op->kind == ForKind::kParallel);
85
86
  p->loop_vars_.push_back(
      IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
87
88
89
90
  p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
  StmtExprVisitor::VisitStmt_(op);
}

91
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
92
93
94
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
95
96
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
97
98
99
100
101
102
103
104
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
    p->buffer_is_write_.insert(op->buffer);
  }
  StmtExprVisitor::VisitStmt_(op);
}

105
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
106
107
108
  if (op->buffer.scope() == "local.fragment") {
    if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
      ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
109
110
          << op->buffer << ": " << op->indices << " and "
          << p->indice_map_.at(op->buffer);
111
112
113
114
115
116
117
118
119
    } else {
      p->indice_map_.Set(op->buffer, op->indices);
    }
  }
  StmtExprVisitor::VisitExpr_(op);
}

ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); }

120
121
bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
  auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
122
123
124
  return StructuralEqual()(indice_map_[buffer], common_indice);
}

125
126
127
128
129
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  if (loop_layout_.defined())
    return {};
  if (level == InferLevel::kStrict)
    return {};
130
131
132

  // Step 1: try to infer loop's partition from a source fragment
  Buffer source_buffer, read_source_buffer;
133
  for (const auto &[buffer, _] : indice_map_) {
134
135
136
137
138
139
140
141
    if (T.layout_map.count(buffer)) {
      auto frag = T.layout_map[buffer].as<Fragment>().value();
      if (buffer_is_write_.count(buffer))
        source_buffer = buffer;
      else
        read_source_buffer = buffer;
    }
  }
142
  auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
143
144
145
146
147
    Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
    if (IsCommonAccessIndice(buffer)) {
      return src_layout;
    } else {
      Var rep;
148
149
150
151
      auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
                              IterVarType::kDataPar);
      PrimExpr loop_var_to_thread =
          src_layout->ForwardThread(indice_map_[buffer], rep);
152
153
154
155
156
157
158
159
160
      return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter);
    }
  };
  if (source_buffer.defined()) {
    loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
  } else if (level == InferLevel::kFree) {
    if (read_source_buffer.defined()) {
      loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
      // Loop don't need to be replicated.
161
162
      if (!is_one(loop_layout_->ReplicateExtent()))
        loop_layout_ = loop_layout_->DeReplicate();
163
164
165
166
      // if still has replication, add a condition
      if (!is_one(loop_layout_->ReplicateExtent())) {
        auto inv = loop_layout_->Inverse();
        Array<PrimExpr> fwd;
167
168
        for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
          fwd.push_back(0);
169
170
171
172
173
174
175
        fwd.push_back(InputPlaceholder(0));
        auto rep = inv->Forward(fwd).back();
        AddPredicate(EQ(rep, 0));
      }
    } else {
      // Vectorize Size must be aware of the buffer_remap
      // As the pass will do post processing to the layout
176
177
      auto maybe_remapped_root_ =
          IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
178
179
180
      int vector_size = GetVectorizeSize(maybe_remapped_root_);

      // Check if coalesced_width is defined
181
182
183
      if (auto coalesced_width =
              root_->annotations.Get(tl::attr::coalesced_width)) {
        if (const auto *imm = coalesced_width.as<IntImmNode>()) {
184
185
186
          int expected = imm->value;
          // Verify that vector_size is divisible by expected
          if (vector_size % expected != 0) {
187
188
            LOG(FATAL) << "Vector size " << vector_size
                       << " is not divisible by coalesced width " << expected;
189
190
191
192
193
194
195
196
197
198
          }
          vector_size = expected;
        } else {
          LOG(FATAL) << "coalesced_width should be an IntImmNode.";
        }
      }

      loop_layout_ = PlanLoopPartition(root_, T.block_size, vector_size);
    }
    PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
199
200
    if (!analyzer_.CanProveEqual(loop_thread_extent,
                                 static_cast<int>(T.block_size)))
201
202
203
204
      AddPredicate(LT(InputPlaceholder(0), loop_thread_extent));
  } else {
    return {};
  }
205
206
207
  // Step 2: Check that the loop's partition can correctly align with all source
  // fragment
  for (const auto &[buffer, _] : indice_map_) {
208
209
210
211
    if (T.layout_map.count(buffer)) {
      auto fragment = T.layout_map[buffer].as<Fragment>().value();
      // TODO: Add thread checks for replicated cases
      // need to wildcard match the rhs with lhs
212
213
      if (!is_one(loop_layout_->ReplicateExtent()) ||
          !is_one(fragment->ReplicateExtent()))
214
        continue;
215
216
      auto vars =
          loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
217
218
219
      auto lhs = loop_layout_->ForwardThread(vars, NullOpt);
      auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt);
      auto diff = analyzer_.Simplify(lhs - rhs);
220
221
222
      ICHECK(is_zero(diff))
          << "Layout infer conflict for " << buffer << " " << source_buffer
          << "\nLHS = " << lhs << "\nRHS = " << rhs;
223
224
225
226
    }
  }
  // Step 3: Infer other fragment's layout from the loop's partition
  LayoutMap results;
227
  for (const auto &[buffer, _] : indice_map_) {
228
    if (!T.layout_map.count(buffer)) {
229
      results.Set(buffer, CompleteBufferFragment(buffer));
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    }
    // Though they may exist some conflicts, but it's fine.

    // Layout infer conflict for local.fragment can noy be handled here
    // because the source_buffer is not always available
    if (buffer.scope() == "local.fragment" && source_buffer.defined() &&
        source_buffer.scope() == "local.fragment") {
      if (T.layout_map.count(buffer)) {
        const FragmentNode *src_layout =
            T.layout_map[buffer].as<Fragment>().get();
        Fragment dst_layout_fragment = CompleteBufferFragment(buffer);
        const FragmentNode *dst_layout =
            dst_layout_fragment.as<Fragment>().get();
        if (src_layout && dst_layout) {
          ICHECK(src_layout->IsEqual(dst_layout, true))
              << "Layout may conflict with ParallelOp for buffer " << buffer
246
247
              << "\nError body begin:\n"
              << GetRoot()->body << "\nError body end"
248
249
250
251
252
253
254
              << "\nLHS = " << src_layout->DebugOutput()
              << "\nRHS = " << dst_layout->DebugOutput()
              << "\nYou may need to use a shared memory to transform the "
                 "layout";
        }
      }
    }
255
256
257
258
259
260
261
262
263
264
265
266
  }
  return results;
}

Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
  if (predicate_.defined()) {
    return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
  } else {
    return NullOpt;
  }
}

267
Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
268
  ICHECK(loop_layout_.defined());
269
270
  if (IsCommonAccessIndice(buffer))
    return loop_layout_;
271

272
273
  PrimExpr rep_b = MakeFlattenedExpression(
      DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
274
275
276
277
278

  auto bijective_indice = indice_map_[buffer];
  bijective_indice.push_back(rep_b);
  Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();

279
280
  PrimExpr indice_rep_extent =
      ind_inv->InputShape().back(); // this is the size of rep_b
281
282
283
284
285
286
287
288
289
  PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
  PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;

  Array<PrimExpr> fwd;
  for (size_t i = 0; i < buffer->shape.size(); i++) {
    fwd.push_back(InputPlaceholder(i));
  }
  fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
  PrimExpr thd_b = loop_layout_->ForwardThread(
290
291
      ind_inv->Forward(fwd),
      FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
292
293
294
295
296

  return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt)
      ->CondenseReplicateVar();
}

297
298
} // namespace tl
} // namespace tvm