lower_hopper_intrin.cc 5.93 KB
Newer Older
1
2
3
4
5
/*!
 * \file lower hopper intrin.cc
 * \brief Lower Hopper intrinsics cuda GPU(sm90+)
 */

6
#include <tvm/ffi/reflection/registry.h>
7
8
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
9
#include <tvm/tir/op.h>
10
11
12
13
14
15
16
17
18
19
20
21
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "../runtime/runtime.h"

namespace tvm {
namespace tl {

using namespace tir;

22
#if (CUDA_MAJOR_VERSION >= 12)
23
class LowerHopperIntrin : public StmtExprMutator {
24
public:
25
  static PrimFunc Substitute(PrimFunc &f, bool disable_shuffle_elect) {
26
    PrimFuncNode *fptr = f.CopyOnWrite();
27
    LowerHopperIntrin substituter(disable_shuffle_elect);
28
    fptr->body = substituter.VisitStmt(f->body);
29
    Map<String, Array<PrimExpr>> init_desc_arg_map;
30
31
    for (auto [call, var] : substituter.desc_map_) {
      // Should allocate 128 bytes for TensorMap on stack
32
33
      Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(),
                             {StringImm("arg_value"), 16});
34
      Array<PrimExpr> init_desc_args;
35
      if (call->op.same_as(create_tma_descriptor())) {
36
        init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled));
37
      } else if (call->op.same_as(create_tma_im2col_descriptor())) {
38
39
40
41
42
        init_desc_args.push_back(StringImm(tvm_tensormap_create_im2col));
      } else {
        CHECK(0) << call->op;
      }
      init_desc_args.push_back(var);
43
44
      init_desc_args.insert(init_desc_args.end(), call->args.begin(),
                            call->args.end());
45
      // add to function attribute
46
47
48
49
      Call init_desc =
          Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args);
      fptr->body =
          LetStmt(var, alloc_desc, SeqStmt({Evaluate(init_desc), fptr->body}));
50
      init_desc_arg_map.Set(var->name_hint, init_desc_args);
51
    }
52
    f = WithAttr(std::move(f), "tma_descriptor_args", init_desc_arg_map);
53
54
55
    return f;
  }

56
57
58
  Stmt VisitStmt_(const AttrStmtNode *op) final {
    // Insert the prefetch TMA descriptor statement TO the beginning of the
    // kernel
59
60
61
62
63
64
65
66
67
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        auto body = StmtExprMutator::VisitStmt(op->body);
        if (prefetch_calls_.empty() && init_mbarrier_calls_.empty()) {
          return AttrStmt(op->node, op->attr_key, op->value, body);
        } else {
          Array<Stmt> stmt_seq;
          if (!init_mbarrier_calls_.empty()) {
68
69
70
            auto alloc_mbarrier =
                Evaluate(Call(DataType::Handle(), builtin::create_barriers(),
                              {static_cast<int>(init_mbarrier_calls_.size())}));
71
72
73
74
            stmt_seq.push_back(alloc_mbarrier);
          }

          auto stmts = prefetch_calls_;
75
76
          stmts.insert(stmts.end(), init_mbarrier_calls_.begin(),
                       init_mbarrier_calls_.end());
77
78
79
80
81
82
83
84
85
          PrimExpr condition;
          if (!disable_shuffle_elect_) {
            condition = Call(DataType::Bool(), tl_shuffle_elect(), {0});
          } else {
            condition = EQ(iv->var, 0);
          }
          auto stmt_ = IfThenElse(condition,
                                  stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]);
          stmt_seq.push_back(stmt_);
86
          if (!init_mbarrier_calls_.empty()) {
87
88
89
            Stmt mem_sync =
                Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(),
                              {StringImm("shared")}));
90
91
92
93
94
95
96
97
98
99
100
101
102
            stmt_seq.push_back(mem_sync);
          }
          stmt_seq.push_back(body);

          prefetch_calls_.clear();
          init_mbarrier_calls_.clear();
          return AttrStmt(op->node, op->attr_key, op->value, SeqStmt(stmt_seq));
        }
      }
    }
    return StmtExprMutator::VisitStmt_(op);
  }

103
  PrimExpr VisitExpr_(const CallNode *call) final {
104
105
    if (call->op.same_as(create_tma_descriptor()) ||
        call->op.same_as(create_tma_im2col_descriptor())) {
106
107
108
109
110
111
      Var var;
      auto iter = desc_map_.find(GetRef<Call>(call));
      if (iter != desc_map_.end()) {
        var = iter->second;
      } else {
        String name = call->args[2].as<Var>().value()->name_hint;
112
113
        var = Var(name + "_desc",
                  PointerType(PrimType(cuTensorMapType()), "grid_constant"));
114
        desc_map_[GetRef<Call>(call)] = var;
115
116
117
        prefetch_calls_.push_back(
            Evaluate(Call(DataType::Handle(), builtin::call_extern(),
                          {StringImm("tl::prefetch_tma_descriptor"), var})));
118
119
      }
      return var;
120
    } else if (call->op.same_as(create_list_of_mbarrier())) {
121
122
123
      ICHECK(init_mbarrier_calls_.size() == 0);
      int num_barriers = static_cast<int>(call->args.size());
      for (int i = 0; i < num_barriers; i++) {
124
        PrimExpr mbarrier = Call(DataType::Handle(), get_mbarrier(), {i});
125
126
127
        init_mbarrier_calls_.push_back(Evaluate(
            Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
                 {mbarrier, call->args[i]})));
128
129
130
131
132
133
134
      }
      return 0;
    } else {
      return StmtExprMutator::VisitExpr_(call);
    }
  }

135
private:
136
137
138
  Array<Stmt> prefetch_calls_;
  Array<Stmt> init_mbarrier_calls_;
  std::unordered_map<Call, Var, StructuralHash, ExprDeepEqual> desc_map_;
139
140
141
  LowerHopperIntrin(bool disable_shuffle_elect)
      : disable_shuffle_elect_(disable_shuffle_elect) {}
  bool disable_shuffle_elect_;
142
143
144
145
146
147
};

using namespace tir::transform;

tvm::transform::Pass LowerHopperIntrin() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
148
149
150
    bool disable_shuffle_elect =
        ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value();
    return LowerHopperIntrin::Substitute(f, disable_shuffle_elect);
151
152
153
154
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LowerHopperIntrin", {});
}

155
156
157
158
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.LowerHopperIntrin", LowerHopperIntrin);
});
159
#endif // (CUDA_MAJOR_VERSION >= 12)
160

161
162
} // namespace tl
} // namespace tvm