copy_sm100.h 4.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#pragma once
#include "cuda_fp8.h"
#include "tcgen_05.h"
#include "tcgen_05_ld.h"

namespace tl {

__device__ __forceinline__ longlong4 ld_global_256(const longlong4 *ptr) {
  longlong4 ret;
  asm volatile("ld.global.v4.s64 {%0, %1, %2, %3}, [%4];"
               : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
               : "l"(ptr));
  return ret;
}

__device__ __forceinline__ void st_global_256(longlong4 *ptr, longlong4 &val) {
  asm volatile("st.global.v4.s64 [%0], {%1, %2, %3, %4};"
               :
               : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}

__device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) {
  ulonglong4 ret;
  asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
               : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
               : "l"(ptr));
  return ret;
}

// must be const &val, otherwise the compiler will generate a temporary variable
// and compilation will fail if we have st_global_256(ptr, ld_global_256(ptr))
__device__ __forceinline__ void st_global_256(ulonglong4 *ptr,
                                              const ulonglong4 &val) {
  asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
               :
               : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}

__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e4_32_t *ptr) {
  ulonglong4 ret;
  asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
               : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
               : "l"(ptr));
  return ret;
}

__device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr,
                                              fp8_e4_32_t &val8) {
  ulonglong4 &val = *((ulonglong4 *)&val8);
  asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
               :
               : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}

__device__ __forceinline__ unsigned long long
pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z,
                const bfloat16_t w) {
  unsigned long long v0 = *((unsigned short *)&x);
  unsigned long long v1 = *((unsigned short *)&y);
  unsigned long long v2 = *((unsigned short *)&z);
  unsigned long long v3 = *((unsigned short *)&w);
  return (v0 | (v1 << 16) | (v2 << 32) | (v3 << 48));
}

__device__ __forceinline__ unsigned long long
pack_float16x4(const half x, const half y, const half z, const half w) {
  unsigned long long v0 = *((unsigned short *)&x);
  unsigned long long v1 = *((unsigned short *)&y);
  unsigned long long v2 = *((unsigned short *)&z);
  unsigned long long v3 = *((unsigned short *)&w);
  return (v0 | (v1 << 16) | (v2 << 32) | (v3 << 48));
}

// Helper function to find the largest K that 2**K <= N
// Requires N > 0
template <int N, int K = 0>
__device__ __forceinline__ constexpr int get_floor_log2() {
  static_assert(N > 0);
  if constexpr ((1 << (K + 1)) > N)
    return K;
  else
    return get_floor_log2<N, K + 1>();
}

template <typename target_call_cls, int MAX_LOGN, int N, typename dst_t>
__device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col,
                                                dst_t *dst_ptr) {
  static_assert(N > 0);
  constexpr int LOG_N = get_floor_log2<N>();
  constexpr int CUR_SEGMENT_LEN = 1 << (LOG_N > MAX_LOGN ? MAX_LOGN : LOG_N);
  target_call_cls::copy<CUR_SEGMENT_LEN>(tmem_start_col, (uint32_t *)dst_ptr);
  if constexpr (N - CUR_SEGMENT_LEN > 0) {
    tcgen05_ld_core<target_call_cls, MAX_LOGN, N - CUR_SEGMENT_LEN>(
        tmem_start_col + CUR_SEGMENT_LEN, dst_ptr + CUR_SEGMENT_LEN);
  }
}

template <int N, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col,
                     uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
  tcgen05_ld_core<tl::tmem_ld_32dp32bNx, 7, N>(tmem_start_col + tmem_col_offset,
                                               dst_ptr);
  tl::fence_view_async_tmem_load();
}

template <int N, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col,
                     uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
  tcgen05_ld_core<tl::tmem_ld_32dp64bNx, 7, N>(tmem_start_col + tmem_col_offset,
                                               dst_ptr);
  tl::fence_view_async_tmem_load();
}

template <int N, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col,
                      uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
  tcgen05_ld_core<tl::tmem_ld_32dp128bNx, 6, N>(
      tmem_start_col + tmem_col_offset, dst_ptr);
  tl::fence_view_async_tmem_load();
}

template <int N, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col,
                      uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
  tcgen05_ld_core<tl::tmem_ld_32dp256bNx, 5, N>(
      tmem_start_col + tmem_col_offset, dst_ptr);
  tl::fence_view_async_tmem_load();
}

} // namespace tl