lower_hopper_intrin.cc 9.12 KB
Newer Older
1
2
3
4
5
/*!
 * \file lower hopper intrin.cc
 * \brief Lower Hopper intrinsics cuda GPU(sm90+)
 */

6
#include <tvm/ffi/reflection/registry.h>
7
8
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
9
#include <tvm/tir/op.h>
10
11
12
13
14
15
16
17
18
19
20
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"
#include "../runtime/runtime.h"

namespace tvm {
namespace tl {

using namespace tir;

21
#if (CUDA_MAJOR_VERSION >= 12)
22
class LowerHopperIntrin : public StmtExprMutator {
23
public:
24
  static PrimFunc Substitute(PrimFunc &f, bool disable_shuffle_elect) {
25
    PrimFuncNode *fptr = f.CopyOnWrite();
26
    LowerHopperIntrin substituter(disable_shuffle_elect);
27
    fptr->body = substituter.VisitStmt(f->body);
28
    Map<Var, Array<PrimExpr>> init_desc_arg_map;
29
30
31
    // Collect prologue/epilogue statements for host-side setup/teardown
    Array<Stmt> prologue_stmts;
    Array<Stmt> epilogue_stmts;
32
    for (const auto &[call, var] : substituter.desc_map_) {
33
      // Should allocate 128 bytes for TensorMap on stack
34
      Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(),
35
                             {StringImm("tvm_ffi_any"), 16});
36
      Array<PrimExpr> init_desc_args;
37
      if (call->op.same_as(create_tma_descriptor())) {
38
        init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled));
39
      } else if (call->op.same_as(create_tma_im2col_descriptor())) {
40
41
42
43
44
        init_desc_args.push_back(StringImm(tvm_tensormap_create_im2col));
      } else {
        CHECK(0) << call->op;
      }
      init_desc_args.push_back(var);
45
46
      init_desc_args.insert(init_desc_args.end(), call->args.begin(),
                            call->args.end());
47
      // add to function attribute
48
49
      Call init_desc =
          Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args);
50
51
      // Accumulate TMA descriptor init into prologue
      prologue_stmts.push_back(LetStmt(var, alloc_desc, Evaluate(init_desc)));
52
      init_desc_arg_map.Set(var, init_desc_args);
53
    }
54
    f = WithAttr(std::move(f), "tma_descriptor_args", init_desc_arg_map);
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    // Additionally, if L2 persistent cache annotations were lowered earlier,
    // materialize TVM FFI calls to set the stream access policy window.
    if (f->attrs.defined() && f->attrs->dict.count("l2_persistent_map")) {
      auto l2_map =
          f->GetAttr<Map<String, Array<PrimExpr>>>("l2_persistent_map");
      if (l2_map.defined()) {
        // Build a lookup from buffer name to Buffer object
        std::unordered_map<std::string, Buffer> name2buf;
        for (const auto &kv : f->buffer_map) {
          name2buf.emplace(kv.second->name, kv.second);
        }
        for (const auto &kv : l2_map.value()) {
          const std::string buf_name = kv.first;
          const Array<PrimExpr> &args = kv.second;
          if (name2buf.count(buf_name) == 0) {
            continue;
          }
          const Buffer &buf = name2buf.at(buf_name);
          // Build base pointer expression (read access)
          PrimExpr base_ptr = buf.access_ptr(1);
          // Args packed: func_name, base_ptr, num_bytes, hit_ratio
          Array<PrimExpr> packed_args;
          packed_args.push_back(
              StringImm(tvm_cuda_stream_set_access_policy_window));
          packed_args.push_back(base_ptr);
          // size_in_bytes (args[1]) then hit_ratio (args[0])
          ICHECK_GE(args.size(), 2);
          packed_args.push_back(args[1]);
          packed_args.push_back(args[0]);
          prologue_stmts.push_back(Evaluate(Call(
              DataType::Int(32), builtin::tvm_call_packed(), packed_args)));
        }
        // Add a single epilogue call to reset the access policy window and
        // restore L2 limit
        Array<PrimExpr> reset_args;
        reset_args.push_back(
            StringImm(tvm_cuda_stream_reset_access_policy_window));
        epilogue_stmts.push_back(Evaluate(
            Call(DataType::Int(32), builtin::tvm_call_packed(), reset_args)));
      }
    }

    // Stitch prologue statements before the original body
    if (!prologue_stmts.empty()) {
      // Chain the Let/Evaluate statements sequentially
      Stmt seq = prologue_stmts.size() == 1 ? prologue_stmts[0]
                                            : SeqStmt(prologue_stmts);
      fptr->body = SeqStmt({seq, fptr->body});
    }
    if (!epilogue_stmts.empty()) {
      Stmt seq_end = epilogue_stmts.size() == 1 ? epilogue_stmts[0]
                                                : SeqStmt(epilogue_stmts);
      fptr->body = SeqStmt({fptr->body, seq_end});
    }
110
111
112
    return f;
  }

113
114
115
  Stmt VisitStmt_(const AttrStmtNode *op) final {
    // Insert the prefetch TMA descriptor statement TO the beginning of the
    // kernel
116
117
118
119
120
121
122
123
124
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
      if (iv->thread_tag == "threadIdx.x") {
        auto body = StmtExprMutator::VisitStmt(op->body);
        if (prefetch_calls_.empty() && init_mbarrier_calls_.empty()) {
          return AttrStmt(op->node, op->attr_key, op->value, body);
        } else {
          Array<Stmt> stmt_seq;
          if (!init_mbarrier_calls_.empty()) {
125
126
127
            auto alloc_mbarrier =
                Evaluate(Call(DataType::Handle(), builtin::create_barriers(),
                              {static_cast<int>(init_mbarrier_calls_.size())}));
128
129
130
131
            stmt_seq.push_back(alloc_mbarrier);
          }

          auto stmts = prefetch_calls_;
132
133
          stmts.insert(stmts.end(), init_mbarrier_calls_.begin(),
                       init_mbarrier_calls_.end());
134
135
136
137
138
139
140
141
142
          PrimExpr condition;
          if (!disable_shuffle_elect_) {
            condition = Call(DataType::Bool(), tl_shuffle_elect(), {0});
          } else {
            condition = EQ(iv->var, 0);
          }
          auto stmt_ = IfThenElse(condition,
                                  stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]);
          stmt_seq.push_back(stmt_);
143
          if (!init_mbarrier_calls_.empty()) {
144
145
146
147
148
149
150
151
152
153
            // Note from FlashAttention:
            // Helps with visibility of barrier init operations across warps /
            // cta / cluster Available as a separate function so as to batch
            // inits across barriers and fence once Note : It must be composed
            // with an appropriate sync instruction with the right scope to
            // ensure visibility eg. __syncthreads() or a cluster_arrive() +
            // cluster_wait()
            Stmt mem_fence = Evaluate(Call(
                DataType::Handle(), tvm::tl::ptx_fence_barrier_init(), {}));
            stmt_seq.push_back(mem_fence);
154
155
156
            Stmt mem_sync =
                Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(),
                              {StringImm("shared")}));
157
158
159
160
161
162
163
164
165
166
167
168
169
            stmt_seq.push_back(mem_sync);
          }
          stmt_seq.push_back(body);

          prefetch_calls_.clear();
          init_mbarrier_calls_.clear();
          return AttrStmt(op->node, op->attr_key, op->value, SeqStmt(stmt_seq));
        }
      }
    }
    return StmtExprMutator::VisitStmt_(op);
  }

170
  PrimExpr VisitExpr_(const CallNode *call) final {
171
172
    if (call->op.same_as(create_tma_descriptor()) ||
        call->op.same_as(create_tma_im2col_descriptor())) {
173
      Var var;
174
      auto iter = desc_map_.find(tvm::ffi::GetRef<Call>(call));
175
176
177
178
      if (iter != desc_map_.end()) {
        var = iter->second;
      } else {
        String name = call->args[2].as<Var>().value()->name_hint;
179
180
        var = Var(name + "_desc",
                  PointerType(PrimType(cuTensorMapType()), "grid_constant"));
181
        desc_map_[tvm::ffi::GetRef<Call>(call)] = var;
182
183
184
        prefetch_calls_.push_back(
            Evaluate(Call(DataType::Handle(), builtin::call_extern(),
                          {StringImm("tl::prefetch_tma_descriptor"), var})));
185
186
      }
      return var;
187
    } else if (call->op.same_as(create_list_of_mbarrier())) {
188
      ICHECK(init_mbarrier_calls_.empty());
189
190
      int num_barriers = static_cast<int>(call->args.size());
      for (int i = 0; i < num_barriers; i++) {
191
        PrimExpr mbarrier = Call(DataType::Handle(), get_mbarrier(), {i});
192
193
194
        init_mbarrier_calls_.push_back(Evaluate(
            Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
                 {mbarrier, call->args[i]})));
195
196
197
198
199
200
201
      }
      return 0;
    } else {
      return StmtExprMutator::VisitExpr_(call);
    }
  }

202
private:
203
204
205
  Array<Stmt> prefetch_calls_;
  Array<Stmt> init_mbarrier_calls_;
  std::unordered_map<Call, Var, StructuralHash, ExprDeepEqual> desc_map_;
206
207
208
  LowerHopperIntrin(bool disable_shuffle_elect)
      : disable_shuffle_elect_(disable_shuffle_elect) {}
  bool disable_shuffle_elect_;
209
210
211
212
213
};

using namespace tir::transform;

tvm::transform::Pass LowerHopperIntrin() {
214
  auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
215
216
217
    bool disable_shuffle_elect =
        ctx->GetConfig<Bool>(kDisableShuffleElect, Bool(false)).value();
    return LowerHopperIntrin::Substitute(f, disable_shuffle_elect);
218
219
220
221
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LowerHopperIntrin", {});
}

222
TVM_FFI_STATIC_INIT_BLOCK() {
223
224
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.LowerHopperIntrin", LowerHopperIntrin);
225
}
226
#endif // (CUDA_MAJOR_VERSION >= 12)
227

228
229
} // namespace tl
} // namespace tvm