lower_tile_op.cc 18.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 * \file lower_tile_op.cc
 * \brief Lower the tile op for further codegen.
 */

#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include "../layout/layout.h"
#include "../layout/utils.h"
13
#include "../op/builtin.h"
14
#include "../op/op.h"
15

16
#include "arith/ir_mutator_with_analyzer.h"
17
18
19
20
21
22
23
#include "loop_partition.h"

namespace tvm {
namespace tl {

using namespace tir;

24
25
static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout,
                                   Map<Var, Var> &var_remap) {
26
27
  const auto *ptr_type =
      TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode);
28
29
30
31
32
33
34
35
36
37
38
  Type new_type;
  // convert fragments to normal local buffer
  if (ptr_type->storage_scope == "local.fragment") {
    new_type = PointerType(ptr_type->element_type, "local");
  } else {
    new_type = buffer->data->type_annotation;
  }
  Var new_var;
  if (ptr_type->storage_scope == "global") {
    new_var = buffer->data;
  } else {
39
40
41
42
43
44
    if (var_remap.count(buffer->data)) {
      new_var = var_remap[buffer->data];
    } else {
      new_var = Var(buffer->data->name_hint, new_type);
      var_remap.Set(buffer->data, new_var);
    }
45
  }
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  Array<PrimExpr> layout_shape = layout->OutputShape();
  Array<PrimExpr> output_shape = layout_shape;

  if (ptr_type->storage_scope == "shared" ||
      ptr_type->storage_scope == "shared.dyn") {
    int replicate_extent = 1;
    Array<PrimExpr> buffer_shape = buffer->shape;
    int buffer_extent = 1;
    int layout_extent = 1;
    for (size_t i = 0; i < buffer_shape.size(); i++) {
      auto shape = buffer_shape[i].as<IntImmNode>();
      buffer_extent *= shape->value;
    }
    for (size_t i = 0; i < layout_shape.size(); i++) {
      auto shape = layout_shape[i].as<IntImmNode>();
      layout_extent *= shape->value;
    }
    replicate_extent = buffer_extent / layout_extent;
    if (replicate_extent > 1) {
      output_shape.insert(output_shape.begin(), replicate_extent);
    }
  }
  return Buffer(new_var, buffer->dtype, output_shape, {}, buffer->elem_offset,
                buffer->name, buffer->data_alignment, buffer->offset_factor,
                buffer->buffer_type);
71
72
}

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
/*!
 * \brief A class that rewrites buffer references in a statement based on a
 * given buffer remapping.
 *
 * This class is used to update buffer references in a statement after buffer
 * transformations have been applied. It specifically handles the remapping of
 * padding annotations.
 */
class RemapBufferRewriter : public arith::IRMutatorWithAnalyzer {
public:
  /*!
   * \brief Substitute buffer references in a statement based on a given buffer
   * remapping. \param stmt The statement to rewrite. \param buffer_remap A map
   * from old buffers to new buffers. \return The rewritten statement.
   */
  static Stmt Substitute(Stmt stmt, Map<Buffer, Buffer> buffer_remap) {
    arith::Analyzer analyzer;
    RemapBufferRewriter substituter(&analyzer);
    substituter.buffer_remap_ = std::move(buffer_remap);
    return substituter.VisitStmt(stmt);
  }

private:
  using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;

  Stmt VisitStmt_(const BlockNode *op) final {
    if (op->annotations.count(attr::kPaddingMap)) {
      return RewritePaddingMap(op);
    }
    return IRMutatorWithAnalyzer::VisitStmt_(op);
  }

  /*!
   * \brief Rewrite the padding map annotation of a block.
   * \param op The block node to rewrite.
   * \return The rewritten block.
   */
  Stmt RewritePaddingMap(const BlockNode *op) {
    auto padding_map =
        op->annotations.Get(attr::kPaddingMap).as<Map<Var, PrimExpr>>().value();

    Map<Var, Var> var_remap = CreateVarRemap();
    Map<Var, PrimExpr> new_padding_map =
        RemapPaddingMap(padding_map, var_remap);

    auto block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
    auto block_ptr = block.CopyOnWrite();
    block_ptr->annotations.Set(attr::kPaddingMap, new_padding_map);
    return block;
  }

  /*!
   * \brief Create a mapping from old variables to new variables based on buffer
   * remapping. \return A map from old variables to new variables.
   */
  Map<Var, Var> CreateVarRemap() const {
    Map<Var, Var> var_remap;
    for (const auto &[buffer, buffer_remap] : buffer_remap_) {
      var_remap.Set(buffer->data, buffer_remap->data);
    }
    return var_remap;
  }

  /*!
   * \brief Remap the padding map using the variable remapping.
   * \param padding_map The original padding map.
   * \param var_remap The variable remapping.
   * \return The remapped padding map.
   */
  Map<Var, PrimExpr> RemapPaddingMap(const Map<Var, PrimExpr> &padding_map,
                                     const Map<Var, Var> &var_remap) const {
    Map<Var, PrimExpr> new_padding_map;
    for (const auto &[var, padding] : padding_map) {
      if (var_remap.count(var)) {
        new_padding_map.Set(var_remap.at(var), padding);
      } else {
        new_padding_map.Set(var, padding);
      }
    }
    return new_padding_map;
  }

  Map<Buffer, Buffer> buffer_remap_;
};

158
class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
159
public:
160
161
162
163
164
  static PrimFunc Substitute(PrimFunc f) {
    arith::Analyzer analyzer;
    LowerTileOpPass substituter(&analyzer);
    // Trace the buffer map for tvm_access_ptr
    substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end());
165
    for (const auto &[_, buffer] : f->buffer_map) {
166
167
168
169
170
      substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
    ICHECK(target.defined()) << "LowerTileOpPass: Require the target attribute";
    substituter.target_ = target.value();
171
    PrimFuncNode *fptr = f.CopyOnWrite();
172
    fptr->body = substituter.VisitStmt(f->body);
173
174
    fptr->body =
        RemapBufferRewriter::Substitute(fptr->body, substituter.buffer_remap_);
175
176
177
178
179
180
181
182
183
    tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
    Optional<Bool> opt_disable_tma_lower =
        ctxt->GetConfig(kDisableTMALower, Optional<Bool>());

    if (!opt_disable_tma_lower.value_or(Bool(false))) {
      // @lei: this is a workaround, as if we don't disable tma lower,
      // cp async lowering won't be generated.
      ctxt->config.Set(kDisableTMALower, Bool(!substituter.has_tma_));
    }
184
185
186
    return f;
  }

187
private:
188
189
  using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;

190
  Stmt VisitStmt_(const BlockNode *op) final {
191
192
193
194
195
196
197
198
199
200
201
202
    // Record the mapping from buffer data var to buffer for later lookup
    for (auto buffer : op->alloc_buffers) {
      buffer_map_.insert({buffer->data, buffer});
    }
    for (auto match_buffer : op->match_buffers) {
      buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer});
    }
    for (auto buffer : op->alloc_buffers) {
      buffer_data_to_buffer_.Set(buffer->data, buffer);
    }
    Map<Var, Layout> vmap;
    if (op->annotations.count(attr::kLayoutMap)) {
203
204
205
      auto layout_map = op->annotations.at(attr::kLayoutMap)
                            .as<Map<Buffer, Layout>>()
                            .value();
206
      for (auto [buffer, layout] : layout_map) {
207
208
        buffer_remap_.Set(buffer,
                          makeBufferWithLayout(buffer, layout, var_remap_));
209
210
211
212
213
214
215
216
217
218
219
        layout_map_.Set(buffer, layout);
      }
    }
    auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
    auto block_ptr = block.CopyOnWrite();
    for (size_t i = 0; i < block->alloc_buffers.size(); i++) {
      auto buffer = block->alloc_buffers[i];
      if (buffer_remap_.count(buffer)) {
        block_ptr->alloc_buffers.Set(i, buffer_remap_[buffer]);
      }
    }
220
221
    for (const auto &buffer : workspaces_)
      block_ptr->alloc_buffers.push_back(buffer);
222
223
224
225
226
227
228
    workspaces_.clear();
    block_ptr->annotations.erase(attr::kLayoutMap);
    return block;
  }

  int CheckAndGetBufferRowSize(Buffer buffer) {
    CHECK(buffer->shape.size() >= 2)
229
230
        << "The dimension of Buffer \"" << buffer->name << "\" with shape "
        << buffer->shape << " should be at least 2";
231
232
233
234
235
236

    auto dim = buffer->shape.size();
    auto buffer_row_size = buffer->shape[dim - 1].as<IntImmNode>()->value;
    return buffer_row_size;
  }

237
238
  PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr,
                                    Optional<PrimExpr> offset = NullOpt,
239
                                    DataType dtype = DataType::Int(32)) {
240
241
    // The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
    // accumulate it to smem_offset
242
243
244
245
246
247
248
249
250
251
252
    CHECK(access_ptr->IsInstance<CallNode>())
        << "Invalid access ptr for permuted layout: " << access_ptr;
    auto access_ptr_call = Downcast<Call>(access_ptr);
    if (access_ptr_call->op.same_as(builtin::tvm_access_ptr())) {
      LOG(FATAL) << "Transformation for tvm_access_ptr is not implemented yet";
    } else if (access_ptr_call->op.same_as(builtin::address_of())) {
      BufferLoad load = Downcast<BufferLoad>(access_ptr_call->args[0]);
      Array<PrimExpr> indices = load->indices;
      Array<PrimExpr> shape = load->buffer->shape;

      CHECK_EQ(indices.size(), shape.size())
253
254
255
256
          << "Indices size and shape size must match for general N-dimensional "
             "buffer "
          << "but got indices size: " << indices.size()
          << " and shape size: " << shape.size();
257
258
259
260
261
262
263
264
265

      PrimExpr elem_offset = 0;
      PrimExpr stride = 1;

      for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
        elem_offset += indices[i] * stride;
        stride *= shape[i];
      }

266
267
      PrimExpr smem_offset =
          elem_offset + (offset.defined() ? offset.value() : 0);
268
269
270

      auto new_buffer = buffer_remap_[load->buffer];

271
272
      auto buffer_map_iter =
          buffer_map_.find(Downcast<Var>(load->buffer->data));
273
      CHECK(buffer_map_iter != buffer_map_.end())
274
275
          << "The buffer corresponding to data Var " << access_ptr_call->args[0]
          << " is not found";
276
277
278
279
280
281
282
283
284

      int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second);
      (void)buffer_row_size;

      // Convert offset to target-dimension, reindex it and convert it back
      Array<PrimExpr> multi_dim_indices;
      PrimExpr remaining_offset = smem_offset;

      for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
285
286
        multi_dim_indices.insert(multi_dim_indices.begin(),
                                 floormod(remaining_offset, shape[i]));
287
288
289
        remaining_offset = floordiv(remaining_offset, shape[i]);
      }

290
291
      auto forward_indices =
          layout_map_[load->buffer]->Forward(multi_dim_indices);
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
      PrimExpr new_offset = 0;
      PrimExpr stride_offset = 1;
      for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
        new_offset += forward_indices[i] * stride_offset;
        stride_offset *= shape[i];
      }
      new_offset = analyzer_->Simplify(new_offset);

      Array<PrimExpr> new_indices;
      for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
        new_indices.insert(new_indices.begin(), floormod(new_offset, shape[i]));
        new_offset = floordiv(new_offset, shape[i]);
      }

      auto new_access_ptr = access_ptr_call.CopyOnWrite();
      new_access_ptr->args.Set(0, BufferLoad(new_buffer, new_indices));
    } else {
      LOG(FATAL) << "Invalid access op for permuted layout: " << access_ptr;
    }

    return access_ptr_call;
  }

315
  PrimExpr VisitExpr_(const tir::CallNode *op) final {
316
317
318
319
320
    if ((!has_tma_) && (op->op.same_as(tl::tma_load()) ||
                        op->op.same_as(tl::tma_load_im2col()) ||
                        op->op.same_as(tl::tma_store()))) {
      has_tma_ = true;
    }
321
322
323
324
325
326
327
    Array<RelayExpr> ptx_instructions = {builtin::ptx_ldmatrix(),
                                         builtin::mma_store()};

    if (std::find(ptx_instructions.begin(), ptx_instructions.end(), op->op) ==
        ptx_instructions.end()) {
      auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
      return call;
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    } else {
      is_ptx_ = true;
    }
    // Rewrite from/to shared or shared.dyn to/from local
    auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
    if (call->op.same_as(builtin::ptx_ldmatrix())) {
      // form: T.ptx_ldmatrix(..., smem_ptr, smem_offset)
      // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
      // or T.address_of(buffer, offset)
      auto access_ptr = call->args[5];
      PrimExpr smem_offset = call->args[6];
      Call address_of_call = Downcast<Call>(access_ptr);
      if (!address_of_call->op.same_as(builtin::address_of())) {
        LOG(FATAL) << "Invalid access ptr for permuted layout: " << access_ptr;
      }
      BufferLoad load = Downcast<BufferLoad>(address_of_call->args[0]);

      if (buffer_remap_.count(load->buffer)) {
346
347
        auto new_access_ptr =
            HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype);
348
349
350
351
352
        auto new_call = call.CopyOnWrite();
        new_call->args.Set(5, new_access_ptr);
        new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
      }
    } else if (call->op.same_as(builtin::mma_store())) {
353
354
      // because we will directly store result to Buffer instead of calling
      // mma_store now
355
      auto access_ptr = call->args[2];
356
357
      auto new_access_ptr =
          HandleAccessPtrAndOffset(access_ptr, NullOpt, call->dtype);
358
359
360
361
362
363
364
365
366
      auto new_call = call.CopyOnWrite();
      new_call->args.Set(2, new_access_ptr);
    } else {
      LOG(FATAL) << "Invalid call node: " << call;
    }
    is_ptx_ = false;
    return call;
  }

367
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
368
369
370
371
    auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
    if (is_ptx_) {
      return load;
    }
372
373
374
    auto buffer = load->buffer;
    if (buffer_remap_.count(buffer)) {
      auto new_indices = layout_map_[buffer]->Forward(load->indices);
375
376
      auto new_buffer = buffer_remap_[load->buffer];
      return BufferLoad(new_buffer, new_indices);
377
378
379
380
381
382
    } else if (var_remap_.count(buffer->data)) {
      auto new_buffer = Buffer(
          var_remap_[buffer->data], buffer->dtype, buffer->shape,
          buffer->strides, buffer->elem_offset, buffer->name,
          buffer->data_alignment, buffer->offset_factor, buffer->buffer_type);
      return BufferLoad(new_buffer, load->indices);
383
384
385
386
    }
    return load;
  }

387
  Stmt VisitStmt_(const BufferStoreNode *op) final {
388
    auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
389
390
391
    auto buffer = store->buffer;
    if (buffer_remap_.count(buffer)) {
      auto new_indices = layout_map_[buffer]->Forward(store->indices);
392
393
      auto new_buffer = buffer_remap_[store->buffer];
      return BufferStore(new_buffer, store->value, new_indices);
394
395
396
397
398
399
    } else if (var_remap_.count(buffer->data)) {
      auto new_buffer = Buffer(
          var_remap_[buffer->data], buffer->dtype, buffer->shape,
          buffer->strides, buffer->elem_offset, buffer->name,
          buffer->data_alignment, buffer->offset_factor, buffer->buffer_type);
      return BufferStore(new_buffer, store->value, store->indices);
400
401
402
403
    }
    return store;
  }

404
  PrimExpr VisitExpr_(const VarNode *op) final {
405
406
407
    auto var = Downcast<Var>(IRMutatorWithAnalyzer::VisitExpr_(op));
    if (buffer_data_to_buffer_.count(var)) {
      auto buffer = buffer_data_to_buffer_[var];
408
409
      if (buffer_remap_.count(buffer))
        return buffer_remap_[buffer]->data;
410
411
412
413
    }
    return var;
  }

414
415
  Stmt VisitStmt_(const EvaluateNode *op) final {
    const CallNode *call = op->value.as<CallNode>();
416
417
418
419
420
    // Do not analysis the call node to the global function.
    if (call && call->op.as<GlobalVarNode>())
      return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op));

    auto tile_op = ParseOperator(GetRef<Stmt>(op), buffer_data_to_buffer_);
421
422
    if (tile_op == nullptr)
      return IRMutatorWithAnalyzer::VisitStmt_(op);
423
    AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) {
424
425
      auto workspace =
          decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn");
426
      workspaces_.push_back(workspace);
427
      return workspace.access_ptr(2); // write
428
429
    };

430
431
432
433
434
    // Get pass config `tl.disable_tma_lower`
    tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
    Optional<Bool> opt_disable_tma_lower =
        ctxt->GetConfig(kDisableTMALower, Optional<Bool>());
    bool disable_tma_lower = opt_disable_tma_lower.value_or(Bool(false));
435
436
437
438
439
440
    Range thread_bounds;

    if (analyzer_->const_int_bound.IsBound(thread_var_->var)) {
      auto const_int_bound = analyzer_->const_int_bound(thread_var_);
      auto min_value = const_int_bound->min_value;
      auto max_value = const_int_bound->max_value;
441
      auto extent = max_value + 1 - min_value;
442
443
      thread_bounds =
          Range::FromMinExtent(IntImm(thread_var_->var.dtype(), min_value),
444
                               IntImm(thread_var_->var.dtype(), extent));
445
446
447
    } else {
      thread_bounds = Range::FromMinExtent(0, 1);
    }
448

449
450
451
452
    auto lowered = tile_op->Lower(
        LowerArgs{target_, thread_bounds, thread_var_->var, callback,
                  layout_map_, buffer_remap_, disable_tma_lower},
        analyzer_);
453
454
455
    return IRMutatorWithAnalyzer::VisitStmt(lowered);
  }

456
  Stmt VisitStmt_(const AttrStmtNode *op) final {
457
458
459
460
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      ICHECK_NE(iv->thread_tag.length(), 0U);
      if (iv->thread_tag == "threadIdx.x") {
461
        thread_var_ = iv;
462
463
464
465
466
467
468
469
470
471
472
        ICHECK(iv->dom->extent.as<IntImmNode>());
        thread_block_size_ = iv->dom->extent.as<IntImmNode>()->value;
      }
    }
    return arith::IRMutatorWithAnalyzer::VisitStmt_(op);
  }

  Target target_;
  Map<Var, Buffer> buffer_data_to_buffer_;
  Map<Buffer, Layout> layout_map_;
  Map<Buffer, Buffer> buffer_remap_;
473
474
475
476
  // This is a workaround for cpu backend,
  // we need to define a thread_var for the serial loop.
  IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
                                IterVarType::kDataPar);
477
478
479
480
481
482
483
  size_t thread_block_size_ = 0;
  Array<Buffer> workspaces_;
  // For ptx Node, we need to remap the buffer and indices
  // By access CallNode instead of BufferLoad Node.
  bool is_ptx_{false};
  // Mapping from data Var of a Buffer to Buffer, for lookup
  std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
484
  Map<Var, Var> var_remap_;
485
  bool has_tma_{false};
486
487
488
489
490
491
492
493
494
495
496
497
498
499
};

namespace transform {

using namespace tir::transform;

tvm::transform::Pass LowerTileOp() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    return LowerTileOpPass::Substitute(std::move(f));
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {});
}

TVM_REGISTER_GLOBAL("tl.transform.LowerTileOp").set_body_typed(LowerTileOp);
500
} // namespace transform
501

502
503
} // namespace tl
} // namespace tvm