copy.h 2.06 KB
Newer Older
1
2
3
4
#pragma once

#include "common.h"

5
6
#ifdef __CUDA_ARCH_LIST__
#if __CUDA_ARCH_LIST__ >= 900
7
8
#include "copy_sm90.h"
#endif
9
10
11
12
#if __CUDA_ARCH_LIST__ >= 1000
#include "copy_sm100.h"
#endif
#endif
13
14
15

namespace tl {

16
17
18
TL_DEVICE void cp_async_commit() {
  asm volatile("cp.async.commit_group;\n" ::);
}
19

20
template <int N> TL_DEVICE void cp_async_wait() {
21
22
23
24
25
26
27
28
  if constexpr (N == 0) {
    asm volatile("cp.async.wait_all;\n" ::);
  } else {
    asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
  }
}

template <int N>
29
30
TL_DEVICE void cp_async_gs(void const *const smem_addr,
                           void const *global_ptr) {
31
32
33
  static_assert(N == 16 || N == 8 || N == 4);
  unsigned int addr = smem_ptr_to_uint(smem_addr);
  if constexpr (N == 16) {
34
    asm volatile(
35
36
37
38
39
40
#if TL_ENABLE_L2_PREFETCH
        "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
#else
        "cp.async.cg.shared.global [%0], [%1], %2;"
#endif
        ::"r"(addr),
41
        "l"((void const *)(global_ptr)), "n"(N));
42
  } else {
43
    asm volatile(
44
45
46
47
48
49
#if TL_ENABLE_L2_PREFETCH
        "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
#else
        "cp.async.ca.shared.global [%0], [%1], %2;"
#endif
        ::"r"(addr),
50
        "l"((void const *)(global_ptr)), "n"(N));
51
52
53
54
  }
}

template <int N>
55
TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
56
                                       void const *global_ptr, bool cond) {
57
58
59
60
  static_assert(N == 16 || N == 8 || N == 4);
  int bytes = cond ? N : 0;
  unsigned int addr = smem_ptr_to_uint(smem_addr);
  if constexpr (N == 16) {
61
    asm volatile(
62
63
64
65
66
67
#if TL_ENABLE_L2_PREFETCH
        "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
        "cp.async.cg.shared.global [%0], [%1], %2, %3;"
#endif
        ::"r"(addr),
68
        "l"((void const *)(global_ptr)), "n"(N), "r"(bytes));
69
  } else {
70
    asm volatile(
71
72
73
74
75
76
#if TL_ENABLE_L2_PREFETCH
        "cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
        "cp.async.ca.shared.global [%0], [%1], %2, %3;"
#endif
        ::"r"(addr),
77
        "l"((void const *)(global_ptr)), "n"(N), "r"(bytes));
78
79
80
  }
}

81
} // namespace tl