Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
66593407
Unverified
Commit
66593407
authored
Aug 05, 2024
by
Po Yen Chen
Committed by
GitHub
Aug 04, 2024
Browse files
[CK_TILE] Pick bugfixes for ROCm 6.2 compiler issues (#1430)
parent
00626ca8
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
747 additions
and
266 deletions
+747
-266
example/ck_tile/01_fmha/generate.py
example/ck_tile/01_fmha/generate.py
+17
-6
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+435
-199
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+3
-5
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+18
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+34
-11
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+13
-6
include/ck_tile/core/tensor/null_tile_window.hpp
include/ck_tile/core/tensor/null_tile_window.hpp
+2
-0
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+15
-9
include/ck_tile/core/tensor/tile_elementwise.hpp
include/ck_tile/core/tensor/tile_elementwise.hpp
+89
-10
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+94
-11
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+27
-9
No files found.
example/ck_tile/01_fmha/generate.py
View file @
66593407
...
...
@@ -355,6 +355,9 @@ class FmhaFwdApiPool:
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
if
not
per_dtypes
:
# empty string we add some ignore to suppress warning in api
per_dtypes
+=
' (void)t ; (void)s ; (void)a;'
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_API
.
format
(
F_dispatch
=
per_dtypes
)
@
dataclass
...
...
@@ -489,19 +492,27 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
for
mask
,
bias
,
lse
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
]):
if
hdim
==
256
:
# if hdim=32, fallback to 'qr' pipeline to workaround rocm 6.2 compiler problem (missing s_waitcnt)
if
hdim
==
256
or
hdim
==
32
:
# if True:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
else
:
if
bias
==
"bias"
:
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
else
:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
if
receipt
==
1
:
if
receipt
==
1
and
bias
!=
"bias"
:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
,
'bf8'
]:
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
66593407
...
...
@@ -34,233 +34,337 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
return
r
;
}
namespace
impl
{
// below type indicate the data type used for buffer load inline asm
// clang-format off
template
<
index_t
N
,
typename
T
>
struct
buffer_load_trait
;
template
<
typename
T
>
struct
buffer_load_trait
<
16
,
T
>
{
using
payload_t
=
fp32x4_t
;
};
template
<
typename
T
>
struct
buffer_load_trait
<
8
,
T
>
{
using
payload_t
=
fp32x2_t
;
};
template
<
typename
T
>
struct
buffer_load_trait
<
4
,
T
>
{
using
payload_t
=
float
;
};
template
<
typename
T
>
struct
buffer_load_trait
<
2
,
T
>
{
using
payload_t
=
float
;
};
template
<
typename
T
>
struct
buffer_load_trait
<
1
,
T
>
{
using
payload_t
=
float
;
};
#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA
template
<
>
struct
buffer_load_trait
<
16
,
thread_buffer
<
bf16_t
,
8
>>
{
using
payload_t
=
bf16x8_t
;
};
template
<
>
struct
buffer_load_trait
<
8
,
thread_buffer
<
bf16_t
,
4
>>
{
using
payload_t
=
bf16x4_t
;
};
template
<
>
struct
buffer_load_trait
<
4
,
thread_buffer
<
bf16_t
,
2
>>
{
using
payload_t
=
bf16x2_t
;
};
#endif
// clang-format on
}
// namespace impl
// TODO: glc/slc/...
template
<
index_t
bytes
>
template
<
index_t
bytes
,
bool
pre_nop
=
false
>
struct
buffer_load
;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
// (exp_vector_type(xxx))
template
<
>
struct
buffer_load
<
16
>
template
<
bool
pre_nop
>
struct
buffer_load
<
16
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
fp32x4_t
;
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
16
,
T
>::
payload_t
;
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load
<
8
>
template
<
bool
pre_nop
>
struct
buffer_load
<
8
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
8
);
using
mbuf_t
=
fp32x2_t
;
asm
volatile
(
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
8
,
T
>::
payload_t
;
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load
<
4
>
template
<
bool
pre_nop
>
struct
buffer_load
<
4
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
asm
volatile
(
"buffer_load_dword %0, %1, %2, %3 offen offset:%4"
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
4
,
T
>::
payload_t
;
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dword %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load
<
2
>
template
<
bool
pre_nop
>
struct
buffer_load
<
2
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
// subdword is buggy, use dword buf and convert manually
using
mbuf_t
=
float
;
asm
volatile
(
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
2
,
T
>::
payload_t
;
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load
<
1
>
template
<
bool
pre_nop
>
struct
buffer_load
<
1
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
asm
volatile
(
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
1
,
T
>::
payload_t
;
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
index_t
bytes
>
template
<
index_t
bytes
,
bool
pre_nop
=
false
>
struct
buffer_load_if
;
template
<
>
struct
buffer_load_if
<
16
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
16
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
16
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
fp32x4
_t
;
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
16
,
T
>::
payload
_t
;
static_assert
(
sizeof
(
mbuf_t
)
==
sizeof
(
T
));
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load_if
<
8
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
8
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
8
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
fp32x2_t
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
8
,
T
>::
payload_t
;
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load_if
<
4
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
4
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
4
,
T
>::
payload_t
;
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load_if
<
2
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
2
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
2
,
T
>::
payload_t
;
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
template
<
>
struct
buffer_load_if
<
1
>
template
<
bool
pre_nop
>
struct
buffer_load_if
<
1
,
pre_nop
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %5
\n
"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4
\n
"
"s_mov_b64 exec %6"
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
1
,
T
>::
payload_t
;
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
else
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %4
\n
"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3
\n
"
"s_mov_b64 exec %5"
:
"+v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
:
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
saved_exec
)
:
"memory"
);
}
};
...
...
@@ -275,16 +379,15 @@ struct buffer_store<16>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
fp32x4_t
;
asm
volatile
(
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4"
asm
volatile
(
"buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
...
...
@@ -296,16 +399,15 @@ struct buffer_store<8>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
8
);
using
mbuf_t
=
fp32x2_t
;
asm
volatile
(
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4"
asm
volatile
(
"buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
...
...
@@ -317,16 +419,15 @@ struct buffer_store<4>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
asm
volatile
(
"buffer_store_dword %0, %1, %2, %3 offen offset:%4"
asm
volatile
(
"buffer_store_dword %0, %1, %2, 0 offen offset:%3"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
...
...
@@ -338,16 +439,15 @@ struct buffer_store<2>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
2
);
using
mbuf_t
=
short
;
asm
volatile
(
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
asm
volatile
(
"buffer_store_short %0, %1, %2, 0 offen offset:%3"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
...
...
@@ -359,16 +459,15 @@ struct buffer_store<1>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
/*flag*/
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
float
;
asm
volatile
(
"buffer_store_byte %0, %1, %2, %3 offen offset:%4"
asm
volatile
(
"buffer_store_byte %0, %1, %2, 0 offen offset:%3"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
)
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"n"
(
i_offset
)
:
"memory"
);
}
};
...
...
@@ -383,21 +482,20 @@ struct buffer_store_if<16>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
16
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
fp32x4_t
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
"buffer_store_dwordx4 %0, %1, %2,
%3
offen offset:%
4
\n
"
"s_mov_b64 exec %
6
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_dwordx4 %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
5
"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
...
...
@@ -412,7 +510,7 @@ struct buffer_store_if<8>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
...
...
@@ -420,14 +518,13 @@ struct buffer_store_if<8>
auto
save_exec
=
__builtin_amdgcn_read_exec
();
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
using
mbuf_t
=
ext_vector_t
<
typename
T
::
value_type
,
T
::
size
()
>
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
"buffer_store_dwordx2 %0, %1, %2,
%3
offen offset:%
4
\n
"
"s_mov_b64 exec %
6
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_dwordx2 %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
5
"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
...
...
@@ -442,21 +539,20 @@ struct buffer_store_if<4>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
"buffer_store_dword %0, %1, %2,
%3
offen offset:%
4
\n
"
"s_mov_b64 exec %
6
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_dword %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
5
"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
...
...
@@ -471,21 +567,20 @@ struct buffer_store_if<2>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
2
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
short
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
"buffer_store_short %0, %1, %2,
%3
offen offset:%
4
\n
"
"s_mov_b64 exec %
6
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_short %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
5
"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
...
...
@@ -500,21 +595,20 @@ struct buffer_store_if<1>
CK_TILE_DEVICE
void
operator
()(
const
T
&
value
,
int32x4_t
res
/*buffer resource*/
,
index_t
v_offset
,
index_t
s_offset
,
index_t
/*
s_offset
*/
,
index_t
i_offset
/*max 0xFFF*/
,
index_t
flag
=
1
)
{
static_assert
(
sizeof
(
T
)
==
4
);
auto
save_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
float
;
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
5
\n
"
"buffer_store_byte %0, %1, %2,
%3
offen offset:%
4
\n
"
"s_mov_b64 exec %
6
"
asm
volatile
(
"v_cmpx_le_u32 exec, 1, %
4
\n
"
"buffer_store_byte %0, %1, %2,
0
offen offset:%
3
\n
"
"s_mov_b64 exec %
5
"
:
:
"v"
(
bit_cast
<
mbuf_t
>
(
value
)),
"v"
(
v_offset
),
"s"
(
res
),
"s"
(
s_offset
),
"n"
(
i_offset
),
"v"
(
flag
),
"s"
(
save_exec
)
...
...
@@ -538,8 +632,9 @@ namespace impl{
template
<
index_t
N
>
CK_TILE_DEVICE
void
insert_dummy_dep_per_dword
(
array
<
float
,
N
>&
b
)
{
static_for
<
0
,
b
.
size
(),
1
>
{}([
&
](
auto
i
){
asm
volatile
(
" "
:
:
"v"
(
b
.
get
(
i
))
:
"memory"
);
constexpr
auto
kSize
=
remove_cvref_t
<
decltype
(
b
)
>::
size
();
static_for
<
0
,
kSize
,
1
>
{}([
&
](
auto
i
){
asm
volatile
(
" "
:
:
"v"
(
b
.
get
(
number
<
i
>
{}))
:
"memory"
);
});
}
#if 1
...
...
@@ -769,6 +864,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.i32"
);
// buffer store ui16
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_ui16
(
uint16_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.i16"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_ui16x2
(
uint16x2_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v2i16"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_ui16x4
(
uint16x4_t
vdata
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.store.v4i16"
);
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_store_i32x2
(
int32x2_t
vdata
,
int32x4_t
rsrc
,
...
...
@@ -859,16 +976,25 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int
soffset
,
// dst_wave_addr_offset
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
CK_TILE_DEVICE
void
async_buffer_load_dword
(
void
*
smem
,
template
<
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
async_buffer_load_dword_v
(
void
*
smem
,
int32x4_t
rsrc
,
index_t
voffset
,
index_t
soffset
,
index_t
/*
soffset
*/
,
index_t
ioffset
/*max 0xFFF*/
,
index_t
/*flag*/
=
0
)
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
asm
volatile
(
"buffer_load_dword %1, %2, %3 offen offset:%4 lds"
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
:
"=r"
(
smem
)
/*dummy dependency for smem*/
:
"v"
(
voffset
),
"s"
(
rsrc
),
"s"
(
soffset
),
"n"
(
ioffset
)
:
"v"
(
voffset
),
"s"
(
rsrc
),
"n"
(
ioffset
)
:
"memory"
);
else
asm
volatile
(
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
:
"=r"
(
smem
)
/*dummy dependency for smem*/
:
"v"
(
voffset
),
"s"
(
rsrc
),
"n"
(
ioffset
)
:
"memory"
);
}
...
...
@@ -1181,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_buffer_load_raw_impl
(
thread_buffer
<
T
,
N
>&
dst
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
flag
=
0
)
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
constexpr
index_t
bytes
=
sizeof
(
T
)
*
N
;
static_assert
(
bytes
==
1
||
bytes
==
2
||
bytes
==
4
||
bytes
==
8
||
bytes
==
16
,
...
...
@@ -1195,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
using
type
=
thread_buffer
<
T
,
N
>
;
if
constexpr
(
oob_conditional_check
)
{
buffer_load_if
<
sizeof
(
type
)
>
{}(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
);
buffer_load_if
<
sizeof
(
type
),
pre_nop
>
{}(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
,
bool_constant
<
pre_nop
>
{});
}
else
{
buffer_load
<
sizeof
(
type
)
>
{}(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
);
buffer_load
<
sizeof
(
type
),
pre_nop
>
{}(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
0
,
flag
,
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_async_buffer_load_impl
(
T
*
smem
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_immediate_addr_offset
=
0
)
index_t
src_immediate_addr_offset
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
*
N
==
4
,
"wrong! not implemented vector size"
);
async_buffer_load_dword
(
smem
,
async_buffer_load_dword
_v
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_immediate_addr_offset
);
src_immediate_addr_offset
,
0
,
bool_constant
<
pre_nop
>
{});
}
template
<
index_t
N
,
...
...
@@ -1339,7 +1481,10 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
std
::
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
uint16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
uint8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
if
constexpr
(
std
::
is_same
<
T
,
float
>::
value
)
// fp32
...
...
@@ -1478,6 +1623,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer<T, N> src_thread_d
static_cast
<
index_t
>
(
coherence
));
}
}
else
if
constexpr
(
std
::
is_same
<
T
,
uint16_t
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_store_ui16
(
bit_cast
<
uint16_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_store_ui16x2
(
bit_cast
<
uint16x2_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
4
)
{
llvm_amdgcn_raw_buffer_store_ui16x4
(
bit_cast
<
uint16x4_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
if
constexpr
(
N
==
8
)
{
llvm_amdgcn_raw_buffer_store_ui16x4
(
src_thread_data
.
template
get_as
<
uint16x4_t
>()[
number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
llvm_amdgcn_raw_buffer_store_ui16x4
(
src_thread_data
.
template
get_as
<
uint16x4_t
>()[
number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
uint16_t
),
static_cast
<
index_t
>
(
coherence
));
}
}
else
{
using
r_t
=
thread_buffer
<
int8_t
,
sizeof
(
T
)
*
N
>
;
...
...
@@ -1595,7 +1783,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
{
if
constexpr
(
N
==
2
)
{
llvm_amdgcn_raw_buffer_atomic_add_fp16x2
(
bit_cast
<
fp16_t
>
(
src_thread_data
),
llvm_amdgcn_raw_buffer_atomic_add_fp16x2
(
bit_cast
<
fp16
x2
_t
>
(
src_thread_data
),
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
...
...
@@ -1821,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_element_space_size
,
index_t
is_valid_element
=
0
)
index_t
is_valid_element
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
>
(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
is_valid_element
);
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
// This version support buffer resource as input arg
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
index_t
is_valid_element
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
// unfortunately async copy can not make sure invalid data is zero inside LDS
...
...
@@ -1843,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
// buffer_load OOB still working.
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob
(
T
*
smem
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_element_space_size
)
index_t
src_element_space_size
,
bool_constant
<
pre_nop
>
=
{})
{
const
int32x4_t
src_wave_buffer_resource
=
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
...
...
@@ -1855,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
);
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
}
// This version support buffer resource as input arg
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
bool_constant
<
pre_nop
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
}
// buffer_store requires:
...
...
include/ck_tile/core/arch/arch.hpp
View file @
66593407
...
...
@@ -79,14 +79,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
"
::
);
}
CK_TILE_DEVICE
void
s_nop
()
CK_TILE_DEVICE
void
s_nop
(
index_t
cnt
=
0
)
{
#if 1
asm
volatile
(
"\
s_nop 0
\n
\
"
::
);
asm
volatile
(
"s_nop %0"
:
:
"n"
(
cnt
)
:
);
#else
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
cnt
);
#endif
}
...
...
include/ck_tile/core/config.hpp
View file @
66593407
...
...
@@ -18,6 +18,7 @@
#define __gfx11__
#endif
#include "hip/hip_version.h"
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
...
...
@@ -144,6 +145,15 @@
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133)
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
...
...
@@ -167,7 +177,15 @@
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif
#ifndef CK_TILE_USE_PK_FP16_TILE_CAST
#define CK_TILE_USE_PK_FP16_TILE_CAST 0
#endif
// TODO: better solve this inside compiler
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
#endif
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif
include/ck_tile/core/tensor/buffer_view.hpp
View file @
66593407
...
...
@@ -68,6 +68,8 @@ struct buffer_view<address_space_enum::generic,
{
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
generic
;
...
...
@@ -223,25 +225,36 @@ struct buffer_view<address_space_enum::global,
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
int32x4_t
cached_buf_res_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
:
p_data_
{},
buffer_size_
{},
cached_buf_res_
{
0
},
invalid_element_value_
{}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
cached_buf_res_
{
0
},
invalid_element_value_
{
0
}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
cached_buf_res_
{
0
},
invalid_element_value_
{
invalid_element_value
}
{
}
// this is non constexpr intentially (will call some intrinsic internally)
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE
void
init_raw
()
{
cached_buf_res_
=
make_wave_buffer_resource
(
p_data_
,
buffer_size_
*
sizeof
(
type
));
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
global
;
...
...
@@ -332,12 +345,15 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
i
,
bool
is_valid_element
)
const
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
i
,
bool
is_valid_element
,
bool_constant
<
pre_nop
>
=
{})
const
{
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -348,18 +364,21 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
dst
,
p_data_
,
i
,
buffer_size_
,
is_valid_element
);
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
cached_buf_res_
,
i
,
is_valid_element
,
bool_constant
<
pre_nop
>
{}
);
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
async_get
(
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
bool
/*is_valid_element*/
)
const
CK_TILE_DEVICE
constexpr
auto
async_get_raw
(
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
bool
/*is_valid_element*/
,
bool_constant
<
pre_nop
>
=
{})
const
{
// X is vector of T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -370,8 +389,8 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
p_data_
,
i
,
buffer_size_
);
amd_async_buffer_load_with_oob
_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
cached_buf_res_
,
i
,
bool_constant
<
pre_nop
>
{}
);
}
// i is offset of T, not X. i should be aligned to X
...
...
@@ -626,6 +645,8 @@ struct buffer_view<address_space_enum::lds,
{
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
lds
;
...
...
@@ -908,6 +929,8 @@ struct buffer_view<address_space_enum::vgpr,
{
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
vgpr
;
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
66593407
...
...
@@ -36,30 +36,37 @@ template <typename T,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{});
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{}
,
bool_constant
<
pre_nop
>
{}
);
}
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
>
index_t
NumCoord
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
)
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
return
tile_window
.
async_load
(
lds_tile
);
return
tile_window
.
async_load_raw
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
...
...
include/ck_tile/core/tensor/null_tile_window.hpp
View file @
66593407
...
...
@@ -35,6 +35,8 @@ struct null_tile_window
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
BottomTensorIndex
{};
}
CK_TILE_DEVICE
void
init_raw
()
{}
WindowLengths
window_lengths_
;
};
...
...
include/ck_tile/core/tensor/tensor_view.hpp
View file @
66593407
...
...
@@ -33,6 +33,8 @@ struct tensor_view
{
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{
buf_
.
init_raw
();
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_tensor_descriptor
()
const
{
return
desc_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
...
...
@@ -82,30 +84,34 @@ struct tensor_view
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
const
TensorCoord
&
coord
,
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
>(
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
dst
,
coord
.
get_offset
(),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
));
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
pre_nop
>
{});
}
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
)
const
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
_raw
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
bool_constant
<
pre_nop
>
=
{}
)
const
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
true
/*not used*/
);
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
coord
.
get_offset
(),
true
/*not used*/
,
bool_constant
<
pre_nop
>
{});
}
// X is vector of DataType.
...
...
include/ck_tile/core/tensor/tile_elementwise.hpp
View file @
66593407
...
...
@@ -76,22 +76,62 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor...
template
<
typename
DstrTensors
,
index_t
v
>
CK_TILE_DEVICE
void
set_tile
(
DstrTensors
&
dstr_tensor
,
number
<
v
>
)
template
<
typename
DstrTensors
,
index_t
v
,
bool
skip_subdword_opt
=
false
>
CK_TILE_DEVICE
void
set_tile
(
DstrTensors
&
dstr_tensor
,
number
<
v
>
,
bool_constant
<
skip_subdword_opt
>
=
{})
{
constexpr
index_t
tensor_bytes
=
DstrTensors
::
get_thread_buffer_size
()
*
sizeof
(
typename
DstrTensors
::
DataType
);
if
constexpr
(
v
==
0
&&
tensor_bytes
%
4
==
0
)
using
elem_type
=
typename
DstrTensors
::
DataType
;
constexpr
index_t
elem_size
=
sizeof
(
elem_type
);
constexpr
index_t
tensor_bytes
=
DstrTensors
::
get_thread_buffer_size
()
*
elem_size
;
// # bytes per write = 4
if
constexpr
(
v
==
0
&&
tensor_bytes
%
4
==
0
&&
!
skip_subdword_opt
)
{
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
auto
&
buffer
=
dstr_tensor
.
get_thread_buffer
();
static_for
<
0
,
tensor_bytes
/
4
,
1
>
{}([
&
](
auto
i_write
)
{
if
constexpr
(
elem_size
==
1
)
{
// # elements per write = 4
constexpr
auto
values
=
ext_vector_t
<
elem_type
,
4
>
{
0
,
0
,
0
,
0
};
buffer
[
i_write
*
4
+
0
]
=
values
.
x
;
buffer
[
i_write
*
4
+
1
]
=
values
.
y
;
buffer
[
i_write
*
4
+
2
]
=
values
.
z
;
buffer
[
i_write
*
4
+
3
]
=
values
.
w
;
}
else
if
constexpr
(
elem_size
==
2
)
{
// # elements per write = 2
constexpr
auto
values
=
ext_vector_t
<
elem_type
,
2
>
{
0
,
0
};
buffer
[
i_write
*
2
+
0
]
=
values
.
x
;
buffer
[
i_write
*
2
+
1
]
=
values
.
y
;
}
else
if
constexpr
(
elem_size
==
4
)
{
// # elements per write = 1
constexpr
elem_type
value
=
0
;
buffer
[
i_write
]
=
value
;
}
else
{
static_assert
(
false
,
"type not supported"
);
}
});
#else
using
dvec_t
=
array
<
index_t
,
tensor_bytes
/
4
>
;
auto
&
tensor
=
reinterpret_cast
<
dvec_t
&>
(
dstr_tensor
.
get_thread_buffer
());
for
(
auto
i
=
0
;
i
<
tensor
.
size
();
i
++
)
tensor
.
get
(
i
)
=
v
;
#endif
}
else
{
tile_elementwise_inout
(
[](
auto
&
x
)
{
x
=
type_convert
<
typename
DstrTensors
::
DataType
,
index_t
>
(
v
);
},
tile_elementwise_inout
([](
auto
&
x
)
{
x
=
type_convert
<
elem_type
,
index_t
>
(
v
);
},
dstr_tensor
);
}
}
...
...
@@ -110,7 +150,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
namespace
impl
{
// TODO: this is ugly
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp8
x4
(
const
InTensor
&
in_dstr_tensors
)
CK_TILE_DEVICE
auto
cast_tile_pk_fp8
_fp32
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
...
...
@@ -156,6 +196,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
#endif
}
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp16_fp32
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
constexpr
index_t
thread_buffer_size
=
InTensor
::
get_thread_buffer_size
();
static_assert
(
thread_buffer_size
%
2
==
0
);
constexpr
index_t
thread_buffer_size_pk
=
thread_buffer_size
/
2
;
auto
out_dstr_tensor
=
make_static_distributed_tensor
<
OutDataType
>
(
in_tile_dstr
);
// TODO: this is rtz cvt, need be very careful
for
(
index_t
i
=
0
;
i
<
thread_buffer_size_pk
;
i
++
)
{
auto
o
=
__builtin_amdgcn_cvt_pkrtz
(
in_dstr_tensors
.
get_thread_buffer
()[
2
*
i
+
0
],
in_dstr_tensors
.
get_thread_buffer
()[
2
*
i
+
1
]);
out_dstr_tensor
.
get_thread_buffer
().
at
(
2
*
i
+
0
)
=
o
.
x
;
out_dstr_tensor
.
get_thread_buffer
().
at
(
2
*
i
+
1
)
=
o
.
y
;
}
return
out_dstr_tensor
;
#else
// fallback
return
tile_elementwise_in
(
type_convert
<
OutDataType
,
typename
InTensor
::
DataType
>
,
in_dstr_tensors
);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
...
...
@@ -229,8 +300,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
4
==
0
))
{
return
impl
::
cast_tile_pk_fp8
x4
<
DstType
,
SrcTensor
>
(
src_tensor
);
return
impl
::
cast_tile_pk_fp8
_fp32
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#if CK_TILE_USE_PK_FP16_TILE_CAST
else
if
constexpr
(
std
::
is_same_v
<
DstType
,
fp16_t
>
&&
std
::
is_same_v
<
typename
SrcTensor
::
DataType
,
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
2
==
0
))
{
return
impl
::
cast_tile_pk_fp16_fp32
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#endif
#if CK_TILE_USE_SUBDWORD_TILE_CAST
else
if
constexpr
(
sizeof
(
DstType
)
<
4
||
sizeof
(
typename
SrcTensor
::
DataType
)
<
4
)
{
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
66593407
...
...
@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution
return
dst_tensor
;
}
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
>
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
Traits
=
load_store_traits
;
...
...
@@ -374,6 +375,12 @@ struct tile_window_with_static_distribution
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
...
...
@@ -384,8 +391,12 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
...
...
@@ -399,12 +410,17 @@ struct tile_window_with_static_distribution
}
});
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
"; this inline asm is workaround to prevent compiler from using too much "
"scratch memory"
::
);
#endif
}
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
bool_constant
<
oob_conditional_check
>
=
{})
const
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
...
...
@@ -450,10 +466,16 @@ struct tile_window_with_static_distribution
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
);
get_bottom_tensor_view
().
template
async_get_vectorized_elements
_raw
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
pre_nop_
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
@@ -608,6 +630,67 @@ struct tile_window_with_static_distribution
});
}
CK_TILE_DEVICE
void
set_window_origin
(
const
BottomTensorIndex
&
new_window_origin
)
{
window_origin_
=
new_window_origin
;
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
tile_dstr_
.
get_ps_ys_to_xs_adaptor
(),
container_concat
(
detail
::
get_partition_index
(
tile_dstr_
),
array
<
index_t
,
NDimY
>
{
0
}));
#endif
BottomTensorIndex
bottom_tensor_thread_origin_idx_tmp
=
window_origin_
+
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
const
auto
bottom_tensor_thread_coord_tmp
=
make_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_origin_idx_tmp
);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using
Traits
=
load_store_traits
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
window_adaptor_thread_coord
=
window_adaptor_thread_coord_tmp
;
auto
bottom_tensor_thread_coord
=
bottom_tensor_thread_coord_tmp
;
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_step_between
(
number
<
0
>
{},
number
<
iCoord
*
NumAccessPerCoord
>
{});
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
pre_computed_coords_
(
iCoord
)
=
make_tuple
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
);
});
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{
bottom_tensor_view_
.
init_raw
();
}
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView
bottom_tensor_view_
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
66593407
...
...
@@ -78,6 +78,12 @@ struct BlockFmhaPipelineQRKSVSAsync
return
Problem
::
kBlockPerCu
;
else
{
// minimize occupancy
if
constexpr
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
)
{
return
1
;
}
if
constexpr
(
kK0BlockLength
<=
32
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
...
...
@@ -212,11 +218,14 @@ struct BlockFmhaPipelineQRKSVSAsync
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
q_dram_window
.
init_raw
();
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto
q
=
decltype
(
load_tile
(
q_dram_window
)){};
set_tile
(
q
,
number
<
0
>
{});
// use per-dword clear to avoid scratch
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw
(
q
,
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -285,6 +294,16 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
k_dram_window
.
init_raw
();
constexpr
auto
k_oob_ck
=
bool_constant
<
true
>
{};
constexpr
auto
k_pre_np
=
[
&
]()
{
if
constexpr
(
kPadSeqLenK
&&
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
)))
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
...
...
@@ -299,7 +318,7 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// prefetch K tile
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -322,7 +341,9 @@ struct BlockFmhaPipelineQRKSVSAsync
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
async_load_tile_raw
(
k_lds_store
(
number
<
LdsSeq
.
at
(
number
<
i_k0
+
1
>
{})
>
{}),
k_dram_window
);
k_dram_window
,
k_oob_ck
,
k_pre_np
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
...
...
@@ -609,16 +630,13 @@ struct BlockFmhaPipelineQRKSVSAsync
{
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
k_dram_window
.
set_window_origin
(
k_dram_block_window
.
get_window_origin
());
if
constexpr
(
k1_loops
>=
2
&&
LdsSeq
.
at
(
number
<
0
>
{})
==
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
2
>
{}))
__builtin_amdgcn_s_barrier
();
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
}
// tail
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment