copy_sm100.h 5.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#pragma once
#include "cuda_fp8.h"
#include "tcgen_05.h"
#include "tcgen_05_ld.h"

namespace tl {

__device__ __forceinline__ longlong4 ld_global_256(const longlong4 *ptr) {
  longlong4 ret;
  asm volatile("ld.global.v4.s64 {%0, %1, %2, %3}, [%4];"
               : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
               : "l"(ptr));
  return ret;
}

__device__ __forceinline__ void st_global_256(longlong4 *ptr, longlong4 &val) {
  asm volatile("st.global.v4.s64 [%0], {%1, %2, %3, %4};"
               :
               : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}

__device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) {
  ulonglong4 ret;
  asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
               : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
               : "l"(ptr));
  return ret;
}

// must be const &val, otherwise the compiler will generate a temporary variable
// and compilation will fail if we have st_global_256(ptr, ld_global_256(ptr))
__device__ __forceinline__ void st_global_256(ulonglong4 *ptr,
                                              const ulonglong4 &val) {
  asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
               :
               : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}

__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e4_32_t *ptr) {
  ulonglong4 ret;
  asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
               : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
               : "l"(ptr));
  return ret;
}

__device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr,
                                              fp8_e4_32_t &val8) {
  ulonglong4 &val = *((ulonglong4 *)&val8);
  asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
               :
               : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e5_32_t *ptr) {
  ulonglong4 ret;
  asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
               : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
               : "l"(ptr));
  return ret;
}

__device__ __forceinline__ void st_global_256(fp8_e5_32_t *ptr,
                                              fp8_e5_32_t &val8) {
  ulonglong4 &val = *((ulonglong4 *)&val8);
  asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
               :
               : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

__device__ __forceinline__ unsigned long long
pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z,
                const bfloat16_t w) {
  unsigned long long v0 = *((unsigned short *)&x);
  unsigned long long v1 = *((unsigned short *)&y);
  unsigned long long v2 = *((unsigned short *)&z);
  unsigned long long v3 = *((unsigned short *)&w);
  return (v0 | (v1 << 16) | (v2 << 32) | (v3 << 48));
}

__device__ __forceinline__ unsigned long long
pack_float16x4(const half x, const half y, const half z, const half w) {
  unsigned long long v0 = *((unsigned short *)&x);
  unsigned long long v1 = *((unsigned short *)&y);
  unsigned long long v2 = *((unsigned short *)&z);
  unsigned long long v3 = *((unsigned short *)&w);
  return (v0 | (v1 << 16) | (v2 << 32) | (v3 << 48));
}

// Helper function to find the largest K that 2**K <= N
// Requires N > 0
template <int N, int K = 0>
__device__ __forceinline__ constexpr int get_floor_log2() {
  static_assert(N > 0);
  if constexpr ((1 << (K + 1)) > N)
    return K;
  else
    return get_floor_log2<N, K + 1>();
}

template <typename target_call_cls, int MAX_LOGN, int N, typename dst_t>
__device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col,
                                                dst_t *dst_ptr) {
  static_assert(N > 0);
  constexpr int LOG_N = get_floor_log2<N>();
  constexpr int CUR_SEGMENT_LEN = 1 << (LOG_N > MAX_LOGN ? MAX_LOGN : LOG_N);
  target_call_cls::copy<CUR_SEGMENT_LEN>(tmem_start_col, (uint32_t *)dst_ptr);
  if constexpr (N - CUR_SEGMENT_LEN > 0) {
    tcgen05_ld_core<target_call_cls, MAX_LOGN, N - CUR_SEGMENT_LEN>(
        tmem_start_col + CUR_SEGMENT_LEN, dst_ptr + CUR_SEGMENT_LEN);
  }
}

113
template <int N, bool pack16, typename dst_t>
114
115
116
__device__ __forceinline__ void
tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col,
                     uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
117
118
  tcgen05_ld_core<tl::tmem_ld_32dp32bNx<pack16>, 7, N>(
      tmem_start_col + tmem_col_offset, dst_ptr);
119
120
121
  tl::fence_view_async_tmem_load();
}

122
template <int N, bool pack16, typename dst_t>
123
124
125
__device__ __forceinline__ void
tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col,
                     uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
126
127
  tcgen05_ld_core<tl::tmem_ld_32dp64bNx<pack16>, 7, N>(
      tmem_start_col + tmem_col_offset, dst_ptr);
128
129
130
  tl::fence_view_async_tmem_load();
}

131
template <int N, bool pack16, typename dst_t>
132
133
134
__device__ __forceinline__ void
tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col,
                      uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
135
  tcgen05_ld_core<tl::tmem_ld_32dp128bNx<pack16>, 6, N>(
136
137
138
139
      tmem_start_col + tmem_col_offset, dst_ptr);
  tl::fence_view_async_tmem_load();
}

140
template <int N, bool pack16, typename dst_t>
141
142
143
__device__ __forceinline__ void
tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col,
                      uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
144
  tcgen05_ld_core<tl::tmem_ld_32dp256bNx<pack16>, 5, N>(
145
146
147
148
149
      tmem_start_col + tmem_col_offset, dst_ptr);
  tl::fence_view_async_tmem_load();
}

} // namespace tl