copy.h 3.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once

#include "common.h"

using f32 = float;
// using f16 = _Float16;

using u8 = std::uint8_t;
using u16 = std::uint16_t;
using u32 = std::uint32_t;

using index_t = u32;

using ck_tile::int32x4_t;

struct __attribute__((packed)) buffer_resource {
19
  const void *ptr;
20
21
22
23
  uint32_t range;
  uint32_t config;
};

24
25
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr,
                                                   uint32_t size = 0xffffffff) {
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
  buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
  int32x4_t r = __builtin_bit_cast(int32x4_t, res);
  r.x = __builtin_amdgcn_readfirstlane(r.x);
  r.y = __builtin_amdgcn_readfirstlane(r.y);
  r.z = __builtin_amdgcn_readfirstlane(r.z);
  r.w = __builtin_amdgcn_readfirstlane(r.w);
  return r;
}

__device__ void init_m0(uint32_t m0_value) {
  asm volatile("s_mov_b32 m0, %0" : : "s"(m0_value) : "memory");
}

__device__ void inc_m0(uint32_t m0_inc) {
  asm volatile("s_add_u32 m0, %0, m0" : : "n"(m0_inc) : "memory");
}

namespace tl {

// AMDGPU automatically commit memory fence
TL_DEVICE void cp_async_commit() {}

// Global Memory only fence
__device__ void async_gld_fence(index_t cnt) {
  asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}

// Global Memory and Shared Memory fence
__device__ void async_gld_sld_fence(index_t cnt) {
  asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}

__device__ void wave_barrier() { asm volatile("s_barrier" : : : "memory"); }

60
template <int N = 0> TL_DEVICE void cp_async_wait() {
61
62
63
64
65
66
  async_gld_fence(N);
  // or
  // async_gld_sld_fence(N);
}

template <bool pre_nop = false>
67
68
69
70
71
72
73
74
CK_TILE_DEVICE void async_buffer_load_dword_v(void *smem, int32x4_t rsrc,
                                              index_t voffset) {
  auto const lds_ptr_sgpr =
      __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(smem)));
  asm volatile("s_mov_b32 m0, %0; \n\t"
               "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
               "v"(voffset), "s"(rsrc)
               : "memory");
75
76
77
}

template <int N>
78
79
80
81
82
83
84
85
86
87
TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) {
  if constexpr (N == 16) {
    *(uint4 *)lds_base_ptr = *(uint4 *)global_base_ptr;
  } else if constexpr (N == 8) {
    *(uint2 *)lds_base_ptr = *(uint2 *)global_base_ptr;
  } else if constexpr (N == 4) {
    async_buffer_load_dword_v(
        lds_base_ptr,
        make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
        threadIdx.x * N /*assume 4 bytes*/);
88
89
90
91
  }
}

template <int N>
92
93
94
95
96
97
98
99
100
TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
                                       void *global_base_ptr, bool cond) {
  if constexpr (N == 16) {
    *(uint4 *)lds_base_ptr =
        cond ? *(uint4 *)global_base_ptr : make_uint4(0, 0, 0, 0);
  } else if constexpr (N == 8) {
    *(uint2 *)lds_base_ptr =
        cond ? *(uint2 *)global_base_ptr : make_uint2(0, 0);
  } else {
101
    if (cond) {
102
103
104
105
106
107
      async_buffer_load_dword_v(
          lds_base_ptr,
          make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
          threadIdx.x * N /*assume 4 bytes*/);
    } else {
      *(uint4 *)lds_base_ptr = make_uint4(0, 0, 0, 0);
108
109
110
111
    }
  }
}

112
} // namespace tl