lower_opaque_block.cc 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file lower_opaque_block.cc
 */

#include <tvm/ffi/reflection/registry.h>
25
#include <tvm/ir/attrs.h>
26
27
28
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

29
#include <string>
30
31
#include <utility>

32
#include "../op/builtin.h"
33
34
35
36
37
38
39
40
41
42
43
44
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {

using namespace tir;
using namespace tir::attr;
/*!
 * \brief Remove Block to ensure that the TIR can not be scheduled again.
 */
class OpaqueBlockLower : public StmtExprMutator {
public:
45
46
  static PrimFunc Rewrite(PrimFunc f) {
    auto fptr = f.CopyOnWrite();
47
    OpaqueBlockLower lower;
48
49
50
51
52
53
54
55
56
57
58
    if (auto existing =
            fptr->attrs.GetAttr<Map<Var, PrimExpr>>(tl::attr::kLocalVarInit)) {
      lower.local_var_init_map_ = existing.value();
    }
    lower.storage_align_ = CollectStorageAlignAnnotation(fptr->body);
    fptr->body = lower(std::move(fptr->body));
    if (!lower.local_var_init_map_.empty()) {
      f = WithAttr(std::move(f), tl::attr::kLocalVarInit,
                   lower.local_var_init_map_);
    }
    return f;
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
  }

private:
  Stmt VisitStmt_(const BlockRealizeNode *op) final {
    // We have convert blocks into opaque blocks in previous passes.
    ICHECK(op->iter_values.empty())
        << "Non-opaque blocks are not allowed in FlattenBuffer. Please "
           "call pass ConvertBlocksToOpaque before.";
    // Step 1. Visit the body
    Block new_block = Downcast<Block>(this->VisitStmt(op->block));
    PrimExpr predicate = this->VisitExpr(op->predicate);
    // Step 2. Transform the `predicate` to if-then-else
    Stmt body = new_block->body;
    if (!is_one(predicate)) {
      body = IfThenElse(predicate, std::move(body));
    }
75
76
77
78
79
80
81
    // Step 3. Handle annotations, block annotations are not preserved by
    // default.
    std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
    HandleAnnotations(new_block->annotations, &pragma_attrs, /*is_block=*/true,
                      new_block->alloc_buffers);

    // Step 4. Handle allocations in reverse order
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
      const Buffer &buffer = new_block->alloc_buffers[i - 1];
      Array<PrimExpr> allocation_shape = GetBufferAllocationShape(buffer);
      body = DeclBuffer(buffer, std::move(body));
      Map<String, ffi::Any> allocate_annotations;
      auto it = storage_align_.find(buffer->data);
      if (it != storage_align_.end()) {
        StorageAlignAnnotation allocate_aligns;
        for (auto tuple : it->second) {
          tuple.Set<0>(-1);
          allocate_aligns.push_back(tuple);
        }
        allocate_annotations.Set(tir::attr::buffer_dim_align, allocate_aligns);
      }
96
97
98
99
100
      auto init_it = local_var_init_map_.find(buffer->data);
      if (init_it != local_var_init_map_.end()) {
        const PrimExpr &init = (*init_it).second;
        allocate_annotations.Set(tl::attr::kLocalVarInit, init);
      }
101
102
103
      body = Allocate(buffer->data, buffer->dtype, allocation_shape,
                      const_true(), std::move(body), allocate_annotations);
    }
104
    // Step 5. Insert attribute statements converted from pragmas
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
      body = AttrStmt(Integer(0), it->first, it->second, std::move(body));
    }
    return body;
  }
  Stmt VisitStmt_(const BlockNode *op) final {
    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
    if (block->annotations.count("stmt_group")) {
      return block->body;
    }
    return block;
  }

  Stmt VisitStmt_(const ForNode *op) final {
    // Step 1. Update unit loop info.
    PrimExpr min = this->VisitExpr(op->min);
    PrimExpr extent = this->VisitExpr(op->extent);
    if (is_one(extent) && op->annotations.empty()) {
      // handling unit loop
      unit_loop_vars_[op->loop_var] = min;
    }
    // Step 2. Visit recursively
    Stmt body = this->VisitStmt(op->body);
    // Step 3. Handle annotations
    std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
    Map<String, ffi::Any> new_annotations =
        HandleAnnotations(op->annotations, &pragma_attrs, /*is_block=*/false);
    // Step 4. Create new For loop accordingly
    if (op->kind == ForKind::kThreadBinding) {
      // Case 1. Thread binding
      ICHECK(op->thread_binding.defined());
      String thread_tag = op->thread_binding.value()->thread_tag;
      body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
    } else if (is_one(extent) && op->annotations.empty()) {
      // Case 2. Unit loop
      return body;
    } else {
      // Case 3. An ordinary loop
      body = For(op->loop_var, std::move(min), std::move(extent), op->kind,
                 std::move(body), std::nullopt, new_annotations);
    }
    // Step 5. Insert nested attrs
    for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
      body = AttrStmt(op->loop_var, it->first, it->second, std::move(body));
    }
    return body;
  }

  PrimExpr VisitExpr_(const VarNode *op) final {
154
    Var var = tvm::ffi::GetRef<Var>(op);
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    auto it = unit_loop_vars_.find(var);
    if (it == unit_loop_vars_.end()) {
      return var;

    } else {
      PrimExpr expr = it->second;
      if (expr.dtype() != var.dtype()) {
        expr = tvm::cast(var.dtype(), std::move(expr));
      }
      return expr;
    }
  }

  static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var,
169
170
                               const String &thread_tag, Stmt body) {
    IterVar iter_var(/*dom=*/Range::FromMinExtent(std::move(min), extent),
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
                     /*var=*/std::move(var),
                     /*iter_type=*/IterVarType::kThreadIndex,
                     /*thread_tag=*/thread_tag);
    String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" ||
                       thread_tag == "vthread.y" || thread_tag == "vthread.z")
                          ? tir::attr::virtual_thread
                          : tir::attr::thread_extent;
    return AttrStmt(/*node=*/std::move(iter_var),
                    /*attr_key=*/std::move(attr_key),
                    /*value=*/std::move(extent),
                    /*body=*/std::move(body));
  }

  /*! \brief Convert attr value from annotation map into PrimExpr. */
  PrimExpr ConvertAttrValue(const String &key, const Any &obj) {
    if (obj == nullptr) {
      return PrimExpr();
    } else if (auto expr = obj.try_cast<PrimExpr>()) {
      return expr.value();
    } else if (auto str = obj.try_cast<String>()) {
      return std::move(StringImm(str.value()));
    } else {
      LOG(FATAL) << "Illegal attribute of key " << key << ", value type "
                 << obj.GetTypeKey() << " not supported";
      return PrimExpr();
    }
  }

  /*!
   * \brief Helper to handle annotation dict.
   * (1) if the attr key is prefixed by `pragma_`, move to ordered kv list. They
   * are lowered to `AttrStmt` by legacy TE schedule convention.
   * (2) the non-pragma loop annotations are preserved
   * (3) the non-pragma block annotations are dropped
   * \return New annotation dict with preserved keys. Also update pragma attr
   * pairs ordered by key.
   */
  Map<String, ffi::Any>
  HandleAnnotations(const Map<String, ffi::Any> &annotations,
                    std::vector<std::pair<std::string, PrimExpr>> *pragma_attrs,
211
212
                    bool is_block,
                    const Array<Buffer> &alloc_buffers = Array<Buffer>()) {
213
214
215
216
217
218
    Map<String, ffi::Any> preserved_annotations;
    pragma_attrs->clear();
    for (const auto &kv : annotations) {
      const String &key = kv.first;
      if (tir::attr::IsPragmaKey(key)) {
        pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second));
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
      } else if (key == tl::attr::kLocalVarInit) {
        if (auto local_init_map = kv.second.try_cast<Map<Var, PrimExpr>>()) {
          for (const auto &pair : local_init_map.value()) {
            local_var_init_map_.Set(pair.first, pair.second);
          }
        } else if (auto init_expr = kv.second.try_cast<PrimExpr>()) {
          ICHECK(is_block) << "`" << tl::attr::kLocalVarInit
                           << "` on non-block annotations is not supported";
          Buffer target = ResolveLocalVarBuffer(alloc_buffers);
          if (!target.defined()) {
            LOG(WARNING) << "Failed to resolve buffer for `"
                         << tl::attr::kLocalVarInit << "` annotation";
            continue;
          }
          local_var_init_map_.Set(target->data, init_expr.value());
        } else {
          LOG(FATAL) << "Expected `" << tl::attr::kLocalVarInit
                     << "` to be a PrimExpr or Map<Var, PrimExpr>, but got "
                     << kv.second.GetTypeKey();
        }
239
240
241
242
243
244
245
246
247
248
249
      } else if (!is_block) {
        // the loop annotation is preserved
        preserved_annotations.Set(key, kv.second);
      }
    }
    std::sort(
        pragma_attrs->begin(), pragma_attrs->end(),
        [](const auto &p1, const auto &p2) { return p1.first < p2.first; });
    return preserved_annotations;
  }

250
251
252
253
254
255
256
257
258
259
260
261
262
  Buffer ResolveLocalVarBuffer(const Array<Buffer> &alloc_buffers) const {
    for (const Buffer &buffer : alloc_buffers) {
      std::string scope = buffer.scope();
      if (scope.find("local.var") != std::string::npos) {
        return buffer;
      }
    }
    if (!alloc_buffers.empty()) {
      return alloc_buffers.back();
    }
    return Buffer();
  }

263
264
265
266
267
268
269
270
271
  /*! \brief Record the loop_var and loop start value of unit loops, whose
   * extent is one. */
  std::unordered_map<Var, PrimExpr> unit_loop_vars_;

  /*! \brief Attr keys to preserve into loop annotations. */
  std::unordered_set<std::string> preserved_annotations_;

  /*! \brief The map from buffer var to its storage alignment information. */
  std::unordered_map<Var, StorageAlignAnnotation> storage_align_;
272
273
274

  /*! \brief Local var initializers collected from block annotations. */
  Map<Var, PrimExpr> local_var_init_map_;
275
276
277
};

PrimFunc TLLowerOpaqueBlock(PrimFunc f) {
278
  return OpaqueBlockLower::Rewrite(std::move(f));
279
280
281
282
}

tir::transform::Pass LowerOpaqueBlock() {
  using namespace tir::transform;
283
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
284
285
286
287
288
    return TLLowerOpaqueBlock(std::move(f));
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LowerOpaqueBlock", {});
}

289
TVM_FFI_STATIC_INIT_BLOCK() {
290
291
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.LowerOpaqueBlock", LowerOpaqueBlock);
292
}
293
294
295

} // namespace tl
} // namespace tvm