copy.h 7.3 KB
Newer Older
Lukinon's avatar
Lukinon committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#pragma once

#include "common.h"

using f32 = float;
// using f16 = _Float16;

using u8 = std::uint8_t;
using u16 = std::uint16_t;
using u32 = std::uint32_t;

using index_t = u32;

using ck_tile::int32x4_t;

struct __attribute__((packed)) buffer_resource {
  const void *ptr;
  uint32_t range;
  uint32_t config;
};

CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr,
                                                   uint32_t size = 0xffffffff) {
  buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
  int32x4_t r = __builtin_bit_cast(int32x4_t, res);
  r.x = __builtin_amdgcn_readfirstlane(r.x);
  r.y = __builtin_amdgcn_readfirstlane(r.y);
  r.z = __builtin_amdgcn_readfirstlane(r.z);
  r.w = __builtin_amdgcn_readfirstlane(r.w);
  return r;
}

__device__ void init_m0(uint32_t m0_value) {
  asm volatile("s_mov_b32 m0, %0" : : "s"(m0_value) : "memory");
}

__device__ void inc_m0(uint32_t m0_inc) {
  asm volatile("s_add_u32 m0, %0, m0" : : "n"(m0_inc) : "memory");
}

qisan's avatar
qisan committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#define UPDATE_WAVE_BUFFER_RESOURCE(res, stride)                       \
    do {                                                               \
        /* 1. 提取 64 位基地址,确保低位不进行符号位扩展 */               \
        uint64_t __current_addr = (static_cast<uint64_t>((res).y) << 32) | \
                                  (static_cast<uint32_t>((res).x));    \
                                                                       \
        /* 2. 增加步长 (自动处理类型提升) */                             \
        __current_addr += (stride);                                    \
                                                                       \
        /* 3. 写回分量到 SGPRs */                                       \
        (res).x = static_cast<int32_t>(__current_addr);                \
        (res).y = static_cast<int32_t>(__current_addr >> 32);          \
    } while (0)

Lukinon's avatar
Lukinon committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
namespace tl {

// AMDGPU automatically commit memory fence
TL_DEVICE void cp_async_commit() {}

// Global Memory only fence
__device__ void async_gld_fence(index_t cnt) {
  asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}

// Global Memory and Shared Memory fence
__device__ void async_gld_sld_fence(index_t cnt) {
  asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}

__device__ void wave_barrier() { asm volatile("s_barrier" : : : "memory"); }

template <int N = 0> TL_DEVICE void cp_async_wait() {
  async_gld_fence(N);
  // or
  // async_gld_sld_fence(N);
}

template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void *smem, int32x4_t rsrc,
                                              index_t voffset) {
  auto const lds_ptr_sgpr =
      __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(smem)));
  asm volatile("s_mov_b32 m0, %0; \n\t"
               "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
               "v"(voffset), "s"(rsrc)
               : "memory");
}

qisan's avatar
qisan committed
89
90
template <int N, int smem_offset, int load_count, int i_sstride, int i_gstride, int k_gstride>
TL_DEVICE void cp_async_gs(void *lds_base_ptr, int32x4_t res, int offset) {
Lukinon's avatar
Lukinon committed
91
  if constexpr (N == 16) {
qisan's avatar
qisan committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    if constexpr (load_count == 1){
        async_buffer_load_dwordx4_v<smem_offset>(
            lds_base_ptr,
            res,
            offset
          );
        UPDATE_WAVE_BUFFER_RESOURCE(res, k_gstride);
      }
      else if constexpr (load_count == 2){
        async_buffer_load_dwordx4_v<smem_offset>(
            lds_base_ptr,
            res,
            current_offset
          );
        UPDATE_WAVE_BUFFER_RESOURCE(res, i_gstride);
        async_buffer_load_dwordx4_v<smem_offset + i_sstride>(
            lds_base_ptr,
            res,
            current_offset
          );
        UPDATE_WAVE_BUFFER_RESOURCE(res, k_gstride - i_gstride);
      }
      else if constexpr (load_count == 4){
        async_buffer_load_dwordx4_v<smem_offset>(
            lds_base_ptr,
            res,
            current_offset
          );
        UPDATE_WAVE_BUFFER_RESOURCE(res, i_gstride);
        async_buffer_load_dwordx4_v<smem_offset + i_sstride>(
            lds_base_ptr,
            res,
            current_offset
          );
        UPDATE_WAVE_BUFFER_RESOURCE(res, i_gstride);
        async_buffer_load_dwordx4_v<smem_offset + 2 * i_sstride>(
            lds_base_ptr,
            res,
            current_offset
          );
        UPDATE_WAVE_BUFFER_RESOURCE(res, i_gstride);
        async_buffer_load_dwordx4_v<smem_offset + 3 * i_sstride>(
            lds_base_ptr,
            res,
            current_offset
          );
        UPDATE_WAVE_BUFFER_RESOURCE(res, k_gstride - 3 * i_gstride);
      }
      else {
        #pragma unroll
        for (int i = 0; i < load_count - 1; ++i) {
          async_buffer_load_dwordx4_v<smem_offset>(
              lds_base_ptr + i * i_sstride,
              res,
              current_offset
            );
          UPDATE_WAVE_BUFFER_RESOURCE(res, i_gstride);
        }
        async_buffer_load_dwordx4_v<smem_offset>(
            lds_base_ptr + (load_count - 1) * i_sstride,
            res,
            current_offset
          );
        UPDATE_WAVE_BUFFER_RESOURCE(res, k_gstride - (load_count - 1) * i_gstride);

      }
  }
  else {
    not implemented;
Lukinon's avatar
Lukinon committed
161
162
163
  }
}

qisan's avatar
qisan committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

template <int N>
// TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) {
//   if constexpr (N == 16) {
//     *(uint4 *)lds_base_ptr = *(uint4 *)global_base_ptr;
//   } else if constexpr (N == 8) {
//     *(uint2 *)lds_base_ptr = *(uint2 *)global_base_ptr;
//   } else if constexpr (N == 4) {
//     async_buffer_load_dword_v(
//         lds_base_ptr,
//         make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
//         threadIdx.x * N /*assume 4 bytes*/);
//   }
// }

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
template <int M, int N, int offset>
TL_DEVICE void ds_read_vector(float4_& dst, uint32_t  lds_base_ptr)
{
  if constexpr (M == 16 && N == 32)
  {
    const int offset_in_bytes = offset * sizeof(half_t);
    asm volatile("ds_read_m32x16_b16 %0, %1 offset:%2\n\t"
                 : "+v"(dst)
                 : "v"(lds_base_ptr),
                  "n"(offset_in_bytes)
                 : "memory");
  }
  else if constexpr (M == 32 && N == 16)
  {
    const int offset_in_bytes0 = offset * sizeof(half_t);
    const int offset_in_bytes1 = offset_in_bytes0 + 4096;
    float2_& front = *reinterpret_cast<float2_*>(&dst);
    float2_& rear  = *(reinterpret_cast<float2_*>(&dst) + 1);

    asm volatile(
      "ds_read_b64 %1, %2 offset:%3\n\t"
      "ds_read_b64 %0, %2 offset:%4\n\t"

      : "+v"(rear), "+v"(front)
      : "v"(lds_base_ptr), "n"(offset_in_bytes0), "n"(offset_in_bytes1)
      : "memory"
    );
  }
}

Lukinon's avatar
Lukinon committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
template <int N>
TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
                                       void *global_base_ptr, bool cond) {
  if constexpr (N == 16) {
    *(uint4 *)lds_base_ptr =
        cond ? *(uint4 *)global_base_ptr : make_uint4(0, 0, 0, 0);
  } else if constexpr (N == 8) {
    *(uint2 *)lds_base_ptr =
        cond ? *(uint2 *)global_base_ptr : make_uint2(0, 0);
  } else {
    if (cond) {
      async_buffer_load_dword_v(
          lds_base_ptr,
          make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
          threadIdx.x * N /*assume 4 bytes*/);
    } else {
      *(uint4 *)lds_base_ptr = make_uint4(0, 0, 0, 0);
    }
  }
}

230
231


Lukinon's avatar
Lukinon committed
232
} // namespace tl