"examples/linear_attention/example_mamba_chunk_state.py" did not exist on "64f17c2f369e612cc297d358f607307a615bbb59"
bulk_copy.cc 14.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 * \file tl/op/bulk_copy.cc
 * \brief Bulk copy operator.
 *
 */

#include "bulk_copy.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../target/cuda.h"
14
#include "../target/utils.h"
15
16
17
18
19
20
21
22
23
24
25
#include "builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

static int to_CUtensorMapDataType(DataType dtype) {
  CUtensorMapDataType tp;
  if (dtype.is_float()) {
    switch (dtype.bits()) {
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    case 64:
      tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
      break;
    case 32:
      tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
      break;
    case 16:
      tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
      break;
    case 8:
      tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
      break;
    default:
      ICHECK(0) << dtype;
40
41
42
    }
  } else if (dtype.is_bfloat16()) {
    tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
43
44
  } else if (dtype.is_e4m3_float8() or dtype.is_e5m2_float8()) {
    tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
45
46
  } else if (dtype.is_int()) {
    switch (dtype.bits()) {
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    case 64:
      tp = CU_TENSOR_MAP_DATA_TYPE_INT64;
      break;
    case 32:
      tp = CU_TENSOR_MAP_DATA_TYPE_INT32;
      break;
    case 16:
      tp = CU_TENSOR_MAP_DATA_TYPE_UINT16;
      break;
    case 8:
      tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
      break;
    default:
      ICHECK(0) << dtype;
61
62
63
    }
  } else if (dtype.is_uint()) {
    switch (dtype.bits()) {
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    case 64:
      tp = CU_TENSOR_MAP_DATA_TYPE_UINT64;
      break;
    case 32:
      tp = CU_TENSOR_MAP_DATA_TYPE_UINT32;
      break;
    case 16:
      tp = CU_TENSOR_MAP_DATA_TYPE_UINT16;
      break;
    case 8:
      tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
      break;
    default:
      ICHECK(0) << dtype;
78
79
80
81
82
83
84
    }
  } else {
    ICHECK(0) << dtype;
  }
  return static_cast<int>(tp);
}

85
template <typename T> static Array<T> ReverseArray(Array<T> array) {
86
87
88
  return Array<T>{array.rbegin(), array.rend()};
}

89
Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
90
91
  if (T.disable_tma_lower)
    return Stmt();
92
93
  if (!TargetIsHopper(T.target))
    return Stmt();
94
  bool is_load;
95
96
  if (src.scope() == "global" &&
      (dst.scope() == "shared.dyn" || dst.scope() == "shared")) {
97
98
    // Use the Hopper TMA bulk copy instructions
    is_load = true;
99
100
  } else if (dst.scope() == "global" &&
             (src.scope() == "shared.dyn" || src.scope() == "shared")) {
101
102
103
104
105
106
107
108
109
110
111
112
    is_load = false;
  } else {
    return Stmt();
  }
  Buffer global_tensor = is_load ? src : dst;
  Buffer shared_tensor = is_load ? dst : src;
  Layout shared_layout;
  if (T.layout_map.count(shared_tensor)) {
    shared_layout = T.layout_map[shared_tensor];
    shared_tensor = T.buffer_remap[shared_tensor];
  }
  if (T.layout_map.count(global_tensor)) {
113
114
    ICHECK(T.layout_map.count(global_tensor) == 0)
        << "Cannot support global layout.";
115
116
117
118
119
120
121
122
123
  }

  TMADesc desc;

  // Verify copy rank
  desc.rank = global_tensor->shape.size();
  ICHECK(desc.rank >= 1 && desc.rank <= 5) << desc.rank;

  // Verify datatype
124
125
126
127
128
  ICHECK(global_tensor->dtype == shared_tensor->dtype)
      << "Copy between buffer " << global_tensor->name << " and "
      << shared_tensor->name << " with different data type "
      << global_tensor->dtype << " and " << shared_tensor->dtype;

129
130
131
132
133
134
  desc.data_type = to_CUtensorMapDataType(global_tensor->dtype);

  // Global Tensor Shape and Stride
  auto global_range = is_load ? src_range : dst_range;
  desc.global_addr = global_tensor->data;
  desc.global_shape = ReverseArray(global_tensor->shape);
135
136
  Array<PrimExpr> global_coords =
      ReverseArray(global_range.Map([](Range r) { return r->min; }));
137
138
139
140
141
142
143
144
145
146
147
148
149
150
  if (!global_tensor->strides.empty()) {
    desc.global_stride = ReverseArray(global_tensor->strides);
  } else {
    // Create stride from shape
    PrimExpr stride = 1;
    desc.global_stride.reserve(desc.rank);
    for (size_t i = 0; i < desc.rank; i++) {
      desc.global_stride.push_back(stride);
      stride *= desc.global_shape[i];
    }
  }
  // The first stride element should be 1
  ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
  // Make global stride in bytes
151
152
  desc.global_stride = desc.global_stride.Map(
      [&](PrimExpr e) { return e * global_tensor->dtype.bytes(); });
153
154

  // Smem Box
155
156
  desc.smem_box =
      ReverseArray(global_range.Map([](Range r) { return r->extent; }));
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
  desc.smem_stride = Array<PrimExpr>(desc.rank, PrimExpr(1));

  // L2 & OOB
  desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
  desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);

  // Detect smem layout
  desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
  if (!shared_layout.defined()) {
    desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
  } else {
    ICHECK(shared_layout->InputDim() == 2) << "Cannot detect TMA layout.";
    auto stride = as_const_int(shared_layout->InputShape()[0]);
    auto continuous = as_const_int(shared_layout->InputShape()[1]);
    ICHECK(stride != nullptr && continuous != nullptr);
172
    if (StructuralEqual()(shared_layout, makeGemmABLayoutPadded(
173
174
                                             *stride, *continuous,
                                             shared_tensor->dtype.bits()))) {
175
176
177
178
179
      desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
    } else if (StructuralEqual()(
                   shared_layout,
                   makeHalfBankSwizzleLayout(*stride, *continuous,
                                             shared_tensor->dtype.bits()))) {
180
181
182
      desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
    } else if (StructuralEqual()(
                   shared_layout,
183
184
                   makeFullBankSwizzleLayout(*stride, *continuous,
                                             shared_tensor->dtype.bits()))) {
185
186
187
188
189
190
191
192
193
194
195
196
197
198
      desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
    } else {
      ICHECK(0) << "Cannot detect TMA layout.";
    }
  }

  auto inner_box_dim = as_const_int(desc.smem_box[0]);
  ICHECK(inner_box_dim != nullptr);
  int instruction_dim = *inner_box_dim;
  if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B)) {
    instruction_dim = 64 / src->dtype.bytes();
  } else if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B)) {
    instruction_dim = 128 / src->dtype.bytes();
  }
199
200
201
202
203
204
205
  if (instruction_dim > 256) {
    // smem_box dim must be in [0, 256]
    // if is 512, we need to split the copy into two parts
    ICHECK((*inner_box_dim) % 256 == 0)
        << "inner_box_dim: " << *inner_box_dim << " is not divisible by 256";
    instruction_dim = 256;
  }
206
207
208
  ICHECK((*inner_box_dim) % instruction_dim == 0);
  desc.smem_box.Set(0, PrimExpr(instruction_dim));

209
  Call create_descriptor =
210
      Call(DataType::Handle(), create_tma_descriptor(), desc.EncodeCallArgs());
211
212
213
214

  Array<PrimExpr> args;
  args.reserve(desc.rank + 3);
  args.push_back(create_descriptor);
215
216
  if (is_load)
    args.push_back(0); // mbarrier id placeholder
217
  auto op = is_load ? tma_load() : tma_store();
218
219
220
221
222
223
224

  Stmt tma_copy;

  if ((*inner_box_dim) != instruction_dim) {
    Var loop_var("i");
    int loop_extent = (*inner_box_dim) / instruction_dim;
    PrimExpr total_elements = 1;
225
226
227
228
229
    for (auto e : desc.smem_box)
      total_elements *= e;
    PrimExpr shared_addr =
        shared_tensor.access_ptr(is_load ? 2 : 1, DataType::Handle(), 1,
                                 total_elements * loop_var, total_elements);
230
231
    args.push_back(shared_addr);
    global_coords.Set(0, global_coords[0] + instruction_dim * loop_var);
232
233
    for (auto coord : global_coords)
      args.push_back(coord);
234
235
236
237
238
    tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled,
                   Evaluate(Call(DataType::Handle(), op, args)));
  } else {
    PrimExpr shared_addr = shared_tensor.access_ptr(is_load ? 2 : 1);
    args.push_back(shared_addr);
239
240
    for (auto coord : global_coords)
      args.push_back(coord);
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    tma_copy = Evaluate(Call(DataType::Handle(), op, args));
  }
  tma_copy = IfThenElse(EQ(T.thread_var, 0), tma_copy);

  return tma_copy;
}

Array<PrimExpr> TMADesc::EncodeCallArgs() const {
  Array<PrimExpr> args;
  args.reserve(rank * 4 + 7);

  args.push_back(data_type);
  args.push_back(static_cast<int>(rank));
  args.push_back(global_addr);
255
256
257
258
259
260
261
262
  for (auto e : global_shape)
    args.push_back(e);
  for (auto e : global_stride)
    args.push_back(e);
  for (auto e : smem_box)
    args.push_back(e);
  for (auto e : smem_stride)
    args.push_back(e);
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
  args.push_back(interleave);
  args.push_back(swizzle);
  args.push_back(l2_promotion);
  args.push_back(oob_fill);

  return args;
}

DataType cuTensorMapType() { return DataType::UInt(8, 128); }

Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
  src = vmap[GetVarFromAccessPtr(args[0])];
  dst = vmap[GetVarFromAccessPtr(args[1])];
  nhw_step = args[2];
  c_step = args[3];
  kernel = args[4].as<IntImm>().value()->value;
  stride = args[5].as<IntImm>().value()->value;
  dilation = args[6].as<IntImm>().value()->value;
  padding = args[7].as<IntImm>().value()->value;
}

284
285
Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
                           arith::Analyzer *analyzer) const {
286
  ICHECK(TargetIsHopper(T.target));
287
288
  ICHECK(src.scope() == "global" &&
         (dst.scope() == "shared.dyn" || dst.scope() == "shared"));
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
  ICHECK(src->shape.size() == 4);
  ICHECK(dst->shape.size() == 2);
  ICHECK(src->dtype == dst->dtype);
  Layout shared_layout;
  if (T.layout_map.count(dst)) {
    shared_layout = T.layout_map[dst];
  }

  TMAIm2ColDesc desc;
  desc.rank = src->shape.size();
  desc.data_type = to_CUtensorMapDataType(src->dtype);
  desc.global_addr = src->data;
  desc.global_shape = ReverseArray(src->shape);
  if (!src->strides.empty()) {
    desc.global_stride = ReverseArray(src->strides);
  } else {
    // Create stride from shape
    PrimExpr stride = 1;
    desc.global_stride.reserve(desc.rank);
    for (size_t i = 0; i < desc.rank; i++) {
      desc.global_stride.push_back(stride);
      stride *= desc.global_shape[i];
    }
  }
  // The first stride element should be 1
  ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
  // Make global stride in bytes
316
317
  desc.global_stride = desc.global_stride.Map(
      [&](PrimExpr e) { return e * src->dtype.bytes(); });
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
  desc.elem_stride = {1, stride, stride, 1};
  desc.lower_corner = {-padding, -padding};
  desc.upper_corner = {-padding, -padding};
  desc.smem_box_pixel = Downcast<IntImm>(dst->shape[0])->value;
  desc.smem_box_channel = Downcast<IntImm>(dst->shape[1])->value;
  desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
  desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
  desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
  if (!shared_layout.defined()) {
    desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
  } else {
    ICHECK(shared_layout->InputDim() == 2) << "Cannot detect TMA layout.";
    auto stride = as_const_int(shared_layout->InputShape()[0]);
    auto continuous = as_const_int(shared_layout->InputShape()[1]);
    ICHECK(stride != nullptr && continuous != nullptr);
    if (StructuralEqual()(shared_layout,
334
335
                          makeHalfBankSwizzleLayout(*stride, *continuous,
                                                    dst->dtype.bits()))) {
336
      desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
337
338
339
    } else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout(
                                                    *stride, *continuous,
                                                    dst->dtype.bits()))) {
340
341
342
343
344
345
      desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
    } else {
      ICHECK(0) << "Cannot detect TMA layout.";
    }
  }

346
  Call create_desc = Call(DataType::Handle(), create_tma_im2col_descriptor(),
347
                          desc.EncodeCallArgs());
348

349
350
  Array<PrimExpr> global_coords; // c, w, h, n
  Array<PrimExpr> image_offset;  // w, h
351
352
  global_coords.reserve(desc.rank);

353
354
  ICHECK(analyzer->CanProveEqual(
      FloorMod(desc.global_shape[0], desc.smem_box_channel), 0))
355
356
      << "Currently can only support divisible channel case";

357
358
  global_coords.push_back(
      FloorMod(c_step * desc.smem_box_channel, desc.global_shape[0]));
359
  image_offset.push_back(
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
      dilation *
      FloorMod(FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0]),
               kernel));
  image_offset.push_back(dilation * FloorDiv(c_step * desc.smem_box_channel,
                                             desc.global_shape[0] * kernel));

  PrimExpr h_dim =
      FloorDiv(src->shape[1] + 2 * padding - (kernel - 1) * dilation - 1,
               stride) +
      1;
  PrimExpr w_dim =
      FloorDiv(src->shape[2] + 2 * padding - (kernel - 1) * dilation - 1,
               stride) +
      1;
  global_coords.push_back(
      stride * FloorMod(nhw_step * desc.smem_box_pixel, w_dim) - padding);
  global_coords.push_back(
      stride *
          FloorMod(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim), h_dim) -
      padding);
380
  global_coords.push_back(
381
      FloorDiv(nhw_step * desc.smem_box_pixel, w_dim * h_dim));
382
383
384
385

  Array<PrimExpr> args;
  args.reserve(desc.rank * 2 + 1);
  args.push_back(create_desc);
386
  args.push_back(0); // mbar placeholder
387
388
389
  auto dst_buffer = T.buffer_remap.count(dst) ? T.buffer_remap[dst] : dst;
  auto shared_addr = dst_buffer.access_ptr(2);
  args.push_back(shared_addr);
390
391
392
393
  for (auto coord : global_coords)
    args.push_back(coord);
  for (auto offset : image_offset)
    args.push_back(offset);
394
395

  Stmt tma_copy =
396
      IfThenElse(EQ(T.thread_var, 0),
397
                 Evaluate(Call(DataType::Handle(), tma_load_im2col(), args)));
398
399
400
401
402
403
404
405
406
407
  return tma_copy;
}

Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
  Array<PrimExpr> args;
  args.reserve(rank * 5 + 5);

  args.push_back(data_type);
  args.push_back(static_cast<int>(rank));
  args.push_back(global_addr);
408
409
410
411
412
413
414
415
416
417
  for (auto e : global_shape)
    args.push_back(e);
  for (auto e : global_stride)
    args.push_back(e);
  for (auto e : elem_stride)
    args.push_back(e);
  for (auto e : lower_corner)
    args.push_back(e);
  for (auto e : upper_corner)
    args.push_back(e);
418
419
420
421
422
423
424
425
426
427
428
429
  args.push_back(smem_box_pixel);
  args.push_back(smem_box_channel);
  args.push_back(interleave);
  args.push_back(swizzle);
  args.push_back(l2_promotion);
  args.push_back(oob_fill);

  return args;
}

TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
    .set_num_inputs(8)
430
431
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
432

433
434
} // namespace tl
} // namespace tvm