copy.h 2 KB
Newer Older
1
2
3
4
#pragma once

#include "common.h"

5
6
#ifdef __CUDA_ARCH_LIST__
#if __CUDA_ARCH_LIST__ >= 900
7
8
#include "copy_sm90.h"
#endif
9
10
11
12
#if __CUDA_ARCH_LIST__ >= 1000
#include "copy_sm100.h"
#endif
#endif
13
14
15

namespace tl {

16
17
18
TL_DEVICE void cp_async_commit() {
  asm volatile("cp.async.commit_group;\n" ::);
}
19

20
template <int N> TL_DEVICE void cp_async_wait() {
21
22
23
24
25
26
27
28
  if constexpr (N == 0) {
    asm volatile("cp.async.wait_all;\n" ::);
  } else {
    asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
  }
}

template <int N>
29
TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
30
31
32
  static_assert(N == 16 || N == 8 || N == 4);
  unsigned int addr = smem_ptr_to_uint(smem_addr);
  if constexpr (N == 16) {
33
    asm volatile(
34
35
36
37
38
39
#if TL_ENABLE_L2_PREFETCH
        "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
#else
        "cp.async.cg.shared.global [%0], [%1], %2;"
#endif
        ::"r"(addr),
40
        "l"((void *)(global_ptr)), "n"(N));
41
  } else {
42
    asm volatile(
43
44
45
46
47
48
#if TL_ENABLE_L2_PREFETCH
        "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
#else
        "cp.async.ca.shared.global [%0], [%1], %2;"
#endif
        ::"r"(addr),
49
        "l"((void *)(global_ptr)), "n"(N));
50
51
52
53
  }
}

template <int N>
54
55
TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
                                       void *global_ptr, bool cond) {
56
57
58
59
  static_assert(N == 16 || N == 8 || N == 4);
  int bytes = cond ? N : 0;
  unsigned int addr = smem_ptr_to_uint(smem_addr);
  if constexpr (N == 16) {
60
    asm volatile(
61
62
63
64
65
66
#if TL_ENABLE_L2_PREFETCH
        "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
        "cp.async.cg.shared.global [%0], [%1], %2, %3;"
#endif
        ::"r"(addr),
67
        "l"((void *)(global_ptr)), "n"(N), "r"(bytes));
68
  } else {
69
    asm volatile(
70
71
72
73
74
75
#if TL_ENABLE_L2_PREFETCH
        "cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
        "cp.async.ca.shared.global [%0], [%1], %2, %3;"
#endif
        ::"r"(addr),
76
        "l"((void *)(global_ptr)), "n"(N), "r"(bytes));
77
78
79
  }
}

80
} // namespace tl