Unverified Commit 66593407 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] Pick bugfixes for ROCm 6.2 compiler issues (#1430)

parent 00626ca8
...@@ -355,6 +355,9 @@ class FmhaFwdApiPool: ...@@ -355,6 +355,9 @@ class FmhaFwdApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
@dataclass @dataclass
...@@ -489,7 +492,8 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw ...@@ -489,7 +492,8 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
pipelines = [] pipelines = []
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
if hdim == 256: # if hdim=32, fallback to 'qr' pipeline to workaround rocm 6.2 compiler problem (missing s_waitcnt)
if hdim == 256 or hdim == 32:
# if True: # if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
...@@ -497,11 +501,18 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw ...@@ -497,11 +501,18 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask))
else: else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, mask)) if bias == "bias":
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) # TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask))
if receipt == 1: pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, mask))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
......
...@@ -34,234 +34,338 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz ...@@ -34,234 +34,338 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
return r; return r;
} }
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
template<index_t N, typename T> struct buffer_load_trait;
template<typename T> struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; };
template<typename T> struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; };
template<typename T> struct buffer_load_trait<4 , T> { using payload_t = float; };
template<typename T> struct buffer_load_trait<2 , T> { using payload_t = float; };
template<typename T> struct buffer_load_trait<1 , T> { using payload_t = float; };
#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA
template<> struct buffer_load_trait<16, thread_buffer<bf16_t, 8>> { using payload_t = bf16x8_t; };
template<> struct buffer_load_trait<8 , thread_buffer<bf16_t, 4>> { using payload_t = bf16x4_t; };
template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payload_t = bf16x2_t; };
#endif
// clang-format on
} // namespace impl
// TODO: glc/slc/... // TODO: glc/slc/...
template <index_t bytes> template <index_t bytes, bool pre_nop = false>
struct buffer_load; struct buffer_load;
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" #pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type // TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
// (exp_vector_type(xxx)) // (exp_vector_type(xxx))
template <> template <bool pre_nop>
struct buffer_load<16> struct buffer_load<16, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t; using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load<8> struct buffer_load<8, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t; using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load<4> struct buffer_load<4, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_dword %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load<2> struct buffer_load<2, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load<1> struct buffer_load<1, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0) index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" if constexpr(pre_nop)
: "+v"(reinterpret_cast<mbuf_t&>(value)) asm volatile("s_nop 4\n"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
: "memory"); : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
} }
}; };
template <index_t bytes> template <index_t bytes, bool pre_nop = false>
struct buffer_load_if; struct buffer_load_if;
template <> template <bool pre_nop>
struct buffer_load_if<16> struct buffer_load_if<16, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t; using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
static_assert(sizeof(mbuf_t) == sizeof(T)); static_assert(sizeof(mbuf_t) == sizeof(T));
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load_if<8> struct buffer_load_if<8, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x2_t; using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load_if<4> struct buffer_load_if<4, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load_if<2> struct buffer_load_if<2, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
template <> template <bool pre_nop>
struct buffer_load_if<1> struct buffer_load_if<1, pre_nop>
{ {
template <typename T> template <typename T>
CK_TILE_DEVICE void operator()(T& value, CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile( if constexpr(pre_nop)
"v_cmpx_le_u32 exec, 1, %5\n" asm volatile("s_nop 4\n"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" "v_cmpx_le_u32 exec, 1, %4\n"
"s_mov_b64 exec %6" "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
: "+v"(reinterpret_cast<mbuf_t&>(value)) "s_mov_b64 exec %5"
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "memory"); : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
} }
}; };
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" #pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
...@@ -275,17 +379,16 @@ struct buffer_store<16> ...@@ -275,17 +379,16 @@ struct buffer_store<16>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t; using mbuf_t = fp32x4_t;
asm volatile( asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3"
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -296,17 +399,16 @@ struct buffer_store<8> ...@@ -296,17 +399,16 @@ struct buffer_store<8>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t; using mbuf_t = fp32x2_t;
asm volatile( asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3"
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -317,17 +419,16 @@ struct buffer_store<4> ...@@ -317,17 +419,16 @@ struct buffer_store<4>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = float; using mbuf_t = float;
asm volatile( asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3"
"buffer_store_dword %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -338,17 +439,16 @@ struct buffer_store<2> ...@@ -338,17 +439,16 @@ struct buffer_store<2>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 2); static_assert(sizeof(T) == 2);
using mbuf_t = short; using mbuf_t = short;
asm volatile( asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3"
"buffer_store_short %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -359,17 +459,16 @@ struct buffer_store<1> ...@@ -359,17 +459,16 @@ struct buffer_store<1>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1) index_t /*flag*/ = 1)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = float; using mbuf_t = float;
asm volatile( asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3"
"buffer_store_byte %0, %1, %2, %3 offen offset:%4" :
: : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "memory");
: "memory");
} }
}; };
...@@ -383,21 +482,20 @@ struct buffer_store_if<16> ...@@ -383,21 +482,20 @@ struct buffer_store_if<16>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t; using mbuf_t = fp32x4_t;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n" "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -412,7 +510,7 @@ struct buffer_store_if<8> ...@@ -412,7 +510,7 @@ struct buffer_store_if<8>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
...@@ -420,14 +518,13 @@ struct buffer_store_if<8> ...@@ -420,14 +518,13 @@ struct buffer_store_if<8>
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch // TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
using mbuf_t = ext_vector_t<typename T::value_type, T::size()>; using mbuf_t = ext_vector_t<typename T::value_type, T::size()>;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n" "buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -442,21 +539,20 @@ struct buffer_store_if<4> ...@@ -442,21 +539,20 @@ struct buffer_store_if<4>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dword %0, %1, %2, %3 offen offset:%4\n" "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -471,21 +567,20 @@ struct buffer_store_if<2> ...@@ -471,21 +567,20 @@ struct buffer_store_if<2>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
static_assert(sizeof(T) == 2); static_assert(sizeof(T) == 2);
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = short; using mbuf_t = short;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_short %0, %1, %2, %3 offen offset:%4\n" "buffer_store_short %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -500,21 +595,20 @@ struct buffer_store_if<1> ...@@ -500,21 +595,20 @@ struct buffer_store_if<1>
CK_TILE_DEVICE void operator()(const T& value, CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/, int32x4_t res /*buffer resource*/,
index_t v_offset, index_t v_offset,
index_t s_offset, index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/, index_t i_offset /*max 0xFFF*/,
index_t flag = 1) index_t flag = 1)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec(); auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n" asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_byte %0, %1, %2, %3 offen offset:%4\n" "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %6" "s_mov_b64 exec %5"
: :
: "v"(bit_cast<mbuf_t>(value)), : "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset), "v"(v_offset),
"s"(res), "s"(res),
"s"(s_offset),
"n"(i_offset), "n"(i_offset),
"v"(flag), "v"(flag),
"s"(save_exec) "s"(save_exec)
...@@ -538,8 +632,9 @@ namespace impl{ ...@@ -538,8 +632,9 @@ namespace impl{
template<index_t N> template<index_t N>
CK_TILE_DEVICE void insert_dummy_dep_per_dword(array<float, N>& b) CK_TILE_DEVICE void insert_dummy_dep_per_dword(array<float, N>& b)
{ {
static_for<0, b.size(), 1>{}([&](auto i){ constexpr auto kSize = remove_cvref_t<decltype(b)>::size();
asm volatile(" " : : "v"(b.get(i)) : "memory"); static_for<0, kSize, 1>{}([&](auto i){
asm volatile(" " : : "v"(b.get(number<i>{})) : "memory");
}); });
} }
#if 1 #if 1
...@@ -769,6 +864,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, ...@@ -769,6 +864,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
// buffer store ui16
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
CK_TILE_DEVICE_EXTERN void CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -859,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, ...@@ -859,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int soffset, // dst_wave_addr_offset int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
CK_TILE_DEVICE void async_buffer_load_dword(void* smem, template <bool pre_nop = false>
int32x4_t rsrc, CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
index_t voffset, int32x4_t rsrc,
index_t soffset, index_t voffset,
index_t ioffset /*max 0xFFF*/, index_t /*soffset*/,
index_t /*flag*/ = 0) index_t ioffset /*max 0xFFF*/,
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{ {
asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds" if constexpr(pre_nop)
: "=r"(smem) /*dummy dependency for smem*/ asm volatile("s_nop 4\n"
: "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset) "buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "memory"); : "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
: "memory");
else
asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
: "memory");
} }
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
...@@ -1181,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe ...@@ -1181,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
template <typename T, template <typename T,
index_t N, index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true> bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t src_wave_buffer_resource, int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
index_t src_wave_addr_offset, index_t src_wave_addr_offset,
index_t flag = 0) index_t flag = 0,
bool_constant<pre_nop> = {})
{ {
constexpr index_t bytes = sizeof(T) * N; constexpr index_t bytes = sizeof(T) * N;
static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
...@@ -1195,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, ...@@ -1195,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
using type = thread_buffer<T, N>; using type = thread_buffer<T, N>;
if constexpr(oob_conditional_check) if constexpr(oob_conditional_check)
{ {
buffer_load_if<sizeof(type)>{}( buffer_load_if<sizeof(type), pre_nop>{}(dst,
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
flag,
bool_constant<pre_nop>{});
} }
else else
{ {
buffer_load<sizeof(type)>{}( buffer_load<sizeof(type), pre_nop>{}(dst,
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
flag,
bool_constant<pre_nop>{});
} }
} }
template <typename T, template <typename T,
index_t N, index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default> amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
int32x4_t src_wave_buffer_resource, int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
index_t src_wave_addr_offset, index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0) index_t src_immediate_addr_offset = 0,
bool_constant<pre_nop> = {})
{ {
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
async_buffer_load_dword(smem, async_buffer_load_dword_v(smem,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset,
src_immediate_addr_offset); src_immediate_addr_offset,
0,
bool_constant<pre_nop>{});
} }
template <index_t N, template <index_t N,
...@@ -1339,7 +1481,10 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d ...@@ -1339,7 +1481,10 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, uint16_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
if constexpr(std::is_same<T, float>::value) // fp32 if constexpr(std::is_same<T, float>::value) // fp32
...@@ -1478,6 +1623,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d ...@@ -1478,6 +1623,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
static_cast<index_t>(coherence)); static_cast<index_t>(coherence));
} }
} }
else if constexpr(std::is_same<T, uint16_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_ui16(bit_cast<uint16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_ui16x2(bit_cast<uint16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_ui16x4(bit_cast<uint16x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_ui16x4(
src_thread_data.template get_as<uint16x4_t>()[number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_ui16x4(
src_thread_data.template get_as<uint16x4_t>()[number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(uint16_t),
static_cast<index_t>(coherence));
}
}
else else
{ {
using r_t = thread_buffer<int8_t, sizeof(T) * N>; using r_t = thread_buffer<int8_t, sizeof(T) * N>;
...@@ -1595,7 +1783,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th ...@@ -1595,7 +1783,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
{ {
if constexpr(N == 2) if constexpr(N == 2)
{ {
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast<fp16_t>(src_thread_data), llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast<fp16x2_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
...@@ -1821,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, ...@@ -1821,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
template <typename T, template <typename T,
index_t N, index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true> bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const T* p_src_wave, const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_element_space_size, index_t src_element_space_size,
index_t is_valid_element = 0) index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
{ {
const int32x4_t src_wave_buffer_resource = const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check>( amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element); dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
is_valid_element,
bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
is_valid_element,
bool_constant<pre_nop>{});
} }
// unfortunately async copy can not make sure invalid data is zero inside LDS // unfortunately async copy can not make sure invalid data is zero inside LDS
...@@ -1843,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, ...@@ -1843,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
// buffer_load OOB still working. // buffer_load OOB still working.
template <typename T, template <typename T,
index_t N, index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default> amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, bool pre_nop = false>
const T* p_src_wave, CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
index_t src_thread_element_offset, const T* p_src_wave,
index_t src_element_space_size) index_t src_thread_element_offset,
index_t src_element_space_size,
bool_constant<pre_nop> = {})
{ {
const int32x4_t src_wave_buffer_resource = const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
...@@ -1855,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, ...@@ -1855,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>( amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0); smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
} }
// buffer_store requires: // buffer_store requires:
......
...@@ -79,14 +79,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load() ...@@ -79,14 +79,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
" ::); " ::);
} }
CK_TILE_DEVICE void s_nop() CK_TILE_DEVICE void s_nop(index_t cnt = 0)
{ {
#if 1 #if 1
asm volatile("\ asm volatile("s_nop %0" : : "n"(cnt) :);
s_nop 0 \n \
" ::);
#else #else
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(cnt);
#endif #endif
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#define __gfx11__ #define __gfx11__
#endif #endif
#include "hip/hip_version.h"
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
...@@ -144,6 +145,15 @@ ...@@ -144,6 +145,15 @@
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif #endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133)
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG #ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0 #define CK_TILE_DEBUG_LOG 0
#endif #endif
...@@ -167,7 +177,15 @@ ...@@ -167,7 +177,15 @@
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0 #define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif #endif
#ifndef CK_TILE_USE_PK_FP16_TILE_CAST
#define CK_TILE_USE_PK_FP16_TILE_CAST 0
#endif
// TODO: better solve this inside compiler // TODO: better solve this inside compiler
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2 #ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0 #define CK_TILE_FMHA_FWD_FAST_EXP2 0
#endif #endif
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif
...@@ -68,6 +68,8 @@ struct buffer_view<address_space_enum::generic, ...@@ -68,6 +68,8 @@ struct buffer_view<address_space_enum::generic,
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::generic; return address_space_enum::generic;
...@@ -223,25 +225,36 @@ struct buffer_view<address_space_enum::global, ...@@ -223,25 +225,36 @@ struct buffer_view<address_space_enum::global,
T* p_data_ = nullptr; T* p_data_ = nullptr;
BufferSizeType buffer_size_; BufferSizeType buffer_size_;
int32x4_t cached_buf_res_;
remove_cvref_t<T> invalid_element_value_ = T{0}; remove_cvref_t<T> invalid_element_value_ = T{0};
CK_TILE_HOST_DEVICE constexpr buffer_view() CK_TILE_HOST_DEVICE constexpr buffer_view()
: p_data_{}, buffer_size_{}, invalid_element_value_{} : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{}
{ {
} }
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} : p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0}
{ {
} }
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
BufferSizeType buffer_size, BufferSizeType buffer_size,
T invalid_element_value) T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} : p_data_{p_data},
buffer_size_{buffer_size},
cached_buf_res_{0},
invalid_element_value_{invalid_element_value}
{ {
} }
// this is non constexpr intentially (will call some intrinsic internally)
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE void init_raw()
{
cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type));
}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::global; return address_space_enum::global;
...@@ -332,12 +345,15 @@ struct buffer_view<address_space_enum::global, ...@@ -332,12 +345,15 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <typename X, template <typename X,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
get_raw(remove_cvref_t<X>& dst, index_t i, bool is_valid_element) const index_t i,
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{ {
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -348,18 +364,21 @@ struct buffer_view<address_space_enum::global, ...@@ -348,18 +364,21 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>( amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
dst, p_data_, i, buffer_size_, is_valid_element); dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
} }
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
template <typename X, template <typename X,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value, typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false> bool>::type = false>
CK_TILE_DEVICE constexpr auto CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t<T>* smem,
async_get(remove_cvref_t<T>* smem, index_t i, bool /*is_valid_element*/) const index_t i,
bool /*is_valid_element*/,
bool_constant<pre_nop> = {}) const
{ {
// X is vector of T // X is vector of T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
...@@ -370,8 +389,8 @@ struct buffer_view<address_space_enum::global, ...@@ -370,8 +389,8 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>( amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
smem, p_data_, i, buffer_size_); smem, cached_buf_res_, i, bool_constant<pre_nop>{});
} }
// i is offset of T, not X. i should be aligned to X // i is offset of T, not X. i should be aligned to X
...@@ -626,6 +645,8 @@ struct buffer_view<address_space_enum::lds, ...@@ -626,6 +645,8 @@ struct buffer_view<address_space_enum::lds,
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::lds; return address_space_enum::lds;
...@@ -908,6 +929,8 @@ struct buffer_view<address_space_enum::vgpr, ...@@ -908,6 +929,8 @@ struct buffer_view<address_space_enum::vgpr,
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space() CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{ {
return address_space_enum::vgpr; return address_space_enum::vgpr;
......
...@@ -36,30 +36,37 @@ template <typename T, ...@@ -36,30 +36,37 @@ template <typename T,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord, index_t NumCoord,
bool oob_conditional_check = true> bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile, CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_with_static_distribution<BottomTensorView_, const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window, NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {}) bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{ {
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}); tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
template <typename LdsTileWindow_, template <typename LdsTileWindow_,
typename BottomTensorView_, typename BottomTensorView_,
typename WindowLengths_, typename WindowLengths_,
typename TileDistribution_, typename TileDistribution_,
index_t NumCoord> index_t NumCoord,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto CK_TILE_DEVICE auto
async_load_tile_raw(LdsTileWindow_&& lds_tile, async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_, const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_, WindowLengths_,
TileDistribution_, TileDistribution_,
NumCoord>& tile_window) NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{ {
return tile_window.async_load(lds_tile); return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
} }
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
......
...@@ -35,6 +35,8 @@ struct null_tile_window ...@@ -35,6 +35,8 @@ struct null_tile_window
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; } CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
CK_TILE_DEVICE void init_raw() {}
WindowLengths window_lengths_; WindowLengths window_lengths_;
}; };
......
...@@ -33,6 +33,8 @@ struct tensor_view ...@@ -33,6 +33,8 @@ struct tensor_view
{ {
} }
CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); }
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; } CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension() CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
...@@ -82,30 +84,34 @@ struct tensor_view ...@@ -82,30 +84,34 @@ struct tensor_view
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X // "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X, template <typename X,
bool oob_conditional_check = true, bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>, typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE void CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
get_vectorized_elements_raw(remove_cvref_t<X>& dst, const TensorCoord& coord,
const TensorCoord& coord, bool_constant<oob_conditional_check> = {},
bool_constant<oob_conditional_check> = {}) const bool_constant<pre_nop> = {}) const
{ {
return buf_.template get_raw<X, oob_conditional_check>( return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst, dst,
coord.get_offset(), coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
} }
template <typename X, template <typename X,
bool pre_nop = false,
typename std::enable_if< typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type, std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>, typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false> bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem, CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw(
const TensorCoord& coord) const remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const
{ {
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/); return buf_.template async_get_raw<X>(
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{});
} }
// X is vector of DataType. // X is vector of DataType.
......
...@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&) ...@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with // TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor... // sub-dword tensor...
template <typename DstrTensors, index_t v> template <typename DstrTensors, index_t v, bool skip_subdword_opt = false>
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>) CK_TILE_DEVICE void
set_tile(DstrTensors& dstr_tensor, number<v>, bool_constant<skip_subdword_opt> = {})
{ {
constexpr index_t tensor_bytes = using elem_type = typename DstrTensors::DataType;
DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType); constexpr index_t elem_size = sizeof(elem_type);
if constexpr(v == 0 && tensor_bytes % 4 == 0)
constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
// # bytes per write = 4
if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
{ {
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
auto& buffer = dstr_tensor.get_thread_buffer();
static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) {
if constexpr(elem_size == 1)
{
// # elements per write = 4
constexpr auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
buffer[i_write * 4 + 0] = values.x;
buffer[i_write * 4 + 1] = values.y;
buffer[i_write * 4 + 2] = values.z;
buffer[i_write * 4 + 3] = values.w;
}
else if constexpr(elem_size == 2)
{
// # elements per write = 2
constexpr auto values = ext_vector_t<elem_type, 2>{0, 0};
buffer[i_write * 2 + 0] = values.x;
buffer[i_write * 2 + 1] = values.y;
}
else if constexpr(elem_size == 4)
{
// # elements per write = 1
constexpr elem_type value = 0;
buffer[i_write] = value;
}
else
{
static_assert(false, "type not supported");
}
});
#else
using dvec_t = array<index_t, tensor_bytes / 4>; using dvec_t = array<index_t, tensor_bytes / 4>;
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer()); auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
for(auto i = 0; i < tensor.size(); i++) for(auto i = 0; i < tensor.size(); i++)
tensor.get(i) = v; tensor.get(i) = v;
#endif
} }
else else
{ {
tile_elementwise_inout( tile_elementwise_inout([](auto& x) { x = type_convert<elem_type, index_t>(v); },
[](auto& x) { x = type_convert<typename DstrTensors::DataType, index_t>(v); }, dstr_tensor);
dstr_tensor);
} }
} }
...@@ -110,7 +150,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor) ...@@ -110,7 +150,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
namespace impl { namespace impl {
// TODO: this is ugly // TODO: this is ugly
template <typename OutDataType, typename InTensor> template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// This API is designed to use the _pk_ serious of function // This API is designed to use the _pk_ serious of function
...@@ -156,6 +196,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) ...@@ -156,6 +196,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
#endif #endif
} }
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
static_assert(thread_buffer_size % 2 == 0);
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2;
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
// TODO: this is rtz cvt, need be very careful
for(index_t i = 0; i < thread_buffer_size_pk; i++)
{
auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0],
in_dstr_tensors.get_thread_buffer()[2 * i + 1]);
out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
}
return out_dstr_tensor;
#else
// fallback
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
in_dstr_tensors);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST #if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword // this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
...@@ -229,8 +300,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) ...@@ -229,8 +300,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
float> && float> &&
(SrcTensor::get_thread_buffer_size() % 4 == 0)) (SrcTensor::get_thread_buffer_size() % 4 == 0))
{ {
return impl::cast_tile_pk_fp8x4<DstType, SrcTensor>(src_tensor); return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
} }
#if CK_TILE_USE_PK_FP16_TILE_CAST
else if constexpr(std::is_same_v<DstType, fp16_t> &&
std::is_same_v<typename SrcTensor::DataType, float> &&
(SrcTensor::get_thread_buffer_size() % 2 == 0))
{
return impl::cast_tile_pk_fp16_fp32<DstType, SrcTensor>(src_tensor);
}
#endif
#if CK_TILE_USE_SUBDWORD_TILE_CAST #if CK_TILE_USE_SUBDWORD_TILE_CAST
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4) else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
{ {
......
...@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution ...@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution
return dst_tensor; return dst_tensor;
} }
template <typename DstTile, bool oob_conditional_check = true> template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{ {
using Traits = load_store_traits; using Traits = load_store_traits;
...@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution ...@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
// data index [y0, y1, ...] // data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
...@@ -384,8 +391,12 @@ struct tile_window_with_static_distribution ...@@ -384,8 +391,12 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>( get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(), dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord, bottom_tensor_thread_coord,
bool_constant<oob_conditional_check>{}); bool_constant<oob_conditional_check>{},
pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile(
""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{ {
...@@ -399,12 +410,17 @@ struct tile_window_with_static_distribution ...@@ -399,12 +410,17 @@ struct tile_window_with_static_distribution
} }
}); });
}); });
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
"scratch memory" ::);
#endif
} }
// TODO: currently async load only implemented in inline asm // TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true> template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {}) const bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{ {
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>; using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView; // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
...@@ -449,11 +465,17 @@ struct tile_window_with_static_distribution ...@@ -449,11 +465,17 @@ struct tile_window_with_static_distribution
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{}; constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
// read from bottom tensor // read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>( get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord); smem, bottom_tensor_thread_coord, pre_nop_);
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
...@@ -608,6 +630,67 @@ struct tile_window_with_static_distribution ...@@ -608,6 +630,67 @@ struct tile_window_with_static_distribution
}); });
} }
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view // this is the bottom tensor view
// [x0', x1', ...] ==> [offset] // [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_; BottomTensorView bottom_tensor_view_;
......
...@@ -78,6 +78,12 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -78,6 +78,12 @@ struct BlockFmhaPipelineQRKSVSAsync
return Problem::kBlockPerCu; return Problem::kBlockPerCu;
else else
{ {
// minimize occupancy
if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS)
{
return 1;
}
if constexpr(kK0BlockLength <= 32) if constexpr(kK0BlockLength <= 32)
{ {
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
...@@ -212,11 +218,14 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -212,11 +218,14 @@ struct BlockFmhaPipelineQRKSVSAsync
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>());
q_dram_window.init_raw();
// TODO: we use async Copy for K, which is inline asm // TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well // a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){}; auto q = decltype(load_tile(q_dram_window)){};
set_tile(q, number<0>{}); // use per-dword clear to avoid scratch // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window); load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -285,6 +294,16 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -285,6 +294,16 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window.get_window_origin(), k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load // load
k_dram_window.init_raw();
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = [&]() {
if constexpr(kPadSeqLenK && (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
(BiasEnum != BlockAttentionBiasEnum::NO_BIAS)))
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window( auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_bottom_tensor_view(),
...@@ -299,7 +318,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -299,7 +318,7 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy::template MakeVDramTileDistribution<Problem>()); Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile // prefetch K tile
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -322,7 +341,9 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -322,7 +341,9 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}), async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window); k_dram_window,
k_oob_ck,
k_pre_np);
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
...@@ -609,16 +630,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -609,16 +630,13 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
// move K tile windows // move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0}); move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window = k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
if constexpr(k1_loops >= 2 && if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{})) LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
} }
// tail // tail
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment