copy_sm90.h 10.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once

#include <cuda.h>

#include "common.h"

namespace tl {

11
12
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
                        void const *const smem_ptr, int32_t const &crd0) {
13
14
15
  uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
  uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
16
17
18
19
20
21
22
  asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::"
               "complete_tx::bytes"
               " [%0], [%1, {%3}], [%2];"
               :
               : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
                 "r"(crd0)
               : "memory");
23
24
}

25
26
27
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
                        void const *const smem_ptr, int32_t const &crd0,
                        int32_t const &crd1) {
28
29
30
  uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
  uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
31
32
33
34
35
36
37
  asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::"
               "complete_tx::bytes"
               " [%0], [%1, {%3, %4}], [%2];"
               :
               : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
                 "r"(crd0), "r"(crd1)
               : "memory");
38
39
}

40
41
42
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
                        void const *const smem_ptr, int32_t const &crd0,
                        int32_t const &crd1, int32_t const &crd2) {
43
44
45
  uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
  uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
46
47
48
49
50
51
52
  asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
               "complete_tx::bytes"
               " [%0], [%1, {%3, %4, %5}], [%2];"
               :
               : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
                 "r"(crd0), "r"(crd1), "r"(crd2)
               : "memory");
53
54
}

55
56
57
58
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
                        void const *const smem_ptr, int32_t const &crd0,
                        int32_t const &crd1, int32_t const &crd2,
                        int32_t const &crd3) {
59
60
61
  uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
  uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
62
63
64
65
66
67
68
  asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::"
               "complete_tx::bytes"
               " [%0], [%1, {%3, %4, %5, %6}], [%2];"
               :
               : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
                 "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
               : "memory");
69
70
}

71
72
73
74
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
                        void const *const smem_ptr, int32_t const &crd0,
                        int32_t const &crd1, int32_t const &crd2,
                        int32_t const &crd3, int32_t const &crd4) {
75
76
77
  uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
  uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
78
79
80
81
82
83
84
  asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::"
               "complete_tx::bytes"
               " [%0], [%1, {%3, %4, %5, %6, %7}], [%2];"
               :
               : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
                 "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
               : "memory");
85
86
}

87
88
89
90
91
92
TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor,
                               uint64_t &smem_mbar, void const *const smem_ptr,
                               int32_t const &coord_c, int32_t const &coord_w,
                               int32_t const &coord_h, int32_t const &coord_n,
                               uint16_t const &offset_w,
                               uint16_t const &offset_h) {
93
94
95
  uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
  uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
96
97
98
99
100
101
102
103
  asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
               ":complete_tx::bytes"
               " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};"
               :
               : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
                 "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n),
                 "h"(offset_w), "h"(offset_h)
               : "memory");
104
105
}

106
107
TL_DEVICE void tma_store(const CUtensorMap &descriptor,
                         void const *const smem_ptr, int32_t const &crd0) {
108
109
110
  uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);

111
112
113
114
115
  asm volatile(
      "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];"
      :
      : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0)
      : "memory");
116
117
}

118
119
120
TL_DEVICE void tma_store(const CUtensorMap &descriptor,
                         void const *const smem_ptr, int32_t const &crd0,
                         int32_t const &crd1) {
121
122
123
  uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);

124
125
  asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, "
               "{%2, %3}], [%1];"
126
127
128
129
130
               :
               : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1)
               : "memory");
}

131
132
133
TL_DEVICE void tma_store(const CUtensorMap &descriptor,
                         void const *const smem_ptr, int32_t const &crd0,
                         int32_t const &crd1, int32_t const &crd2) {
134
135
136
  uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);

137
138
  asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, "
               "{%2, %3, %4}], [%1];"
139
               :
140
141
               : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
                 "r"(crd2)
142
143
144
               : "memory");
}

145
146
147
148
TL_DEVICE void tma_store(const CUtensorMap &descriptor,
                         void const *const smem_ptr, int32_t const &crd0,
                         int32_t const &crd1, int32_t const &crd2,
                         int32_t const &crd3) {
149
150
151
  uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);

152
153
  asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, "
               "{%2, %3, %4, %5}], [%1];"
154
               :
155
156
               : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
                 "r"(crd2), "r"(crd3)
157
158
159
               : "memory");
}

160
161
162
163
TL_DEVICE void tma_store(const CUtensorMap &descriptor,
                         void const *const smem_ptr, int32_t const &crd0,
                         int32_t const &crd1, int32_t const &crd2,
                         int32_t const &crd3, int32_t const &crd4) {
164
165
166
  uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);

167
168
169
170
171
172
  asm volatile("cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, "
               "{%2, %3, %4, %5, %6}], [%1];"
               :
               : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
                 "r"(crd2), "r"(crd3), "r"(crd4)
               : "memory");
173
174
}

175
TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) {
176
177
178
179
  uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
  asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory");
}

180
TL_DEVICE void mbarrier_init(uint64_t &smem_barrier, uint32_t arrive_count) {
181
  uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
182
183
184
  asm volatile("mbarrier.init.shared.b64 [%1], %0;"
               :
               : "r"(arrive_count), "r"(smem_int_ptr));
185
186
}

187
TL_DEVICE void mbarrier_wait(uint64_t &smem_barrier, int phase_bit) {
188
  uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
189
190
191
192
193
194
195
  asm volatile("{\n"
               ".reg .pred                P1;\n"
               "LAB_WAIT:\n"
               "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n"
               "@!P1                      bra.uni LAB_WAIT;\n"
               "}\n" ::"r"(smem_int_ptr),
               "r"(phase_bit));
196
197
}

198
TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier) {
199
200
201
202
  uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
  asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(smem_int_ptr));
}

203
204
TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier,
                                  uint32_t transaction_bytes) {
205
206
207
208
209
210
  uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
  asm volatile("mbarrier.expect_tx.shared.b64 [%1], %0;"
               :
               : "r"(transaction_bytes), "r"(smem_int_ptr));
}

211
212
TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t &smem_barrier,
                                         uint32_t transaction_bytes) {
213
214
215
216
217
218
  uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
  asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0;"
               :
               : "r"(transaction_bytes), "r"(smem_int_ptr));
}

219
TL_DEVICE void mbarrier_cp_async_arrive(uint64_t &smem_barrier) {
220
  uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
221
222
223
  asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];"
               :
               : "r"(smem_int_ptr));
224
225
}

226
227
228
TL_DEVICE void fence_proxy_async() {
  asm volatile("fence.proxy.async.shared::cta;" : :);
}
229

230
231
232
233
234
235
236
237
238
// Indicate arrival of warp issuing TMA_STORE
TL_DEVICE void tma_store_arrive() {
  asm volatile("cp.async.bulk.commit_group;");
}

template <int Count> TL_DEVICE void tma_store_wait() {
  asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(Count) : "memory");
}

239
TL_DEVICE void syncthreads_partial(uint64_t &smem_barrier) {
240
  uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
241
  uint64_t state = 0;
242
243
244
245
246
247
248
249
250
  asm volatile("{\n"
               ".reg .pred                P1;\n"
               "mbarrier.arrive.shared.b64 %1, [%0];\n"
               "LAB_WAIT:\n"
               "mbarrier.try_wait.shared.b64 P1, [%0], %1;\n"
               "@!P1                      bra.uni LAB_WAIT;\n"
               "}\n"
               :
               : "r"(smem_int_ptr), "l"(state));
251
252
}

253
254
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
  asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount));
255
256
}

257
258
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
  asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
259
260
}

261
} // namespace tl