Unverified Commit 8182976c authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[CK_TILE] wa prec, remove sgpr offset for inline asm (#1356)



* wa prec, remove sgpr offset for inline asm

* macro for set tile

* ignore unused param if no kernel instances in host API

* fix more prec issue

* cache buffer resource

* fix

* support pre-nop

* clear tile by vector type members

* add workaround to reduce scratch memory

* conditionally enable workaround code

* enable workaround start from certain build version

* fallback set_tile() implementation from certain build version

* undo template argument changes

* put dummy asm in load_raw()

* fix comments, refactor s_nop inside buffer_load

---------
Co-authored-by: default avatarPoYen, Chen <PoYen.Chen@amd.com>
parent eb44e047
...@@ -271,7 +271,9 @@ class FmhaBwdApiPool: ...@@ -271,7 +271,9 @@ class FmhaBwdApiPool:
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
# GEMM0: Q@K=S^T # GEMM0: Q@K=S^T
......
...@@ -278,6 +278,9 @@ class FmhaFwdApiPool: ...@@ -278,6 +278,9 @@ class FmhaFwdApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
@dataclass @dataclass
......
...@@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool: ...@@ -331,6 +331,9 @@ class FmhaFwdSplitKVApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes) return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes)
@dataclass @dataclass
......
...@@ -54,233 +54,318 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa ...@@ -54,233 +54,318 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
} // namespace impl } // namespace impl
// TODO: glc/slc/... // TODO: glc/slc/...
template <index_t bytes> template <index_t bytes, bool pre_nop = false>
struct buffer_load; struct buffer_load;
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" #pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type // TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
// (exp_vector_type(xxx)) // (exp_vector_type(xxx))
template <> template <bool pre_nop>
struct buffer_load<16> struct buffer_load<16, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load<8> struct buffer_load<8, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load<4> struct buffer_load<4, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_dword %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load<2> struct buffer_load<2, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load<1> struct buffer_load<1, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <index_t bytes> template <index_t bytes, bool pre_nop = false>
struct buffer_load_if; struct buffer_load_if;
template <> template <bool pre_nop>
struct buffer_load_if<16> struct buffer_load_if<16, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
static_assert(sizeof(mbuf_t) == sizeof(T)); static_assert(sizeof(mbuf_t) == sizeof(T));
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load_if<8> struct buffer_load_if<8, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load_if<4> struct buffer_load_if<4, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load_if<2> struct buffer_load_if<2, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load_if<1> struct buffer_load_if<1, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" #pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
...@@ -294,17 +379,16 @@ struct buffer_store<16> ...@@ -294,17 +379,16 @@ struct buffer_store<16>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t; using mbuf_t = fp32x4_t;
asm volatile( asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3"
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -315,17 +399,16 @@ struct buffer_store<8> ...@@ -315,17 +399,16 @@ struct buffer_store<8>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t; using mbuf_t = fp32x2_t;
asm volatile( asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3"
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -336,17 +419,16 @@ struct buffer_store<4> ...@@ -336,17 +419,16 @@ struct buffer_store<4>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = float; using mbuf_t = float;
asm volatile( asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3"
"buffer_store_dword %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -357,17 +439,16 @@ struct buffer_store<2> ...@@ -357,17 +439,16 @@ struct buffer_store<2>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 2); static_assert(sizeof(T) == 2);
using mbuf_t = short; using mbuf_t = short;
asm volatile( asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3"
"buffer_store_short %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -378,17 +459,16 @@ struct buffer_store<1> ...@@ -378,17 +459,16 @@ struct buffer_store<1>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = float; using mbuf_t = float;
asm volatile( asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3"
"buffer_store_byte %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -402,21 +482,20 @@ struct buffer_store_if<16> ...@@ -402,21 +482,20 @@ struct buffer_store_if<16>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t; using mbuf_t = fp32x4_t;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n" "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -431,7 +510,7 @@ struct buffer_store_if<8> ...@@ -431,7 +510,7 @@ struct buffer_store_if<8>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
...@@ -439,14 +518,13 @@ struct buffer_store_if<8> ...@@ -439,14 +518,13 @@ struct buffer_store_if<8>
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch // TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
using mbuf_t = ext_vector_t<typename T::value_type, T::size()>; using mbuf_t = ext_vector_t<typename T::value_type, T::size()>;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n" "buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -461,21 +539,20 @@ struct buffer_store_if<4> ...@@ -461,21 +539,20 @@ struct buffer_store_if<4>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dword %0, %1, %2, %3 offen offset:%4\n" "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -490,21 +567,20 @@ struct buffer_store_if<2> ...@@ -490,21 +567,20 @@ struct buffer_store_if<2>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
static_assert(sizeof(T) == 2); static_assert(sizeof(T) == 2);
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = short; using mbuf_t = short;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_short %0, %1, %2, %3 offen offset:%4\n" "buffer_store_short %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -519,21 +595,20 @@ struct buffer_store_if<1> ...@@ -519,21 +595,20 @@ struct buffer_store_if<1>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_byte %0, %1, %2, %3 offen offset:%4\n" "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -901,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, ...@@ -901,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int soffset, // dst_wave_addr_offset int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
CK_TILE_DEVICE void async_buffer_load_dword(void* smem, template <bool pre_nop = false>
int32x4_t rsrc, CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
index_t voffset, int32x4_t rsrc,
index_t soffset, index_t voffset,
index_t ioffset /*max 0xFFF*/, index_t /*soffset*/,
index_t /*flag*/ = 0) index_t ioffset /*max 0xFFF*/,
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds" if constexpr(pre_nop)
: "=r"(smem) /*dummy dependency for smem*/ asm volatile("s_nop 4\n"
: "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset) "buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "memory"); : "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
: "memory");
else
asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
: "memory");
} }
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
...@@ -1223,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe ...@@ -1223,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
template <typename T, template <typename T,
index_t N, index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true> bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t src_wave_buffer_resource, int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
index_t src_wave_addr_offset, index_t src_wave_addr_offset,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
constexpr index_t bytes = sizeof(T) * N; constexpr index_t bytes = sizeof(T) * N;
static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
...@@ -1237,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, ...@@ -1237,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
using type = thread_buffer<T, N>; using type = thread_buffer<T, N>;
if constexpr(oob_conditional_check) if constexpr(oob_conditional_check)
{ {
buffer_load_if<sizeof(type)>{}( buffer_load_if<sizeof(type), pre_nop>{}(dst,
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
flag,
bool_constant<pre_nop>{});
} }
else else
{ {
buffer_load<sizeof(type)>{}( buffer_load<sizeof(type), pre_nop>{}(dst,
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
flag,
bool_constant<pre_nop>{});
} }
} }
template <typename T, template <typename T,
index_t N, index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default> amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
int32x4_t src_wave_buffer_resource, int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
index_t src_wave_addr_offset, index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0) index_t src_immediate_addr_offset = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
async_buffer_load_dword(smem, async_buffer_load_dword_v(smem,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset,
src_immediate_addr_offset); src_immediate_addr_offset,
0,
bool_constant<pre_nop>{});
} }
template <index_t N, template <index_t N,
...@@ -1909,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, ...@@ -1909,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
template <typename T, template <typename T,
index_t N, index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true> bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const T* p_src_wave, const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_element_space_size, index_t src_element_space_size,
index_t is_valid_element = 0) index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
{ {
const int32x4_t src_wave_buffer_resource = const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check>( amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element); dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
is_valid_element,
bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
is_valid_element,
bool_constant<pre_nop>{});
} }
// unfortunately async copy can not make sure invalid data is zero inside LDS // unfortunately async copy can not make sure invalid data is zero inside LDS
...@@ -1931,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, ...@@ -1931,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
// buffer_load OOB still working. // buffer_load OOB still working.
template <typename T, template <typename T,
index_t N, index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default> amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, bool pre_nop = false>
const T* p_src_wave, CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
index_t src_thread_element_offset, const T* p_src_wave,
index_t src_element_space_size) index_t src_thread_element_offset,
index_t src_element_space_size,
bool_constant<pre_nop> = {})
{ {
const int32x4_t src_wave_buffer_resource = const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
...@@ -1943,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, ...@@ -1943,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>( amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0); smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
} }
// buffer_store requires: // buffer_store requires:
......
...@@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load() ...@@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
" ::); " ::);
} }
CK_TILE_DEVICE void s_nop() CK_TILE_DEVICE void s_nop(index_t cnt = 0)
{ {
#if 1 #if 1
asm volatile("\ asm volatile("s_nop %0" : : "n"(cnt) :);
s_nop 0 \n \
" ::);
#else #else
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(cnt);
#endif #endif
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#define __gfx12__ #define __gfx12__
#endif #endif
#include "hip/hip_version.h"
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
...@@ -147,6 +148,14 @@ ...@@ -147,6 +148,14 @@
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif #endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG #ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0 #define CK_TILE_DEBUG_LOG 0
#endif #endif
......
...@@ -69,6 +69,8 @@ struct buffer_view<address_space_enum::generic, ...@@ -69,6 +69,8 @@ struct buffer_view<address_space_enum::generic,
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::generic; return address_space_enum::generic;
...@@ -224,25 +226,36 @@ struct buffer_view<address_space_enum::global, ...@@ -224,25 +226,36 @@ struct buffer_view<address_space_enum::global,
T* p_data_ = nullptr; T* p_data_ = nullptr;
BufferSizeType buffer_size_; BufferSizeType buffer_size_;
int32x4_t cached_buf_res_;
remove_cvref_t<T> invalid_element_value_ = T{0}; remove_cvref_t<T> invalid_element_value_ = T{0};
CK_TILE_HOST_DEVICE constexpr buffer_view() CK_TILE_HOST_DEVICE constexpr buffer_view()
: p_data_{}, buffer_size_{}, invalid_element_value_{} : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{}
{ {
} }
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} : p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0}
{ {
} }
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
BufferSizeType buffer_size, BufferSizeType buffer_size,
T invalid_element_value) T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} : p_data_{p_data},
buffer_size_{buffer_size},
cached_buf_res_{0},
invalid_element_value_{invalid_element_value}
{ {
} }
// this is non constexpr intentially (will call some intrinsic internally)
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE void init_raw()
{
cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type));
}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::global; return address_space_enum::global;
...@@ -333,12 +346,15 @@ struct buffer_view<address_space_enum::global, ...@@ -333,12 +346,15 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <typename X, template <typename X,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
get_raw(remove_cvref_t<X>& dst, index_t i, bool is_valid_element) const index_t i,
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{ {
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -349,18 +365,21 @@ struct buffer_view<address_space_enum::global, ...@@ -349,18 +365,21 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>( amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
dst, p_data_, i, buffer_size_, is_valid_element); dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
} }
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <typename X, template <typename X,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t<T>* smem,
async_get(remove_cvref_t<T>* smem, index_t i, bool /*is_valid_element*/) const index_t i,
bool /*is_valid_element*/,
bool_constant<pre_nop> = {}) const
{ {
// X is vector of T // X is vector of T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -371,8 +390,8 @@ struct buffer_view<address_space_enum::global, ...@@ -371,8 +390,8 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>( amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
smem, p_data_, i, buffer_size_); smem, cached_buf_res_, i, bool_constant<pre_nop>{});
} }
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
...@@ -627,6 +646,8 @@ struct buffer_view<address_space_enum::lds, ...@@ -627,6 +646,8 @@ struct buffer_view<address_space_enum::lds,
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::lds; return address_space_enum::lds;
...@@ -909,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr, ...@@ -909,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr,
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::vgpr; return address_space_enum::vgpr;
......
...@@ -36,30 +36,37 @@ template <typename T, ...@@ -36,30 +36,37 @@ template <typename T,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
bool oob_conditional_check = true> bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile, CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_with_static_distribution<BottomTensorView_, const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{ {
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}); tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
typename BottomTensorView_, typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord> index_t NumCoord,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
async_load_tile_raw(LdsTileWindow_&& lds_tile, async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_, const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window) NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{ {
return tile_window.async_load(lds_tile); return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
......
...@@ -35,6 +35,8 @@ struct null_tile_window ...@@ -35,6 +35,8 @@ struct null_tile_window
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; } CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
CK_TILE_DEVICE void init_raw() {}
WindowLengths window_lengths_; WindowLengths window_lengths_;
}; };
......
...@@ -36,6 +36,8 @@ struct tensor_view ...@@ -36,6 +36,8 @@ struct tensor_view
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); }
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; } CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension() CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
...@@ -85,30 +87,34 @@ struct tensor_view ...@@ -85,30 +87,34 @@ struct tensor_view
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X // "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X, template <typename X,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>, typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE void CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
get_vectorized_elements_raw(remove_cvref_t<X>& dst, const TensorCoord& coord,
const TensorCoord& coord, bool_constant<oob_conditional_check> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<pre_nop> = {}) const
{ {
return buf_.template get_raw<X, oob_conditional_check>( return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst, dst,
coord.get_offset(), coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
} }
template <typename X, template <typename X,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>, typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem, CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw(
const TensorCoord& coord) const remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const
{ {
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/); return buf_.template async_get_raw<X>(
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{});
} }
// X is vector of DataType. // X is vector of DataType.
......
...@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&) ...@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with // TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor... // sub-dword tensor...
template <typename DstrTensors, index_t v> template <typename DstrTensors, index_t v, bool skip_subdword_opt = false>
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>) CK_TILE_DEVICE void
set_tile(DstrTensors& dstr_tensor, number<v>, bool_constant<skip_subdword_opt> = {})
{ {
constexpr index_t tensor_bytes = using elem_type = typename DstrTensors::DataType;
DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType); constexpr index_t elem_size = sizeof(elem_type);
if constexpr(v == 0 && tensor_bytes % 4 == 0)
constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
// # bytes per write = 4
if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
{ {
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
auto& buffer = dstr_tensor.get_thread_buffer();
static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) {
if constexpr(elem_size == 1)
{
// # elements per write = 4
constexpr auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
buffer[i_write * 4 + 0] = values.x;
buffer[i_write * 4 + 1] = values.y;
buffer[i_write * 4 + 2] = values.z;
buffer[i_write * 4 + 3] = values.w;
}
else if constexpr(elem_size == 2)
{
// # elements per write = 2
constexpr auto values = ext_vector_t<elem_type, 2>{0, 0};
buffer[i_write * 2 + 0] = values.x;
buffer[i_write * 2 + 1] = values.y;
}
else if constexpr(elem_size == 4)
{
// # elements per write = 1
constexpr elem_type value = 0;
buffer[i_write] = value;
}
else
{
static_assert(false, "type not supported");
}
});
#else
using dvec_t = array<index_t, tensor_bytes / 4>; using dvec_t = array<index_t, tensor_bytes / 4>;
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer()); auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
for(auto i = 0; i < tensor.size(); i++) for(auto i = 0; i < tensor.size(); i++)
tensor.get(i) = v; tensor.get(i) = v;
#endif
} }
else else
{ {
tile_elementwise_inout( tile_elementwise_inout([](auto& x) { x = type_convert<elem_type, index_t>(v); },
[](auto& x) { x = type_convert<typename DstrTensors::DataType, index_t>(v); }, dstr_tensor);
dstr_tensor);
} }
} }
......
...@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution ...@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution
return dst_tensor; return dst_tensor;
} }
template <typename DstTile, bool oob_conditional_check = true> template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution ...@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
// data index [y0, y1, ...] // data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
...@@ -384,7 +391,8 @@ struct tile_window_with_static_distribution ...@@ -384,7 +391,8 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>( get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(), dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord, bottom_tensor_thread_coord,
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{},
pre_nop_);
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -399,12 +407,17 @@ struct tile_window_with_static_distribution ...@@ -399,12 +407,17 @@ struct tile_window_with_static_distribution
} }
}); });
}); });
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
"scratch memory" ::);
#endif
} }
// TODO: currently async load only implemented in inline asm // TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true> template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{ {
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>; using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView; // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
...@@ -449,11 +462,17 @@ struct tile_window_with_static_distribution ...@@ -449,11 +462,17 @@ struct tile_window_with_static_distribution
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
// read from bottom tensor // read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>( get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord); smem, bottom_tensor_thread_coord, pre_nop_);
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -668,6 +687,67 @@ struct tile_window_with_static_distribution ...@@ -668,6 +687,67 @@ struct tile_window_with_static_distribution
}); });
} }
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view // this is the bottom tensor view
// [x0', x1', ...] ==> [offset] // [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_; BottomTensorView bottom_tensor_view_;
......
...@@ -81,6 +81,12 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -81,6 +81,12 @@ struct BlockFmhaPipelineQRKSVSAsync
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
// minimize occupancy
if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)
{
return 1;
}
if constexpr(kK0BlockLength <= 32) if constexpr(kK0BlockLength <= 32)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
...@@ -220,6 +226,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -220,6 +226,7 @@ struct BlockFmhaPipelineQRKSVSAsync
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>());
q_dram_window.init_raw();
// TODO: we use async Copy for K, which is inline asm // TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well // a side effect is we have to use inline asm for q as well
...@@ -293,6 +300,17 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -293,6 +300,17 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window.get_window_origin(), k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load // load
k_dram_window.init_raw();
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = [&]() {
if constexpr(kPadSeqLenK &&
(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout)))
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window( auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_bottom_tensor_view(),
...@@ -310,7 +328,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -310,7 +328,7 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy::template MakeVDramTileDistribution<Problem>()); Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile // prefetch K tile
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -333,7 +351,9 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -333,7 +351,9 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}), async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window); k_dram_window,
k_oob_ck,
k_pre_np);
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
...@@ -637,16 +657,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -637,16 +657,13 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
// move K tile windows // move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0}); move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window = k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
if constexpr(k1_loops >= 2 && if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{})) LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
} }
// tail // tail
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment