lower_tile_op.cc 13.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 * \file lower_tile_op.cc
 * \brief Lower the tile op for further codegen.
 */

#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include "../layout/layout.h"
#include "../layout/utils.h"
13
#include "../op/builtin.h"
14
#include "../op/op.h"
15

16
#include "arith/ir_mutator_with_analyzer.h"
17
18
19
20
21
22
23
#include "loop_partition.h"

namespace tvm {
namespace tl {

using namespace tir;

24
25
26
static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout) {
  const auto *ptr_type =
      TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode);
27
28
29
30
31
32
33
34
35
36
37
38
39
  Type new_type;
  // convert fragments to normal local buffer
  if (ptr_type->storage_scope == "local.fragment") {
    new_type = PointerType(ptr_type->element_type, "local");
  } else {
    new_type = buffer->data->type_annotation;
  }
  Var new_var;
  if (ptr_type->storage_scope == "global") {
    new_var = buffer->data;
  } else {
    new_var = Var(buffer->data->name_hint, new_type);
  }
40
41
42
  return Buffer(new_var, buffer->dtype, layout->OutputShape(), {},
                buffer->elem_offset, buffer->name, buffer->data_alignment,
                buffer->offset_factor, buffer->buffer_type);
43
44
45
}

class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
46
public:
47
48
49
50
51
  static PrimFunc Substitute(PrimFunc f) {
    arith::Analyzer analyzer;
    LowerTileOpPass substituter(&analyzer);
    // Trace the buffer map for tvm_access_ptr
    substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end());
52
    for (const auto &[_, buffer] : f->buffer_map) {
53
54
55
56
57
      substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
    ICHECK(target.defined()) << "LowerTileOpPass: Require the target attribute";
    substituter.target_ = target.value();
58
    PrimFuncNode *fptr = f.CopyOnWrite();
59
60
61
62
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

63
private:
64
65
  using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;

66
  Stmt VisitStmt_(const BlockNode *op) final {
67
68
69
70
71
72
73
74
75
76
77
78
    // Record the mapping from buffer data var to buffer for later lookup
    for (auto buffer : op->alloc_buffers) {
      buffer_map_.insert({buffer->data, buffer});
    }
    for (auto match_buffer : op->match_buffers) {
      buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer});
    }
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    Map<Var, Layout> vmap;
    if (op->annotations.count(attr::kLayoutMap)) {
79
80
81
      auto layout_map = op->annotations.at(attr::kLayoutMap)
                            .as<Map<Buffer, Layout>>()
                            .value();
82
83
84
85
86
87
88
89
90
91
92
93
94
      for (auto [buffer, layout] : layout_map) {
        buffer_remap_.Set(buffer, makeBufferWithLayout(buffer, layout));
        layout_map_.Set(buffer, layout);
      }
    }
    auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
    auto block_ptr = block.CopyOnWrite();
    for (size_t i = 0; i < block->alloc_buffers.size(); i++) {
      auto buffer = block->alloc_buffers[i];
      if (buffer_remap_.count(buffer)) {
        block_ptr->alloc_buffers.Set(i, buffer_remap_[buffer]);
      }
    }
95
96
    for (const auto &buffer : workspaces_)
      block_ptr->alloc_buffers.push_back(buffer);
97
98
99
100
101
102
103
    workspaces_.clear();
    block_ptr->annotations.erase(attr::kLayoutMap);
    return block;
  }

  int CheckAndGetBufferRowSize(Buffer buffer) {
    CHECK(buffer->shape.size() >= 2)
104
105
        << "The dimension of Buffer \"" << buffer->name << "\" with shape "
        << buffer->shape << " should be at least 2";
106
107
108
109
110
111

    auto dim = buffer->shape.size();
    auto buffer_row_size = buffer->shape[dim - 1].as<IntImmNode>()->value;
    return buffer_row_size;
  }

112
113
  PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr,
                                    Optional<PrimExpr> offset = NullOpt,
114
                                    DataType dtype = DataType::Int(32)) {
115
116
    // The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
    // accumulate it to smem_offset
117
118
119
120
121
122
123
124
125
126
127
    CHECK(access_ptr->IsInstance<CallNode>())
        << "Invalid access ptr for permuted layout: " << access_ptr;
    auto access_ptr_call = Downcast<Call>(access_ptr);
    if (access_ptr_call->op.same_as(builtin::tvm_access_ptr())) {
      LOG(FATAL) << "Transformation for tvm_access_ptr is not implemented yet";
    } else if (access_ptr_call->op.same_as(builtin::address_of())) {
      BufferLoad load = Downcast<BufferLoad>(access_ptr_call->args[0]);
      Array<PrimExpr> indices = load->indices;
      Array<PrimExpr> shape = load->buffer->shape;

      CHECK_EQ(indices.size(), shape.size())
128
129
130
131
          << "Indices size and shape size must match for general N-dimensional "
             "buffer "
          << "but got indices size: " << indices.size()
          << " and shape size: " << shape.size();
132
133
134
135
136
137
138
139
140

      PrimExpr elem_offset = 0;
      PrimExpr stride = 1;

      for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
        elem_offset += indices[i] * stride;
        stride *= shape[i];
      }

141
142
      PrimExpr smem_offset =
          elem_offset + (offset.defined() ? offset.value() : 0);
143
144
145

      auto new_buffer = buffer_remap_[load->buffer];

146
147
      auto buffer_map_iter =
          buffer_map_.find(Downcast<Var>(load->buffer->data));
148
      CHECK(buffer_map_iter != buffer_map_.end())
149
150
          << "The buffer corresponding to data Var " << access_ptr_call->args[0]
          << " is not found";
151
152
153
154
155
156
157
158
159

      int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second);
      (void)buffer_row_size;

      // Convert offset to target-dimension, reindex it and convert it back
      Array<PrimExpr> multi_dim_indices;
      PrimExpr remaining_offset = smem_offset;

      for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
160
161
        multi_dim_indices.insert(multi_dim_indices.begin(),
                                 floormod(remaining_offset, shape[i]));
162
163
164
        remaining_offset = floordiv(remaining_offset, shape[i]);
      }

165
166
      auto forward_indices =
          layout_map_[load->buffer]->Forward(multi_dim_indices);
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
      PrimExpr new_offset = 0;
      PrimExpr stride_offset = 1;
      for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
        new_offset += forward_indices[i] * stride_offset;
        stride_offset *= shape[i];
      }
      new_offset = analyzer_->Simplify(new_offset);

      Array<PrimExpr> new_indices;
      for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
        new_indices.insert(new_indices.begin(), floormod(new_offset, shape[i]));
        new_offset = floordiv(new_offset, shape[i]);
      }

      auto new_access_ptr = access_ptr_call.CopyOnWrite();
      new_access_ptr->args.Set(0, BufferLoad(new_buffer, new_indices));
    } else {
      LOG(FATAL) << "Invalid access op for permuted layout: " << access_ptr;
    }

    return access_ptr_call;
  }

190
  PrimExpr VisitExpr_(const tir::CallNode *op) final {
191
192
193
194
195
196
197
    Array<RelayExpr> ptx_instructions = {builtin::ptx_ldmatrix(),
                                         builtin::mma_store()};

    if (std::find(ptx_instructions.begin(), ptx_instructions.end(), op->op) ==
        ptx_instructions.end()) {
      auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
      return call;
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    } else {
      is_ptx_ = true;
    }
    // Rewrite from/to shared or shared.dyn to/from local
    auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
    if (call->op.same_as(builtin::ptx_ldmatrix())) {
      // form: T.ptx_ldmatrix(..., smem_ptr, smem_offset)
      // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
      // or T.address_of(buffer, offset)
      auto access_ptr = call->args[5];
      PrimExpr smem_offset = call->args[6];
      Call address_of_call = Downcast<Call>(access_ptr);
      if (!address_of_call->op.same_as(builtin::address_of())) {
        LOG(FATAL) << "Invalid access ptr for permuted layout: " << access_ptr;
      }
      BufferLoad load = Downcast<BufferLoad>(address_of_call->args[0]);

      if (buffer_remap_.count(load->buffer)) {
216
217
        auto new_access_ptr =
            HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype);
218
219
220
221
222
        auto new_call = call.CopyOnWrite();
        new_call->args.Set(5, new_access_ptr);
        new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
      }
    } else if (call->op.same_as(builtin::mma_store())) {
223
224
      // because we will directly store result to Buffer instead of calling
      // mma_store now
225
      auto access_ptr = call->args[2];
226
227
      auto new_access_ptr =
          HandleAccessPtrAndOffset(access_ptr, NullOpt, call->dtype);
228
229
230
231
232
233
234
235
236
      auto new_call = call.CopyOnWrite();
      new_call->args.Set(2, new_access_ptr);
    } else {
      LOG(FATAL) << "Invalid call node: " << call;
    }
    is_ptx_ = false;
    return call;
  }

237
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
238
239
240
241
    auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
    if (is_ptx_) {
      return load;
    }
242

243
244
245
246
247
248
249
250
    if (buffer_remap_.count(load->buffer)) {
      auto new_indices = layout_map_[load->buffer]->Forward(load->indices);
      auto new_buffer = buffer_remap_[load->buffer];
      return BufferLoad(new_buffer, new_indices);
    }
    return load;
  }

251
  Stmt VisitStmt_(const BufferStoreNode *op) final {
252
253
254
255
256
257
258
259
260
    auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
    if (buffer_remap_.count(store->buffer)) {
      auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
      auto new_buffer = buffer_remap_[store->buffer];
      return BufferStore(new_buffer, store->value, new_indices);
    }
    return store;
  }

261
  PrimExpr VisitExpr_(const VarNode *op) final {
262
263
264
    auto var = Downcast<Var>(IRMutatorWithAnalyzer::VisitExpr_(op));
    if (buffer_data_to_buffer_.count(var)) {
      auto buffer = buffer_data_to_buffer_[var];
265
266
      if (buffer_remap_.count(buffer))
        return buffer_remap_[buffer]->data;
267
268
269
270
    }
    return var;
  }

271
272
  Stmt VisitStmt_(const EvaluateNode *op) final {
    const CallNode *call = op->value.as<CallNode>();
273
274
275
276
277
    // Do not analysis the call node to the global function.
    if (call && call->op.as<GlobalVarNode>())
      return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op));

    auto tile_op = ParseOperator(GetRef<Stmt>(op), buffer_data_to_buffer_);
278
279
    if (tile_op == nullptr)
      return IRMutatorWithAnalyzer::VisitStmt_(op);
280
    AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) {
281
282
      auto workspace =
          decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn");
283
      workspaces_.push_back(workspace);
284
      return workspace.access_ptr(2); // write
285
286
    };

287
288
289
290
291
    // Get pass config `tl.disable_tma_lower`
    tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
    Optional<Bool> opt_disable_tma_lower =
        ctxt->GetConfig(kDisableTMALower, Optional<Bool>());
    bool disable_tma_lower = opt_disable_tma_lower.value_or(Bool(false));
292
293
294
295
296
297
    Range thread_bounds;

    if (analyzer_->const_int_bound.IsBound(thread_var_->var)) {
      auto const_int_bound = analyzer_->const_int_bound(thread_var_);
      auto min_value = const_int_bound->min_value;
      auto max_value = const_int_bound->max_value;
298
      auto extent = max_value + 1 - min_value;
299
300
      thread_bounds =
          Range::FromMinExtent(IntImm(thread_var_->var.dtype(), min_value),
301
                               IntImm(thread_var_->var.dtype(), extent));
302
303
304
    } else {
      thread_bounds = Range::FromMinExtent(0, 1);
    }
305

306
307
308
309
    auto lowered = tile_op->Lower(
        LowerArgs{target_, thread_bounds, thread_var_->var, callback,
                  layout_map_, buffer_remap_, disable_tma_lower},
        analyzer_);
310
311
312
    return IRMutatorWithAnalyzer::VisitStmt(lowered);
  }

313
  Stmt VisitStmt_(const AttrStmtNode *op) final {
314
315
316
317
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
318
        thread_var_ = iv;
319
320
321
322
323
324
325
326
327
328
329
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_block_size_ = iv->dom->extent.as<IntImmNode>()->value;
      }
    }
    return arith::IRMutatorWithAnalyzer::VisitStmt_(op);
  }

  Target target_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Layout> layout_map_;
  Map<Buffer, Buffer> buffer_remap_;
330
331
332
333
  // This is a workaround for cpu backend,
  // we need to define a thread_var for the serial loop.
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
  size_t thread_block_size_ = 0;
  Array<Buffer> workspaces_;
  // For ptx Node, we need to remap the buffer and indices
  // By access CallNode instead of BufferLoad Node.
  bool is_ptx_{false};
  // Mapping from data Var of a Buffer to Buffer, for lookup
  std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
};

namespace transform {

using namespace tir::transform;

tvm::transform::Pass LowerTileOp() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    return LowerTileOpPass::Substitute(std::move(f));
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {});
}

TVM_REGISTER_GLOBAL("tl.transform.LowerTileOp").set_body_typed(LowerTileOp);
355
} // namespace transform
356

357
358
} // namespace tl
} // namespace tvm