lower_tile_op.cc 14.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 * \file lower_tile_op.cc
 * \brief Lower the tile op for further codegen.
 */

#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include "../layout/layout.h"
#include "../layout/utils.h"
13
#include "../op/builtin.h"
14
#include "../op/op.h"
15

16
#include "arith/ir_mutator_with_analyzer.h"
17
18
19
20
21
22
23
#include "loop_partition.h"

namespace tvm {
namespace tl {

using namespace tir;

24
25
static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout,
                                   Map<Var, Var> &var_remap) {
26
27
  const auto *ptr_type =
      TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode);
28
29
30
31
32
33
34
35
36
37
38
  Type new_type;
  // convert fragments to normal local buffer
  if (ptr_type->storage_scope == "local.fragment") {
    new_type = PointerType(ptr_type->element_type, "local");
  } else {
    new_type = buffer->data->type_annotation;
  }
  Var new_var;
  if (ptr_type->storage_scope == "global") {
    new_var = buffer->data;
  } else {
39
40
41
42
43
44
    if (var_remap.count(buffer->data)) {
      new_var = var_remap[buffer->data];
    } else {
      new_var = Var(buffer->data->name_hint, new_type);
      var_remap.Set(buffer->data, new_var);
    }
45
  }
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  Array<PrimExpr> layout_shape = layout->OutputShape();
  Array<PrimExpr> output_shape = layout_shape;

  if (ptr_type->storage_scope == "shared" ||
      ptr_type->storage_scope == "shared.dyn") {
    int replicate_extent = 1;
    Array<PrimExpr> buffer_shape = buffer->shape;
    int buffer_extent = 1;
    int layout_extent = 1;
    for (size_t i = 0; i < buffer_shape.size(); i++) {
      auto shape = buffer_shape[i].as<IntImmNode>();
      buffer_extent *= shape->value;
    }
    for (size_t i = 0; i < layout_shape.size(); i++) {
      auto shape = layout_shape[i].as<IntImmNode>();
      layout_extent *= shape->value;
    }
    replicate_extent = buffer_extent / layout_extent;
    if (replicate_extent > 1) {
      output_shape.insert(output_shape.begin(), replicate_extent);
    }
  }
  return Buffer(new_var, buffer->dtype, output_shape, {}, buffer->elem_offset,
                buffer->name, buffer->data_alignment, buffer->offset_factor,
                buffer->buffer_type);
71
72
73
}

class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
74
public:
75
76
77
78
79
  static PrimFunc Substitute(PrimFunc f) {
    arith::Analyzer analyzer;
    LowerTileOpPass substituter(&analyzer);
    // Trace the buffer map for tvm_access_ptr
    substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end());
80
    for (const auto &[_, buffer] : f->buffer_map) {
81
82
83
84
85
      substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
    ICHECK(target.defined()) << "LowerTileOpPass: Require the target attribute";
    substituter.target_ = target.value();
86
    PrimFuncNode *fptr = f.CopyOnWrite();
87
88
89
90
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

91
private:
92
93
  using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;

94
  Stmt VisitStmt_(const BlockNode *op) final {
95
96
97
98
99
100
101
102
103
104
105
106
    // Record the mapping from buffer data var to buffer for later lookup
    for (auto buffer : op->alloc_buffers) {
      buffer_map_.insert({buffer->data, buffer});
    }
    for (auto match_buffer : op->match_buffers) {
      buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer});
    }
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    Map<Var, Layout> vmap;
    if (op->annotations.count(attr::kLayoutMap)) {
107
108
109
      auto layout_map = op->annotations.at(attr::kLayoutMap)
                            .as<Map<Buffer, Layout>>()
                            .value();
110
      for (auto [buffer, layout] : layout_map) {
111
112
        buffer_remap_.Set(buffer,
                          makeBufferWithLayout(buffer, layout, var_remap_));
113
114
115
116
117
118
119
120
121
122
123
        layout_map_.Set(buffer, layout);
      }
    }
    auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
    auto block_ptr = block.CopyOnWrite();
    for (size_t i = 0; i < block->alloc_buffers.size(); i++) {
      auto buffer = block->alloc_buffers[i];
      if (buffer_remap_.count(buffer)) {
        block_ptr->alloc_buffers.Set(i, buffer_remap_[buffer]);
      }
    }
124
125
    for (const auto &buffer : workspaces_)
      block_ptr->alloc_buffers.push_back(buffer);
126
127
128
129
130
131
132
    workspaces_.clear();
    block_ptr->annotations.erase(attr::kLayoutMap);
    return block;
  }

  int CheckAndGetBufferRowSize(Buffer buffer) {
    CHECK(buffer->shape.size() >= 2)
133
134
        << "The dimension of Buffer \"" << buffer->name << "\" with shape "
        << buffer->shape << " should be at least 2";
135
136
137
138
139
140

    auto dim = buffer->shape.size();
    auto buffer_row_size = buffer->shape[dim - 1].as<IntImmNode>()->value;
    return buffer_row_size;
  }

141
142
  PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr,
                                    Optional<PrimExpr> offset = NullOpt,
143
                                    DataType dtype = DataType::Int(32)) {
144
145
    // The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
    // accumulate it to smem_offset
146
147
148
149
150
151
152
153
154
155
156
    CHECK(access_ptr->IsInstance<CallNode>())
        << "Invalid access ptr for permuted layout: " << access_ptr;
    auto access_ptr_call = Downcast<Call>(access_ptr);
    if (access_ptr_call->op.same_as(builtin::tvm_access_ptr())) {
      LOG(FATAL) << "Transformation for tvm_access_ptr is not implemented yet";
    } else if (access_ptr_call->op.same_as(builtin::address_of())) {
      BufferLoad load = Downcast<BufferLoad>(access_ptr_call->args[0]);
      Array<PrimExpr> indices = load->indices;
      Array<PrimExpr> shape = load->buffer->shape;

      CHECK_EQ(indices.size(), shape.size())
157
158
159
160
          << "Indices size and shape size must match for general N-dimensional "
             "buffer "
          << "but got indices size: " << indices.size()
          << " and shape size: " << shape.size();
161
162
163
164
165
166
167
168
169

      PrimExpr elem_offset = 0;
      PrimExpr stride = 1;

      for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
        elem_offset += indices[i] * stride;
        stride *= shape[i];
      }

170
171
      PrimExpr smem_offset =
          elem_offset + (offset.defined() ? offset.value() : 0);
172
173
174

      auto new_buffer = buffer_remap_[load->buffer];

175
176
      auto buffer_map_iter =
          buffer_map_.find(Downcast<Var>(load->buffer->data));
177
      CHECK(buffer_map_iter != buffer_map_.end())
178
179
          << "The buffer corresponding to data Var " << access_ptr_call->args[0]
          << " is not found";
180
181
182
183
184
185
186
187
188

      int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second);
      (void)buffer_row_size;

      // Convert offset to target-dimension, reindex it and convert it back
      Array<PrimExpr> multi_dim_indices;
      PrimExpr remaining_offset = smem_offset;

      for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
189
190
        multi_dim_indices.insert(multi_dim_indices.begin(),
                                 floormod(remaining_offset, shape[i]));
191
192
193
        remaining_offset = floordiv(remaining_offset, shape[i]);
      }

194
195
      auto forward_indices =
          layout_map_[load->buffer]->Forward(multi_dim_indices);
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
      PrimExpr new_offset = 0;
      PrimExpr stride_offset = 1;
      for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
        new_offset += forward_indices[i] * stride_offset;
        stride_offset *= shape[i];
      }
      new_offset = analyzer_->Simplify(new_offset);

      Array<PrimExpr> new_indices;
      for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
        new_indices.insert(new_indices.begin(), floormod(new_offset, shape[i]));
        new_offset = floordiv(new_offset, shape[i]);
      }

      auto new_access_ptr = access_ptr_call.CopyOnWrite();
      new_access_ptr->args.Set(0, BufferLoad(new_buffer, new_indices));
    } else {
      LOG(FATAL) << "Invalid access op for permuted layout: " << access_ptr;
    }

    return access_ptr_call;
  }

219
  PrimExpr VisitExpr_(const tir::CallNode *op) final {
220
221
222
223
224
225
226
    Array<RelayExpr> ptx_instructions = {builtin::ptx_ldmatrix(),
                                         builtin::mma_store()};

    if (std::find(ptx_instructions.begin(), ptx_instructions.end(), op->op) ==
        ptx_instructions.end()) {
      auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
      return call;
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    } else {
      is_ptx_ = true;
    }
    // Rewrite from/to shared or shared.dyn to/from local
    auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
    if (call->op.same_as(builtin::ptx_ldmatrix())) {
      // form: T.ptx_ldmatrix(..., smem_ptr, smem_offset)
      // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
      // or T.address_of(buffer, offset)
      auto access_ptr = call->args[5];
      PrimExpr smem_offset = call->args[6];
      Call address_of_call = Downcast<Call>(access_ptr);
      if (!address_of_call->op.same_as(builtin::address_of())) {
        LOG(FATAL) << "Invalid access ptr for permuted layout: " << access_ptr;
      }
      BufferLoad load = Downcast<BufferLoad>(address_of_call->args[0]);

      if (buffer_remap_.count(load->buffer)) {
245
246
        auto new_access_ptr =
            HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype);
247
248
249
250
251
        auto new_call = call.CopyOnWrite();
        new_call->args.Set(5, new_access_ptr);
        new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
      }
    } else if (call->op.same_as(builtin::mma_store())) {
252
253
      // because we will directly store result to Buffer instead of calling
      // mma_store now
254
      auto access_ptr = call->args[2];
255
256
      auto new_access_ptr =
          HandleAccessPtrAndOffset(access_ptr, NullOpt, call->dtype);
257
258
259
260
261
262
263
264
265
      auto new_call = call.CopyOnWrite();
      new_call->args.Set(2, new_access_ptr);
    } else {
      LOG(FATAL) << "Invalid call node: " << call;
    }
    is_ptx_ = false;
    return call;
  }

266
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
267
268
269
270
    auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
    if (is_ptx_) {
      return load;
    }
271
272
273
    auto buffer = load->buffer;
    if (buffer_remap_.count(buffer)) {
      auto new_indices = layout_map_[buffer]->Forward(load->indices);
274
275
      auto new_buffer = buffer_remap_[load->buffer];
      return BufferLoad(new_buffer, new_indices);
276
277
278
279
280
281
    } else if (var_remap_.count(buffer->data)) {
      auto new_buffer = Buffer(
          var_remap_[buffer->data], buffer->dtype, buffer->shape,
          buffer->strides, buffer->elem_offset, buffer->name,
          buffer->data_alignment, buffer->offset_factor, buffer->buffer_type);
      return BufferLoad(new_buffer, load->indices);
282
283
284
285
    }
    return load;
  }

286
  Stmt VisitStmt_(const BufferStoreNode *op) final {
287
    auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
288
289
290
    auto buffer = store->buffer;
    if (buffer_remap_.count(buffer)) {
      auto new_indices = layout_map_[buffer]->Forward(store->indices);
291
292
      auto new_buffer = buffer_remap_[store->buffer];
      return BufferStore(new_buffer, store->value, new_indices);
293
294
295
296
297
298
    } else if (var_remap_.count(buffer->data)) {
      auto new_buffer = Buffer(
          var_remap_[buffer->data], buffer->dtype, buffer->shape,
          buffer->strides, buffer->elem_offset, buffer->name,
          buffer->data_alignment, buffer->offset_factor, buffer->buffer_type);
      return BufferStore(new_buffer, store->value, store->indices);
299
300
301
302
    }
    return store;
  }

303
  PrimExpr VisitExpr_(const VarNode *op) final {
304
305
306
    auto var = Downcast<Var>(IRMutatorWithAnalyzer::VisitExpr_(op));
    if (buffer_data_to_buffer_.count(var)) {
      auto buffer = buffer_data_to_buffer_[var];
307
308
      if (buffer_remap_.count(buffer))
        return buffer_remap_[buffer]->data;
309
310
311
312
    }
    return var;
  }

313
314
  Stmt VisitStmt_(const EvaluateNode *op) final {
    const CallNode *call = op->value.as<CallNode>();
315
316
317
318
319
    // Do not analysis the call node to the global function.
    if (call && call->op.as<GlobalVarNode>())
      return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op));

    auto tile_op = ParseOperator(GetRef<Stmt>(op), buffer_data_to_buffer_);
320
321
    if (tile_op == nullptr)
      return IRMutatorWithAnalyzer::VisitStmt_(op);
322
    AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) {
323
324
      auto workspace =
          decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn");
325
      workspaces_.push_back(workspace);
326
      return workspace.access_ptr(2); // write
327
328
    };

329
330
331
332
333
    // Get pass config `tl.disable_tma_lower`
    tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
    Optional<Bool> opt_disable_tma_lower =
        ctxt->GetConfig(kDisableTMALower, Optional<Bool>());
    bool disable_tma_lower = opt_disable_tma_lower.value_or(Bool(false));
334
335
336
337
338
339
    Range thread_bounds;

    if (analyzer_->const_int_bound.IsBound(thread_var_->var)) {
      auto const_int_bound = analyzer_->const_int_bound(thread_var_);
      auto min_value = const_int_bound->min_value;
      auto max_value = const_int_bound->max_value;
340
      auto extent = max_value + 1 - min_value;
341
342
      thread_bounds =
          Range::FromMinExtent(IntImm(thread_var_->var.dtype(), min_value),
343
                               IntImm(thread_var_->var.dtype(), extent));
344
345
346
    } else {
      thread_bounds = Range::FromMinExtent(0, 1);
    }
347

348
349
350
351
    auto lowered = tile_op->Lower(
        LowerArgs{target_, thread_bounds, thread_var_->var, callback,
                  layout_map_, buffer_remap_, disable_tma_lower},
        analyzer_);
352
353
354
    return IRMutatorWithAnalyzer::VisitStmt(lowered);
  }

355
  Stmt VisitStmt_(const AttrStmtNode *op) final {
356
357
358
359
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
360
        thread_var_ = iv;
361
362
363
364
365
366
367
368
369
370
371
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_block_size_ = iv->dom->extent.as<IntImmNode>()->value;
      }
    }
    return arith::IRMutatorWithAnalyzer::VisitStmt_(op);
  }

  Target target_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Layout> layout_map_;
  Map<Buffer, Buffer> buffer_remap_;
372
373
374
375
  // This is a workaround for cpu backend,
  // we need to define a thread_var for the serial loop.
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
376
377
378
379
380
381
382
  size_t thread_block_size_ = 0;
  Array<Buffer> workspaces_;
  // For ptx Node, we need to remap the buffer and indices
  // By access CallNode instead of BufferLoad Node.
  bool is_ptx_{false};
  // Mapping from data Var of a Buffer to Buffer, for lookup
  std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
383
  Map<Var, Var> var_remap_;
384
385
386
387
388
389
390
391
392
393
394
395
396
397
};

namespace transform {

using namespace tir::transform;

tvm::transform::Pass LowerTileOp() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    return LowerTileOpPass::Substitute(std::move(f));
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {});
}

TVM_REGISTER_GLOBAL("tl.transform.LowerTileOp").set_body_typed(LowerTileOp);
398
} // namespace transform
399

400
401
} // namespace tl
} // namespace tvm