copy.h 3.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#pragma once

#include "common.h"

using f32 = float;
// using f16 = _Float16;

using u8 = std::uint8_t;
using u16 = std::uint16_t;
using u32 = std::uint32_t;

using index_t = u32;

using ck_tile::int32x4_t;

struct __attribute__((packed)) buffer_resource {
17
  const void *ptr;
18
19
20
21
  uint32_t range;
  uint32_t config;
};

22
23
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr,
                                                   uint32_t size = 0xffffffff) {
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
  buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
  int32x4_t r = __builtin_bit_cast(int32x4_t, res);
  r.x = __builtin_amdgcn_readfirstlane(r.x);
  r.y = __builtin_amdgcn_readfirstlane(r.y);
  r.z = __builtin_amdgcn_readfirstlane(r.z);
  r.w = __builtin_amdgcn_readfirstlane(r.w);
  return r;
}

__device__ void init_m0(uint32_t m0_value) {
  asm volatile("s_mov_b32 m0, %0" : : "s"(m0_value) : "memory");
}

__device__ void inc_m0(uint32_t m0_inc) {
  asm volatile("s_add_u32 m0, %0, m0" : : "n"(m0_inc) : "memory");
}

namespace tl {

// AMDGPU automatically commit memory fence
TL_DEVICE void cp_async_commit() {}

// Global Memory only fence
__device__ void async_gld_fence(index_t cnt) {
  asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}

// Global Memory and Shared Memory fence
__device__ void async_gld_sld_fence(index_t cnt) {
  asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}

__device__ void wave_barrier() { asm volatile("s_barrier" : : : "memory"); }

58
template <int N = 0> TL_DEVICE void cp_async_wait() {
59
60
61
62
63
64
  async_gld_fence(N);
  // or
  // async_gld_sld_fence(N);
}

template <bool pre_nop = false>
65
66
67
68
69
70
71
72
CK_TILE_DEVICE void async_buffer_load_dword_v(void *smem, int32x4_t rsrc,
                                              index_t voffset) {
  auto const lds_ptr_sgpr =
      __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(smem)));
  asm volatile("s_mov_b32 m0, %0; \n\t"
               "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
               "v"(voffset), "s"(rsrc)
               : "memory");
73
74
75
}

template <int N>
76
77
78
79
80
81
82
83
84
85
TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) {
  if constexpr (N == 16) {
    *(uint4 *)lds_base_ptr = *(uint4 *)global_base_ptr;
  } else if constexpr (N == 8) {
    *(uint2 *)lds_base_ptr = *(uint2 *)global_base_ptr;
  } else if constexpr (N == 4) {
    async_buffer_load_dword_v(
        lds_base_ptr,
        make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
        threadIdx.x * N /*assume 4 bytes*/);
86
87
88
89
  }
}

template <int N>
90
91
92
93
94
95
96
97
98
TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
                                       void *global_base_ptr, bool cond) {
  if constexpr (N == 16) {
    *(uint4 *)lds_base_ptr =
        cond ? *(uint4 *)global_base_ptr : make_uint4(0, 0, 0, 0);
  } else if constexpr (N == 8) {
    *(uint2 *)lds_base_ptr =
        cond ? *(uint2 *)global_base_ptr : make_uint2(0, 0);
  } else {
99
    if (cond) {
100
101
102
103
104
105
      async_buffer_load_dword_v(
          lds_base_ptr,
          make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
          threadIdx.x * N /*assume 4 bytes*/);
    } else {
      *(uint4 *)lds_base_ptr = make_uint4(0, 0, 0, 0);
106
107
108
109
    }
  }
}

110
} // namespace tl