copy.h 2.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once

#include "common.h"

#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
#include "copy_sm90.h"
#endif

namespace tl {

13
14
15
TL_DEVICE void cp_async_commit() {
  asm volatile("cp.async.commit_group;\n" ::);
}
16

17
template <int N> TL_DEVICE void cp_async_wait() {
18
19
20
21
22
23
24
25
  if constexpr (N == 0) {
    asm volatile("cp.async.wait_all;\n" ::);
  } else {
    asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
  }
}

template <int N>
26
TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
27
28
29
30
31
32
33
34
35
36
  static_assert(N == 16 || N == 8 || N == 4);
  unsigned int addr = smem_ptr_to_uint(smem_addr);
  if constexpr (N == 16) {
    __asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH
        "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
#else
        "cp.async.cg.shared.global [%0], [%1], %2;"
#endif
        ::"r"(addr),
37
        "l"((void *)(global_ptr)), "n"(N));
38
39
40
41
42
43
44
45
  } else {
    __asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH
        "cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
#else
        "cp.async.ca.shared.global [%0], [%1], %2;"
#endif
        ::"r"(addr),
46
        "l"((void *)(global_ptr)), "n"(N));
47
48
49
50
  }
}

template <int N>
51
52
TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
                                       void *global_ptr, bool cond) {
53
54
55
56
57
58
59
60
61
62
63
  static_assert(N == 16 || N == 8 || N == 4);
  int bytes = cond ? N : 0;
  unsigned int addr = smem_ptr_to_uint(smem_addr);
  if constexpr (N == 16) {
    __asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH
        "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
        "cp.async.cg.shared.global [%0], [%1], %2, %3;"
#endif
        ::"r"(addr),
64
        "l"((void *)(global_ptr)), "n"(N), "r"(bytes));
65
66
67
68
69
70
71
72
  } else {
    __asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH
        "cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
        "cp.async.ca.shared.global [%0], [%1], %2, %3;"
#endif
        ::"r"(addr),
73
        "l"((void *)(global_ptr)), "n"(N), "r"(bytes));
74
75
76
  }
}

77
} // namespace tl