copy.h 1.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
#pragma once

#include "common.h"

#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
#include "copy_sm90.h"
#endif

namespace tl {

11
12
13
TL_DEVICE void cp_async_commit() {
  asm volatile("cp.async.commit_group;\n" ::);
}
14

15
template <int N> TL_DEVICE void cp_async_wait() {
16
17
18
19
20
21
22
23
  if constexpr (N == 0) {
    asm volatile("cp.async.wait_all;\n" ::);
  } else {
    asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
  }
}

template <int N>
24
TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
25
26
27
  static_assert(N == 16 || N == 8 || N == 4);
  unsigned int addr = smem_ptr_to_uint(smem_addr);
  if constexpr (N == 16) {
28
    asm volatile(
29
30
31
32
33
34
#if TL_ENABLE_L2_PREFETCH
        "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
#else
        "cp.async.cg.shared.global [%0], [%1], %2;"
#endif
        ::"r"(addr),
35
        "l"((void *)(global_ptr)), "n"(N));
36
  } else {
37
    asm volatile(
38
39
40
41
42
43
#if TL_ENABLE_L2_PREFETCH
        "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
#else
        "cp.async.ca.shared.global [%0], [%1], %2;"
#endif
        ::"r"(addr),
44
        "l"((void *)(global_ptr)), "n"(N));
45
46
47
48
  }
}

template <int N>
49
50
TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
                                       void *global_ptr, bool cond) {
51
52
53
54
  static_assert(N == 16 || N == 8 || N == 4);
  int bytes = cond ? N : 0;
  unsigned int addr = smem_ptr_to_uint(smem_addr);
  if constexpr (N == 16) {
55
    asm volatile(
56
57
58
59
60
61
#if TL_ENABLE_L2_PREFETCH
        "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
        "cp.async.cg.shared.global [%0], [%1], %2, %3;"
#endif
        ::"r"(addr),
62
        "l"((void *)(global_ptr)), "n"(N), "r"(bytes));
63
  } else {
64
    asm volatile(
65
66
67
68
69
70
#if TL_ENABLE_L2_PREFETCH
        "cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
        "cp.async.ca.shared.global [%0], [%1], %2, %3;"
#endif
        ::"r"(addr),
71
        "l"((void *)(global_ptr)), "n"(N), "r"(bytes));
72
73
74
  }
}

75
} // namespace tl