Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
eed60199
Commit
eed60199
authored
Sep 13, 2024
by
carlushuang
Browse files
more robust api
parent
cae751d1
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1727 additions
and
210 deletions
+1727
-210
example/ck_tile/05_moe/topk_softmax_api.cpp
example/ck_tile/05_moe/topk_softmax_api.cpp
+25
-1
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/algorithm/space_filling_curve.hpp
include/ck_tile/core/algorithm/space_filling_curve.hpp
+23
-2
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+54
-45
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+24
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+74
-53
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+84
-1
include/ck_tile/core/tensor/store_tile.hpp
include/ck_tile/core/tensor/store_tile.hpp
+27
-0
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+156
-10
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+2
-0
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+53
-11
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+1055
-0
include/ck_tile/core/utility/magic_div.hpp
include/ck_tile/core/utility/magic_div.hpp
+22
-5
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+23
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
...ile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
+16
-16
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+75
-55
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+2
-2
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+1
-1
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
+8
-5
include/ck_tile/ops/reduce/block/block_reduce.hpp
include/ck_tile/ops/reduce/block/block_reduce.hpp
+2
-3
No files found.
example/ck_tile/05_moe/topk_softmax_api.cpp
View file @
eed60199
...
...
@@ -25,7 +25,7 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
using
ts_input_type
=
ck_tile
::
fp16_t
;
using
ts_weight_type
=
float
;
using
ts_index_type
=
ck_tile
::
index_t
;
#if 1
if
(
t
.
experts
<=
8
)
{
TOPK_SOFTMAX_DISPATCH
(
8
)
...
...
@@ -42,9 +42,24 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
{
TOPK_SOFTMAX_DISPATCH
(
64
)
}
else
if
(
t
.
experts
<=
128
)
{
TOPK_SOFTMAX_DISPATCH
(
128
)
}
else
if
(
t
.
experts
<=
192
)
{
TOPK_SOFTMAX_DISPATCH
(
192
)
}
#else
if
(
t
.
experts
<=
16
)
{
TOPK_SOFTMAX_DISPATCH
(
16
)
}
#endif
}
else
if
(
t
.
input_type
==
"bf16"
&&
t
.
weight_type
==
"fp32"
)
{
#if 1
using
ts_input_type
=
ck_tile
::
bf16_t
;
using
ts_weight_type
=
float
;
using
ts_index_type
=
ck_tile
::
index_t
;
...
...
@@ -64,6 +79,15 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
{
TOPK_SOFTMAX_DISPATCH
(
64
)
}
else
if
(
t
.
experts
<=
128
)
{
TOPK_SOFTMAX_DISPATCH
(
128
)
}
else
if
(
t
.
experts
<=
192
)
{
TOPK_SOFTMAX_DISPATCH
(
192
)
}
#endif
}
return
-
1
;
}
include/ck_tile/core.hpp
View file @
eed60199
...
...
@@ -50,6 +50,7 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
...
...
include/ck_tile/core/algorithm/space_filling_curve.hpp
View file @
eed60199
...
...
@@ -66,6 +66,20 @@ struct space_filling_curve
return
idx_tail
-
idx_head
;
}
template
<
index_t
AccessIdx1dHead
,
index_t
AccessIdx1dTail
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_step_between_static
(
number
<
AccessIdx1dHead
>
,
number
<
AccessIdx1dTail
>
)
{
static_assert
(
AccessIdx1dHead
>=
0
&&
AccessIdx1dHead
<
get_num_of_access
(),
"1D index out of range"
);
static_assert
(
AccessIdx1dTail
>=
0
&&
AccessIdx1dTail
<
get_num_of_access
(),
"1D index out of range"
);
constexpr
auto
idx_head
=
get_index_static
(
number
<
AccessIdx1dHead
>
{});
constexpr
auto
idx_tail
=
get_index_static
(
number
<
AccessIdx1dTail
>
{});
return
idx_tail
-
idx_head
;
}
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_forward_step
(
number
<
AccessIdx1d
>
)
{
...
...
@@ -73,6 +87,13 @@ struct space_filling_curve
return
get_step_between
(
number
<
AccessIdx1d
>
{},
number
<
AccessIdx1d
+
1
>
{});
}
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_forward_step_static
(
number
<
AccessIdx1d
>
)
{
static_assert
(
AccessIdx1d
<
get_num_of_access
(),
"1D index should be larger than 0"
);
return
get_step_between_static
(
number
<
AccessIdx1d
>
{},
number
<
AccessIdx1d
+
1
>
{});
}
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_backward_step
(
number
<
AccessIdx1d
>
)
{
...
...
@@ -153,9 +174,9 @@ struct space_filling_curve
return
idx_md
;
}
// FIXME: re
name this function
// FIXME: re
turn tuple of number<>, which is compile time only variable
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_index_
tuple_of_number
(
number
<
AccessIdx1d
>
)
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_index_
static
(
number
<
AccessIdx1d
>
)
{
constexpr
auto
idx
=
get_index
(
number
<
AccessIdx1d
>
{});
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
eed60199
...
...
@@ -156,8 +156,8 @@ struct buffer_load<2, pre_nop>
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
// subdword is buggy, use dword buf and convert manually
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
2
,
T
>::
payload_t
;
//
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using
mbuf_t
=
ushort
;
//
typename impl::buffer_load_trait<2, T>::payload_t;
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
...
...
@@ -315,9 +315,9 @@ struct buffer_load_if<2, pre_nop>
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
//
static_assert(sizeof(T) == 4);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
2
,
T
>::
payload_t
;
using
mbuf_t
=
ushort
;
//
typename impl::buffer_load_trait<2, T>::payload_t;
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
...
...
@@ -676,19 +676,17 @@ template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; };
}
// namespace impl
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
template
<
index_t
>
struct
smem_load
;
template
<
index_t
>
struct
smem_load
;
template
<
>
template
<
>
struct
smem_load
<
16
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
16
,
T
>::
payload_t
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
16
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_b128 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
...
...
@@ -700,9 +698,7 @@ template <>
struct
smem_load
<
8
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
8
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
8
,
T
>::
payload_t
;
...
...
@@ -717,9 +713,7 @@ template <>
struct
smem_load
<
4
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
4
,
T
>::
payload_t
;
...
...
@@ -734,11 +728,10 @@ template <>
struct
smem_load
<
2
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
4
);
// subdword is buggy, use dword buf and convert manually
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
1
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_u16 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
...
...
@@ -750,9 +743,7 @@ template <>
struct
smem_load
<
1
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
1
,
T
>::
payload_t
;
...
...
@@ -1879,6 +1870,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
,
index_t
dst_linear_addr_offset
,
index_t
is_valid_element
=
1
)
{
constexpr
index_t
bytes
=
sizeof
(
T
)
*
N
;
...
...
@@ -1892,7 +1884,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
,
dst_linear_addr_offset
,
is_valid_element
);
}
else
...
...
@@ -1901,7 +1893,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
dst_linear_addr_offset
);
}
}
...
...
@@ -2266,6 +2258,7 @@ template <typename T,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
index_t
src_element_space_size
,
bool_constant
<
pre_nop
>
=
{})
{
...
...
@@ -2273,9 +2266,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
src_linear_addr_offset
,
bool_constant
<
pre_nop
>
{});
}
// This version support buffer resource as input arg
...
...
@@ -2286,12 +2284,18 @@ template <typename T,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
bool_constant
<
pre_nop
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
src_linear_addr_offset
,
bool_constant
<
pre_nop
>
{});
}
// This version support buffer resource as input arg
...
...
@@ -2302,16 +2306,18 @@ template <typename T,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob
(
CK_TILE_LDS_ADDR
T
*
smem
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_async_buffer_load
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
src_linear_addr_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
...
...
@@ -2368,6 +2374,7 @@ template <typename T,
CK_TILE_DEVICE
void
amd_buffer_store_raw
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
index_t
dst_linear_element_offset
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
)
{
...
...
@@ -2375,11 +2382,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
*
sizeof
(
T
));
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
index_t
dst_linear_addr_offset
=
dst_linear_element_offset
*
sizeof
(
T
);
amd_buffer_store_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
,
dst_linear_addr_offset
,
dst_thread_element_valid
);
}
...
...
include/ck_tile/core/container/tuple.hpp
View file @
eed60199
...
...
@@ -635,6 +635,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
return
r
;
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
+
y
[
i
];
},
number
<
NSize
>
{});
}
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
...
...
@@ -649,6 +657,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
return
r
;
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
-
y
[
i
];
},
number
<
NSize
>
{});
}
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
...
...
@@ -686,6 +702,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
return
a
*
x
;
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
*
y
[
i
];
},
number
<
NSize
>
{});
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
...
...
include/ck_tile/core/tensor/buffer_view.hpp
View file @
eed60199
...
...
@@ -91,8 +91,10 @@ struct buffer_view<address_space_enum::generic,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -107,11 +109,11 @@ struct buffer_view<address_space_enum::generic,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
+
linear_offset
]),
sizeof
(
X
));
return
tmp
;
#else
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
]);
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
+
linear_offset
]);
#endif
}
else
...
...
@@ -134,17 +136,17 @@ struct buffer_view<address_space_enum::generic,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
auto
tmp
=
this
->
template
get
<
X
>(
i
,
linear_offset
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
}
}
...
...
@@ -154,7 +156,7 @@ struct buffer_view<address_space_enum::generic,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
set
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -169,9 +171,9 @@ struct buffer_view<address_space_enum::generic,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
__builtin_memcpy
(
&
(
p_data_
[
i
+
linear_offset
]),
&
tmp
,
sizeof
(
X
));
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
+
linear_offset
])
=
x
;
#endif
}
}
...
...
@@ -276,8 +278,10 @@ struct buffer_view<address_space_enum::global,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -303,7 +307,7 @@ struct buffer_view<address_space_enum::global,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
);
}
else
{
...
...
@@ -311,8 +315,11 @@ struct buffer_view<address_space_enum::global,
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
p_data_
,
i
,
is_valid_element
,
buffer_size_
,
invalid_element_value_
);
oob_conditional_check
>
(
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
,
invalid_element_value_
);
}
}
else
...
...
@@ -322,11 +329,11 @@ struct buffer_view<address_space_enum::global,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
+
linear_offset
]),
sizeof
(
X
));
return
tmp
;
#else
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
]);
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
+
linear_offset
]);
#endif
}
else
...
...
@@ -379,6 +386,7 @@ struct buffer_view<address_space_enum::global,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
async_get
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
...
...
@@ -392,7 +400,12 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
cached_buf_res_
,
i
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
smem
,
cached_buf_res_
,
i
,
linear_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
// i is offset of T, not X. i should be aligned to X
...
...
@@ -404,6 +417,7 @@ struct buffer_view<address_space_enum::global,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
async_get_raw
(
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
index_t
linear_offset
,
bool
/*is_valid_element*/
,
bool_constant
<
pre_nop
>
=
{})
const
{
...
...
@@ -417,7 +431,7 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
cached_buf_res_
,
i
,
bool_constant
<
pre_nop
>
{});
smem
,
cached_buf_res_
,
i
,
linear_offset
,
bool_constant
<
pre_nop
>
{});
}
// i is offset of T, not X. i should be aligned to X
...
...
@@ -427,11 +441,11 @@ struct buffer_view<address_space_enum::global,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
{
...
...
@@ -458,7 +472,7 @@ struct buffer_view<address_space_enum::global,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
set
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -479,7 +493,7 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
);
}
else
{
...
...
@@ -488,9 +502,9 @@ struct buffer_view<address_space_enum::global,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
__builtin_memcpy
(
&
(
p_data_
[
i
+
linear_offset
]),
&
tmp
,
sizeof
(
X
));
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
+
linear_offset
])
=
x
;
#endif
}
}
...
...
@@ -503,7 +517,7 @@ struct buffer_view<address_space_enum::global,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set_raw
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
set_raw
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -515,7 +529,7 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
,
linear_offset
,
is_valid_element
,
buffer_size_
);
}
template
<
typename
X
,
...
...
@@ -523,7 +537,8 @@ struct buffer_view<address_space_enum::global,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
atomic_add
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
atomic_add
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
using
scalar_t
=
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
;
...
...
@@ -558,13 +573,13 @@ struct buffer_view<address_space_enum::global,
if
constexpr
(
use_amd_buffer_addressing
)
{
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
);
}
else
{
if
(
is_valid_element
)
{
atomic_add_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
],
x
);
atomic_add_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
+
linear_offset
],
x
);
}
}
}
...
...
@@ -574,7 +589,8 @@ struct buffer_view<address_space_enum::global,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
atomic_max
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
atomic_max
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -598,11 +614,11 @@ struct buffer_view<address_space_enum::global,
if
constexpr
(
use_amd_buffer_addressing
)
{
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
);
}
else
if
(
is_valid_element
)
{
atomic_max_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
],
x
);
atomic_max_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
+
linear_offset
],
x
);
}
}
...
...
@@ -694,8 +710,10 @@ struct buffer_view<address_space_enum::lds,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -710,14 +728,14 @@ struct buffer_view<address_space_enum::lds,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
+
linear_offset
]),
sizeof
(
X
));
return
tmp
;
#else
using
buf_t
=
ext_vector_t
<
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
,
scalar_per_t_vector
*
scalar_per_x_vector
>
;
// using buf_t = ushort __attribute__((ext_vector_type(8)));
auto
rtn
=
*
c_style_pointer_cast
<
const
buf_t
*>
(
&
p_data_
[
i
]);
auto
rtn
=
*
c_style_pointer_cast
<
const
buf_t
*>
(
&
p_data_
[
i
+
linear_offset
]);
return
bit_cast
<
X
>
(
rtn
);
#endif
}
...
...
@@ -745,7 +763,7 @@ struct buffer_view<address_space_enum::lds,
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
v_offset
,
index_t
i_offset
,
bool
is_valid_element
,
bool
/*
is_valid_element
*/
,
bool_constant
<
pre_nop
>
=
{})
const
{
#if 0
...
...
@@ -768,17 +786,17 @@ struct buffer_view<address_space_enum::lds,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
auto
tmp
=
this
->
template
get
<
X
>(
i
,
linear_offset
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
}
}
...
...
@@ -788,7 +806,7 @@ struct buffer_view<address_space_enum::lds,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
set
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -804,6 +822,7 @@ struct buffer_view<address_space_enum::lds,
bool
constexpr
workaround_int8_ds_write_issue
=
false
;
#endif
i
+=
linear_offset
;
// simplicity
if
constexpr
(
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
,
int8_t
>::
value
&&
workaround_int8_ds_write_issue
)
...
...
@@ -1005,8 +1024,10 @@ struct buffer_view<address_space_enum::vgpr,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
index_t
/*linear_offset*/
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -1048,17 +1069,17 @@ struct buffer_view<address_space_enum::vgpr,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
auto
tmp
=
this
->
template
get
<
X
>(
i
,
linear_offset
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
}
}
...
...
@@ -1068,7 +1089,7 @@ struct buffer_view<address_space_enum::vgpr,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
set
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
...
@@ -1083,9 +1104,9 @@ struct buffer_view<address_space_enum::vgpr,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
__builtin_memcpy
(
&
(
p_data_
[
i
+
linear_offset
]),
&
tmp
,
sizeof
(
X
));
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
+
linear_offset
])
=
x
;
#endif
}
}
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
eed60199
...
...
@@ -11,7 +11,7 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window
_linear
.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
...
...
@@ -31,6 +31,20 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
return
tile_window
.
load
(
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
T
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
...
...
@@ -49,6 +63,24 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
T
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
// while creating the smem window, which can enable compiler properly detect the
// dependency if using multiple smem window (multiple buffer)
...
...
@@ -69,6 +101,22 @@ async_load_tile(LdsTileWindow_&& lds_tile,
return
tile_window
.
async_load
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load_tile
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
async_load
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
...
...
@@ -89,6 +137,25 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
return
tile_window
.
async_load_raw
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
auto
load_tile
(
const
null_tile_window
<
WindowLengths
>&
)
{
...
...
@@ -100,4 +167,20 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<Windo
{
}
// TODO: this function requires some sub-fileds exist for the target tile window
template
<
typename
TileWindow
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
const
TileWindow
&
w
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
using
TileDstr
=
typename
TileWindow
::
TileDstr
;
using
DataType
=
typename
TileWindow
::
DataType
;
auto
t
=
make_static_distributed_tensor
<
DataType
>
(
TileDstr
{});
load_tile_raw
(
t
,
w
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
return
t
;
}
}
// namespace ck_tile
include/ck_tile/core/tensor/store_tile.hpp
View file @
eed60199
...
...
@@ -10,6 +10,7 @@
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
...
...
@@ -90,4 +91,30 @@ store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
tile_window
.
store_raw
(
dstr_tensor
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
tile_window
.
store
(
dstr_tensor
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile_raw
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
tile_window
.
store_raw
(
dstr_tensor
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/tensor_view.hpp
View file @
eed60199
...
...
@@ -75,14 +75,34 @@ struct tensor_view
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
X
>
get_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
return
buf_
.
template
get
<
X
>(
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
X
>
get_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
// flag
bool_constant
<
oob_conditional_check
>
=
{})
const
{
return
buf_
.
template
get
<
X
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
...
...
@@ -106,6 +126,24 @@ struct tensor_view
bool_constant
<
pre_nop
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
dst
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
...
...
@@ -114,26 +152,71 @@ struct tensor_view
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
)
const
const
TensorCoord
&
coord
,
index_t
linear_offset
)
const
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
)
const
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements_raw
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
bool_constant
<
pre_nop
>
=
{})
const
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements_raw
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool_constant
<
pre_nop
>
=
{})
const
{
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
coord
.
get_offset
(),
true
/*not used*/
,
bool_constant
<
pre_nop
>
{});
smem
,
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
pre_nop
>
{});
}
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements_raw
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
pre_nop
>
=
{})
const
{
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
// X is vector of DataType.
...
...
@@ -144,11 +227,15 @@ struct tensor_view
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements
(
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
...
...
@@ -159,15 +246,53 @@ struct tensor_view
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
...
...
@@ -176,15 +301,36 @@ struct tensor_view
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements
(
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"tensor_view{"
);
...
...
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
eed60199
...
...
@@ -454,6 +454,7 @@ struct tile_distribution_detail
}
// namespace detail
#if 0
// this returns a constexpr tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
...
...
@@ -490,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
#endif
// this returns a static tile_distribution
template
<
typename
StaticTileDistributionEncoding_
>
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
eed60199
...
...
@@ -223,10 +223,11 @@ struct tile_window_with_static_distribution
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template
<
typename
ATopIndex
>
CK_TILE_DEVICE
void
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
WindowAdaptorCoord
&
window_adaptor_thread_coord
,
BottomTensorCoord
&
bottom_tensor_thread_coord
,
const
A
daptor
TopIndex
&
idx_diff_adaptor_top
)
const
const
ATopIndex
&
idx_diff_adaptor_top
)
const
{
array
<
index_t
,
NDimBottomTensor
>
idx_diff_adaptor_bottom
;
...
...
@@ -309,7 +310,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
bottom_tensor_thread_coord
,
0
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
// write into distributed tensor
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
...
...
@@ -337,10 +338,11 @@ struct tile_window_with_static_distribution
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
_static
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
...
...
@@ -398,7 +400,7 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
bottom_tensor_thread_coord
,
/**/
,
0
/**/
,
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
...
...
@@ -484,7 +486,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements_raw
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
pre_nop_
);
smem
,
bottom_tensor_thread_coord
,
0
,
pre_nop_
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
@@ -552,7 +554,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
smem
,
bottom_tensor_thread_coord
,
0
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
@@ -618,7 +620,10 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
@@ -676,7 +681,7 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view
()
.
template
set_vectorized_elements_raw
<
vector_t
,
oob_conditional_check
>(
bottom_tensor_thread_coord
,
vec_value
);
bottom_tensor_thread_coord
,
0
,
vec_value
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
@@ -736,7 +741,10 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
@@ -868,6 +876,27 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view
,
window_lengths
,
origin
,
tile_distribution
};
}
// this version must not be called under a constexpr context
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
index_t
NumCoord
=
1
>
CK_TILE_DEVICE
auto
make_tile_window_raw
(
const
TensorView_
&
tensor_view
,
const
WindowLengths_
&
window_lengths
,
const
multi_index
<
TensorView_
::
get_num_of_dimension
()
>&
origin
,
const
StaticTileDistribution_
&
tile_distribution
,
number
<
NumCoord
>
=
{})
{
auto
w
=
tile_window_with_static_distribution
<
remove_cvref_t
<
TensorView_
>
,
remove_cvref_t
<
WindowLengths_
>
,
remove_cvref_t
<
StaticTileDistribution_
>
,
NumCoord
>
{
tensor_view
,
window_lengths
,
origin
,
tile_distribution
};
w
.
init_raw
();
return
w
;
}
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
...
...
@@ -992,6 +1021,19 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
tile_distribution
);
}
template
<
typename
TensorView
,
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window_raw
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
StaticTileDistribution
&
tile_distribution
)
{
auto
w
=
make_tile_window
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
tile_window
.
get_window_origin
(),
tile_distribution
);
w
.
init_raw
();
return
w
;
}
template
<
typename
TensorView_
,
typename
WindowLengths_
>
CK_TILE_DEVICE
void
move_tile_window
(
tile_window_with_static_lengths
<
TensorView_
,
WindowLengths_
>&
window
,
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
0 → 100644
View file @
eed60199
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
//
// This version of tile window will pre-cache offset/flags based on need
//
// LinearBottomDims_, e.g seq<0, 1> for 2d tensor, the last one is linear dim
// so last dim can use immediate offset to indexing, can save register
// TODO: if using this struct, better use load_raw()/store_raw(), can control
// the the immediate offset on the fly
// space-filing-curve is non-snaked here!
//
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
typename
LinearBottomDims_
>
struct
tile_window_linear
{
using
BottomTensorView
=
remove_reference_t
<
BottomTensorView_
>
;
using
WindowLengths
=
remove_cvref_t
<
WindowLengths_
>
;
using
TileDstr
=
remove_cvref_t
<
StaticTileDistribution_
>
;
using
WindowAdaptor
=
typename
TileDstr
::
PsYs2XsAdaptor
;
using
BottomTensorDesc
=
typename
BottomTensorView
::
TensorDesc
;
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView
::
DataType
>
;
using
LinearBottomDims
=
remove_cvref_t
<
LinearBottomDims_
>
;
static_assert
(
LinearBottomDims
::
size
()
==
BottomTensorView
::
get_num_of_dimension
());
static
constexpr
index_t
NDimWindowAdaptorTop
=
WindowAdaptor
::
get_num_of_top_dimension
();
static
constexpr
index_t
NDimBottomTensor
=
BottomTensorDesc
::
get_num_of_dimension
();
static
constexpr
index_t
NDimP
=
TileDstr
::
get_num_of_dimension_p
();
static
constexpr
index_t
NDimY
=
TileDstr
::
get_num_of_dimension_y
();
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths
>::
value
,
"wrong! lengths should be static"
);
static_assert
(
TileDstr
::
is_static
(),
"wrong!"
);
static_assert
(
NDimBottomTensor
==
WindowAdaptor
::
get_num_of_bottom_dimension
(),
"wrong! inconsistent # of diemsnions"
);
using
AdaptorTopIndex
=
array
<
index_t
,
NDimWindowAdaptorTop
>
;
using
BottomTensorIndex
=
array
<
index_t
,
NDimBottomTensor
>
;
using
WindowAdaptorCoord
=
decltype
(
make_tensor_adaptor_coordinate
(
WindowAdaptor
{},
AdaptorTopIndex
{}));
using
BottomTensorCoord
=
decltype
(
make_tensor_coordinate
(
BottomTensorDesc
{},
BottomTensorIndex
{}));
struct
traits
{
private:
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE
static
constexpr
auto
get_window_adaptor_ys_safe_vector_length_strides
()
{
// bottom tensor top dimension vector lengths and strides
const
auto
[
bottom_tensor_top_dim_vector_lengths
,
bottom_tensor_top_dim_vector_strides
]
=
BottomTensorDesc
::
get_top_dimension_safe_vector_length_strides
();
// window vector lengths/strides
const
auto
window_adaptor_bottom_dim_vector_lengths
=
bottom_tensor_top_dim_vector_lengths
;
const
auto
window_adaptor_bottom_dim_vector_strides
=
bottom_tensor_top_dim_vector_strides
;
// window adaptor [p0, p1, ..., y0, y1, ...]
array
<
index_t
,
WindowAdaptor
::
get_num_of_hidden_dimension
()
>
window_adaptor_vector_lengths
{
-
1
};
array
<
index_t
,
WindowAdaptor
::
get_num_of_hidden_dimension
()
>
window_adaptor_vector_strides
{
-
1
};
constexpr
auto
window_adaptor_bottom_dims
=
WindowAdaptor
::
get_bottom_dimension_hidden_ids
();
set_container_subset
(
window_adaptor_vector_lengths
,
window_adaptor_bottom_dims
,
window_adaptor_bottom_dim_vector_lengths
);
set_container_subset
(
window_adaptor_vector_strides
,
window_adaptor_bottom_dims
,
window_adaptor_bottom_dim_vector_strides
);
const
auto
[
window_adaptor_ps_ys_vector_lengths
,
window_adaptor_ps_ys_vector_strides
]
=
WindowAdaptor
{}.
get_top_dimension_safe_vector_length_strides
(
window_adaptor_vector_lengths
,
window_adaptor_vector_strides
);
// [y0, y1, ...]
constexpr
auto
y_dims
=
typename
arithmetic_sequence_gen
<
TileDstr
::
get_num_of_dimension_p
(),
NDimWindowAdaptorTop
,
1
>::
type
{};
return
make_tuple
(
get_container_subset
(
window_adaptor_ps_ys_vector_lengths
,
y_dims
),
get_container_subset
(
window_adaptor_ps_ys_vector_strides
,
y_dims
));
}
static
constexpr
auto
get_vector_dim_y_scalar_per_vector
()
{
const
auto
[
ys_vector_lengths
,
ys_vector_strides
]
=
get_window_adaptor_ys_safe_vector_length_strides
();
index_t
VectorDimY_
=
0
;
index_t
ScalarPerVector_
=
1
;
for
(
index_t
i
=
0
;
i
<
NDimY
;
++
i
)
{
if
(
ys_vector_strides
[
i
]
==
1
&&
ys_vector_lengths
[
i
]
>
ScalarPerVector_
)
{
ScalarPerVector_
=
ys_vector_lengths
[
i
];
VectorDimY_
=
i
;
}
}
return
make_tuple
(
VectorDimY_
,
ScalarPerVector_
);
}
public:
static
constexpr
index_t
VectorDimY
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
0
>();
static
constexpr
index_t
ScalarPerVector
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
1
>();
using
vector_t
=
thread_buffer
<
DataType
,
ScalarPerVector
>
;
private:
static
constexpr
auto
scalars_per_access_
=
[]
{
constexpr
auto
scalars_per_access_arr
=
generate_array
(
[
&
](
auto
i
)
{
return
(
i
==
VectorDimY
)
?
ScalarPerVector
:
1
;
},
number
<
NDimY
>
{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr
auto
NDimY_
=
NDimY
;
return
TO_SEQUENCE
(
scalars_per_access_arr
,
NDimY_
);
}();
static
constexpr
auto
get_space_filling_curve
()
{
constexpr
auto
thread_tensor_lengths_ys
=
to_sequence
(
TileDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
// FIXME: need logic to judge dim access order
using
DimAccessOrder
=
typename
arithmetic_sequence_gen
<
0
,
NDimY
,
1
>::
type
;
return
space_filling_curve
<
decltype
(
thread_tensor_lengths_ys
),
DimAccessOrder
,
decltype
(
scalars_per_access_
),
false
/*!!! no snaked curve! */
>
{};
}
public:
using
SFC_Ys
=
decltype
(
get_space_filling_curve
());
static
constexpr
index_t
NumAccess
=
SFC_Ys
::
get_num_of_access
();
static_assert
(
0
<
NumAccess
,
"Wrong! NumAccess should be larger than 0"
);
private:
static
constexpr
auto
get_num_non_linear_access
()
{
constexpr
auto
sfc_access_lens
=
SFC_Ys
::
access_lengths
;
using
ys_to_rhs_major
=
typename
decltype
(
TileDstr
{}.
get_static_tile_distribution_encoding
())
::
Ys2RHsMajor
;
constexpr
auto
non_linear
=
[
&
]()
{
index_t
cnt
=
1
;
static_for
<
0
,
NDimY
,
1
>
{}([
&
](
auto
i_dim_y
)
{
constexpr
auto
rhs_major
=
ys_to_rhs_major
{}[
i_dim_y
];
constexpr
auto
target_h_dim
=
number
<
rhs_major
-
1
>
{};
// no r dim here!
if
constexpr
(
LinearBottomDims
{}[
target_h_dim
]
==
0
)
{
cnt
*=
sfc_access_lens
[
i_dim_y
];
}
});
return
cnt
;
}();
return
non_linear
;
}
// example:
// non_linear_access_map: sequence<0, 0, 0, 0, 1, 1, 1, 1> for 8 access, totally 2 register
// used
// -> histogram : sequence<4, 4>
// -> prefixsum : seqneuce<0, 4, 8>
// non_linear_access_map: sequence<0, 1, 2, 3, 4, 5, 6, 7> for 8 access, totally 8 register
// used, will pre-cache 8
// -> histogram : sequence<1, 1, 1, 1, 1, 1, 1, 1>
// -> prefixsum : seqneuce<0, 1, 2, 3, 4, 5, 6, 7, 8>
// non_linear_access_map: sequence<0, 0, 1, 1, 2, 2, 3, 3> for 8 access, totally 4 register
// used, will pre-cache 4
// -> histogram : sequence<2, 2, 2, 2>
// -> prefixsum : seqneuce<0, 2, 4, 6, 8>
static
constexpr
auto
get_non_linear_access_map
()
{
constexpr
auto
sfc_access_lens
=
SFC_Ys
::
access_lengths
;
using
ys_to_rhs_major
=
typename
decltype
(
TileDstr
{}.
get_static_tile_distribution_encoding
())
::
Ys2RHsMajor
;
constexpr
auto
non_linear_map
=
[
&
]()
{
array
<
index_t
,
NumAccess
>
m_
{
0
};
index_t
cumulative_len_
=
1
;
index_t
cumulative_non_linear_len_
=
1
;
static_for
<
0
,
NDimY
,
1
>
{}([
&
](
auto
i_y
)
{
constexpr
auto
i_dim_y
=
number
<
NDimY
-
i_y
-
1
>
{};
// from right to left
constexpr
auto
rhs_major
=
ys_to_rhs_major
{}[
i_dim_y
];
constexpr
auto
target_h_dim
=
number
<
rhs_major
-
1
>
{};
// no r dim here!
constexpr
auto
is_linear_dim
=
LinearBottomDims
{}[
target_h_dim
];
array
<
index_t
,
NumAccess
>
current_m_
{
0
};
constexpr
auto
current_len_
=
sfc_access_lens
[
i_dim_y
];
// copy cumulative length as current pattern
for
(
auto
i_
=
0
;
i_
<
cumulative_len_
;
i_
++
)
{
current_m_
(
i_
)
=
m_
[
i_
];
}
for
(
auto
j_
=
0
;
j_
<
current_len_
;
j_
++
)
{
auto
j_offset_
=
is_linear_dim
?
0
:
j_
*
cumulative_non_linear_len_
;
for
(
auto
i_
=
0
;
i_
<
cumulative_len_
;
i_
++
)
{
m_
(
j_
*
cumulative_len_
+
i_
)
=
current_m_
[
i_
]
+
j_offset_
;
}
}
cumulative_len_
*=
current_len_
;
if
(
!
is_linear_dim
)
cumulative_non_linear_len_
*=
current_len_
;
});
return
m_
;
}();
return
TO_SEQUENCE
(
non_linear_map
,
NumAccess
);
}
static
constexpr
auto
get_non_linear_access_histogram
()
{
constexpr
auto
m_
=
get_non_linear_access_map
();
// m_.foo();
constexpr
auto
r_
=
typename
arithmetic_sequence_gen
<
0
,
get_num_non_linear_access
()
+
1
,
1
>::
type
{};
constexpr
auto
h_
=
histogram_sorted_sequence
(
m_
,
r_
);
return
h_
;
}
static
constexpr
auto
get_non_linear_access_histogram_prefix_sum
()
{
constexpr
auto
h_
=
get_non_linear_access_histogram
();
constexpr
auto
h_prefix_sum_
=
prefix_sum_sequence
(
h_
);
return
h_prefix_sum_
;
}
public:
static
constexpr
index_t
NumAccess_NonLinear
=
get_num_non_linear_access
();
using
AccessMap_NonLinear
=
decltype
(
get_non_linear_access_map
());
// sequence
using
AccessHistogram_NonLinear
=
decltype
(
get_non_linear_access_histogram
());
using
AccessPrefixSum_NonLinear
=
decltype
(
get_non_linear_access_histogram_prefix_sum
());
};
static
constexpr
index_t
NumAccess
=
traits
::
NumAccess
;
static
constexpr
index_t
NumAccess_NonLinear
=
traits
::
NumAccess_NonLinear
;
using
AccessMap_NonLinear
=
typename
traits
::
AccessMap_NonLinear
;
using
AccessHistogram_NonLinear
=
typename
traits
::
AccessHistogram_NonLinear
;
using
AccessPrefixSum_NonLinear
=
typename
traits
::
AccessPrefixSum_NonLinear
;
CK_TILE_DEVICE
constexpr
tile_window_linear
()
=
default
;
CK_TILE_DEVICE
constexpr
tile_window_linear
(
const
BottomTensorView
&
bottom_tensor_view
,
const
WindowLengths
&
window_lengths
,
const
BottomTensorIndex
&
window_origin
,
const
TileDstr
&
tile_distribution
)
:
bottom_tensor_view_
{
bottom_tensor_view
},
window_lengths_
{
window_lengths
},
window_origin_
{
window_origin
},
tile_dstr_
{
tile_distribution
},
cached_coords_
{},
cached_flags_
{}
{
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
tile_distribution
.
get_ps_ys_to_xs_adaptor
(),
container_concat
(
make_tuple
(
get_warp_id
(),
get_lane_id
()),
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimY
>
{})));
BottomTensorIndex
bottom_tensor_thread_origin_idx_tmp
=
window_origin
+
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
auto
bottom_tensor_thread_coord_tmp
=
make_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_origin_idx_tmp
);
// future load/store() calls (might allocate more registers)
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
i_access
]
>
{};
constexpr
auto
need_save_non_linear_coord
=
bool_constant
<
AccessPrefixSum_NonLinear
{}[
non_linear_id
]
==
i_access
>
{};
if
constexpr
(
need_save_non_linear_coord
)
{
cached_coords_
(
non_linear_id
)
=
bottom_tensor_thread_coord_tmp
;
}
// TODO: need pad_tensor_view to check which dim need use flag to check
// cached flag is independent from non-linear-coord
// but need be updated in move_tile, with proper dims
cached_flags_
(
i_access
)
=
coordinate_has_valid_offset_assuming_top_index_is_valid
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_coord_tmp
);
if
constexpr
(
i_access
!=
(
NumAccess
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step_static
(
i_access
);
// tuple of number
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord_tmp
,
bottom_tensor_thread_coord_tmp
,
idx_diff_ps_ys
);
}
});
}
CK_TILE_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
{
return
NDimBottomTensor
;
}
CK_TILE_DEVICE
static
constexpr
bool
has_static_tile_distribution
()
{
return
TileDstr
::
is_static
();
}
CK_TILE_DEVICE
constexpr
auto
get_window_lengths
()
const
{
return
window_lengths_
;
}
CK_TILE_DEVICE
constexpr
auto
get_tile_distribution
()
const
{
return
tile_dstr_
;
}
CK_TILE_DEVICE
constexpr
auto
get_bottom_tensor_view
()
const
{
return
bottom_tensor_view_
;
}
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
CK_TILE_DEVICE
constexpr
void
set_bottom_tensor_view_data_ptr
(
typename
BottomTensorView
::
DataType
*
data
)
{
bottom_tensor_view_
.
buf_
.
p_data_
=
data
;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template
<
typename
ATopIndex
>
CK_TILE_DEVICE
void
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
WindowAdaptorCoord
&
window_adaptor_thread_coord
,
BottomTensorCoord
&
bottom_tensor_thread_coord
,
const
ATopIndex
&
idx_diff_adaptor_top
)
const
{
array
<
index_t
,
NDimBottomTensor
>
idx_diff_adaptor_bottom
;
move_tensor_adaptor_coordinate
(
tile_dstr_
.
get_ps_ys_to_xs_adaptor
(),
window_adaptor_thread_coord
,
idx_diff_adaptor_top
,
idx_diff_adaptor_bottom
);
move_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_coord
,
idx_diff_adaptor_bottom
);
}
template
<
index_t
i_access
>
CK_TILE_DEVICE
static
constexpr
auto
get_bottom_linear_coordinate
(
number
<
i_access
>
)
{
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
idx_ys
=
SFC_Ys
::
get_index_static
(
number
<
i_access
>
{});
using
ys_to_rhs_major
=
typename
decltype
(
TileDstr
{}.
get_static_tile_distribution_encoding
())
::
Ys2RHsMajor
;
constexpr
auto
modified_idx_ys
=
generate_tuple
(
[
&
](
auto
i_dim_y
)
{
constexpr
auto
rhs_major
=
ys_to_rhs_major
{}[
i_dim_y
];
constexpr
auto
target_h_dim
=
number
<
rhs_major
-
1
>
{};
// no r dim here!
if
constexpr
(
LinearBottomDims
{}[
target_h_dim
]
==
0
)
{
return
number
<
0
>
{};
}
else
{
return
number
<
idx_ys
[
i_dim_y
]
>
{};
}
},
number
<
NDimY
>
{});
constexpr
auto
adaptor_
=
TileDstr
{}.
get_ps_ys_to_xs_adaptor
();
constexpr
auto
idx_
=
container_concat
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{}),
modified_idx_ys
);
return
adaptor_
.
calculate_bottom_index
(
idx_
);
}
template
<
index_t
i_access
>
CK_TILE_DEVICE
static
constexpr
index_t
get_bottom_linear_offset
(
number
<
i_access
>
)
{
constexpr
auto
linear_coord
=
get_bottom_linear_coordinate
(
number
<
i_access
>
{});
// since this is linear offset, we assum bottom X tensor is always linear
constexpr
index_t
linear_offset
=
[
&
]()
{
constexpr
auto
x_idx_
=
linear_coord
;
constexpr
auto
x_len_
=
TileDstr
{}.
get_lengths
();
static_assert
(
x_idx_
.
size
()
==
x_len_
.
size
());
constexpr
index_t
x_dims_
=
x_idx_
.
size
();
index_t
cu_stride_
=
1
;
index_t
cu_offset_
=
0
;
static_for
<
0
,
x_dims_
,
1
>
{}([
&
](
auto
i_
)
{
auto
r_i_
=
number
<
x_dims_
-
i_
-
1
>
{};
cu_offset_
+=
x_idx_
[
r_i_
]
*
cu_stride_
;
cu_stride_
*=
x_len_
[
r_i_
];
});
return
cu_offset_
;
}();
return
linear_offset
;
}
CK_TILE_DEVICE
constexpr
auto
get_num_access
()
const
{
return
traits
::
NumAccess
;
}
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
// data index [y0, y1, ...]
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_index
(
IAccess
);
// write into distributed tensor
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_diff_ys
[
jj
]
+
j
)
:
idx_diff_ys
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start_static
);
static_assert
(
d
%
traits
::
ScalarPerVector
==
0
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
#endif
});
return
dst_tensor
;
}
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
static
constexpr
index_t
YElementSize
=
TileDstr
{}.
get_ys_to_d_descriptor
().
get_element_space_size
();
static_assert
(
YElementSize
%
traits
::
ScalarPerVector
==
0
);
using
vectorized_tbuf
=
array
<
vector_t
,
YElementSize
/
traits
::
ScalarPerVector
>
;
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
&
dst_vec_tbuf
=
reinterpret_cast
<
vectorized_tbuf
&>
(
dst_tensor
.
get_thread_buffer
());
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
i_access
==
0
&&
BottomTensorView
::
buffer_view
::
get_address_space
()
==
address_space_enum
::
global
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start_static
=
SFC_Ys
::
get_index_static
(
IAccess
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start_static
);
static_assert
(
d
%
traits
::
ScalarPerVector
==
0
);
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
dst_vec_tbuf
.
template
at
<
d
/
traits
::
ScalarPerVector
>(),
bottom_tensor_thread_coord
,
linear_offset
/**/
,
bottom_tensor_flag
,
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
"; this inline asm is workaround to prevent compiler from using too much "
"scratch memory"
::
);
#endif
}
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsDataType
=
typename
LdsTileWindow
::
DataType
;
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert
(
NumAccess_NonLinear
==
NumAccess
);
static_assert
(
BottomTensorView
::
buffer_view
::
get_address_space
()
==
address_space_enum
::
global
);
// issues * warps * lanes
static_assert
(
LdsTileWindow
::
get_num_of_dimension
()
==
3
);
// TODO: hard coded
const
index_t
size_per_buf
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
);
const
index_t
size_per_wave
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
1
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
size_per_issue
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
1
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
m0_init_value
=
size_per_buf
+
size_per_wave
*
get_warp_id
();
m0_set_with_memory
(
m0_init_value
);
// This should be wave independent
using
vector_t
=
typename
traits
::
vector_t
;
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
i_access
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// get this flag anyway
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements_raw
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
0
,
bottom_tensor_flag
,
pre_nop_
);
// move thread coordinate
if
constexpr
(
i_access
!=
(
NumAccess
-
1
))
{
m0_inc_with_memory
(
size_per_issue
);
}
});
}
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsDataType
=
typename
LdsTileWindow
::
DataType
;
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert
(
NumAccess_NonLinear
==
NumAccess
);
static_assert
(
BottomTensorView
::
buffer_view
::
get_address_space
()
==
address_space_enum
::
global
);
// issues * warps * lanes
static_assert
(
LdsTileWindow
::
get_num_of_dimension
()
==
3
);
// TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr
index_t
size_per_buf
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{},
number
<
0
>
{}));
constexpr
index_t
size_per_wave
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
1
>
{},
number
<
0
>
{}))
-
size_per_buf
;
constexpr
index_t
size_per_issue
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
1
>
{},
number
<
0
>
{},
number
<
0
>
{}))
-
size_per_buf
;
const
index_t
m0_init_value
=
size_per_buf
+
size_per_wave
*
get_warp_id
();
using
vector_t
=
typename
traits
::
vector_t
;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
+
m0_init_value
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
0
,
bottom_tensor_flag
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
i_access
!=
(
NumAccess
-
1
))
{
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
}
});
}
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
});
}
CK_TILE_DEVICE
void
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
)
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
static
constexpr
bool
oob_conditional_check
=
true
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
()
.
template
set_vectorized_elements_raw
<
vector_t
,
oob_conditional_check
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
);
});
}
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE
void
move
(
const
BottomTensorIndex
&
step
)
{
window_origin_
+=
step
;
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
i_access
]
>
{};
constexpr
auto
need_update_non_linear_coord
=
bool_constant
<
AccessPrefixSum_NonLinear
{}[
non_linear_id
]
==
i_access
>
{};
if
constexpr
(
need_update_non_linear_coord
)
{
move_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
cached_coords_
(
non_linear_id
),
step
);
}
// move the current coord with linear_coords
auto
tmp_coords
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_coord
=
get_bottom_linear_coordinate
(
IAccess
);
move_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
tmp_coords
,
linear_coord
);
cached_flags_
(
IAccess
)
=
coordinate_has_valid_offset_assuming_top_index_is_valid
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
tmp_coords
);
});
}
CK_TILE_DEVICE
void
set_window_origin
(
const
BottomTensorIndex
&
new_window_origin
)
{
window_origin_
=
new_window_origin
;
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
TileDstr
{}.
get_ps_ys_to_xs_adaptor
(),
container_concat
(
make_tuple
(
get_warp_id
(),
get_lane_id
()),
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimY
>
{})));
BottomTensorIndex
bottom_tensor_thread_origin_idx_tmp
=
window_origin_
+
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
auto
bottom_tensor_thread_coord_tmp
=
make_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_origin_idx_tmp
);
// future load/store() calls (might allocate more registers)
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
i_access
]
>
{};
constexpr
auto
need_save_non_linear_coord
=
bool_constant
<
AccessPrefixSum_NonLinear
{}[
non_linear_id
]
==
i_access
>
{};
if
constexpr
(
need_save_non_linear_coord
)
{
cached_coords_
(
non_linear_id
)
=
bottom_tensor_thread_coord_tmp
;
}
if
constexpr
(
i_access
!=
(
NumAccess
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step_static
(
i_access
);
// tuple of number
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord_tmp
,
bottom_tensor_thread_coord_tmp
,
idx_diff_ps_ys
);
}
});
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{
bottom_tensor_view_
.
init_raw
();
}
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView
bottom_tensor_view_
;
//
WindowLengths
window_lengths_
;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex
window_origin_
;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr
tile_dstr_
;
// this contains:
array
<
BottomTensorCoord
,
traits
::
NumAccess_NonLinear
>
cached_coords_
;
array
<
bool
,
traits
::
NumAccess
>
cached_flags_
;
};
namespace
impl
{
template
<
address_space_enum
,
index_t
len_
>
struct
default_linear_bottom_dims_impl
{
using
type
=
typename
uniform_sequence_gen
<
len_
,
0
>::
type
;
};
template
<
index_t
len_
>
struct
default_linear_bottom_dims_impl
<
address_space_enum
::
global
,
len_
>
{
// global default to seq<0,0,....1>
using
type
=
typename
sequence_merge
<
typename
uniform_sequence_gen
<
len_
-
1
,
0
>::
type
,
sequence
<
1
>>::
type
;
};
template
<
index_t
len_
>
struct
default_linear_bottom_dims_impl
<
address_space_enum
::
lds
,
len_
>
{
// lds default to seq<1,1.....1>
using
type
=
typename
uniform_sequence_gen
<
len_
,
1
>::
type
;
};
}
// namespace impl
template
<
typename
TensorView_
>
using
default_linear_bottom_dims
=
typename
impl
::
default_linear_bottom_dims_impl
<
TensorView_
::
buffer_view
::
get_address_space
(),
TensorView_
::
get_num_of_dimension
()
>::
type
;
// if using this API, will create a tile_window_linear
// this structure can have the chance to use immediate value, save register
// need pass in LinearBottomDims_ properly to control which dim is linear
// so to generate a constexpr offset as linear_offset for this dim
// (and finally pass to the immediate offset of buffer/lds instruction)
//
// Note: there is no internal check for which dim is OK to use linear offset
// user must make sure by themselves
//
// e.g.
// 2d global matrix, set LinearBottomDims_=seq<0, 1>, the last dim will generate
// immediate offset if each thread has multiple issue along last dim
//
// 2d LDS buffer, set LinearBottomDims_=seq<1, 1>, then only one vgpr used as offset
// everything else is just using immediate offset.
//
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
typename
LinearBottomDims_
=
default_linear_bottom_dims
<
TensorView_
>
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window_linear
(
const
TensorView_
&
tensor_view
,
const
WindowLengths_
&
window_lengths
,
const
multi_index
<
TensorView_
::
get_num_of_dimension
()
>&
origin
,
const
StaticTileDistribution_
&
tile_distribution
,
LinearBottomDims_
=
{})
{
static_assert
(
LinearBottomDims_
::
size
()
==
TensorView_
::
get_num_of_dimension
());
return
tile_window_linear
<
remove_cvref_t
<
TensorView_
>
,
remove_cvref_t
<
WindowLengths_
>
,
remove_cvref_t
<
StaticTileDistribution_
>
,
remove_cvref_t
<
LinearBottomDims_
>>
{
tensor_view
,
window_lengths
,
origin
,
tile_distribution
};
}
template
<
typename
TileWindow_
,
typename
StaticTileDistribution_
,
typename
LinearBottomDims_
=
default_linear_bottom_dims
<
typename
TileWindow_
::
BottomTensorView
>
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window_linear
(
const
TileWindow_
&
tile_window
,
const
StaticTileDistribution_
&
tile_distribution
,
LinearBottomDims_
=
{})
{
return
make_tile_window_linear
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
tile_window
.
get_window_origin
(),
tile_distribution
,
LinearBottomDims_
{});
}
// this version must not be called under a constexpr context
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
typename
LinearBottomDims_
=
default_linear_bottom_dims
<
TensorView_
>
>
CK_TILE_DEVICE
auto
make_tile_window_linear_raw
(
const
TensorView_
&
tensor_view
,
const
WindowLengths_
&
window_lengths
,
const
multi_index
<
TensorView_
::
get_num_of_dimension
()
>&
origin
,
const
StaticTileDistribution_
&
tile_distribution
,
LinearBottomDims_
=
{})
{
static_assert
(
LinearBottomDims_
::
size
()
==
TensorView_
::
get_num_of_dimension
());
auto
w
=
tile_window_linear
<
remove_cvref_t
<
TensorView_
>
,
remove_cvref_t
<
WindowLengths_
>
,
remove_cvref_t
<
StaticTileDistribution_
>
,
remove_cvref_t
<
LinearBottomDims_
>>
{
tensor_view
,
window_lengths
,
origin
,
tile_distribution
};
w
.
init_raw
();
return
w
;
}
template
<
typename
TileWindow_
,
typename
StaticTileDistribution_
,
typename
LinearBottomDims_
=
default_linear_bottom_dims
<
typename
TileWindow_
::
BottomTensorView
>
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window_linear_raw
(
const
TileWindow_
&
tile_window
,
const
StaticTileDistribution_
&
tile_distribution
,
LinearBottomDims_
=
{})
{
return
make_tile_window_linear_raw
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
tile_window
.
get_window_origin
(),
tile_distribution
,
LinearBottomDims_
{});
}
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
typename
LinearBottomDims_
>
CK_TILE_DEVICE
void
move_tile_window
(
tile_window_linear
<
TensorView_
,
WindowLengths_
,
StaticTileDistribution_
,
LinearBottomDims_
>&
window
,
const
typename
tile_window_linear
<
TensorView_
,
WindowLengths_
,
StaticTileDistribution_
,
LinearBottomDims_
>::
BottomTensorIndex
&
step
)
{
window
.
move
(
step
);
}
}
// namespace ck_tile
include/ck_tile/core/utility/magic_div.hpp
View file @
eed60199
...
...
@@ -58,10 +58,18 @@ struct magic_division32_bit_range
// magic division for uint32_t
CK_TILE_DEVICE
static
constexpr
uint32_t
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
if
(
__builtin_is_constant_evaluated
())
{
uint32_t
tmp
=
(
static_cast
<
uint64_t
>
(
dividend
)
*
multiplier
)
>>
32
;
return
(
tmp
+
dividend
)
>>
shift
;
}
else
{
uint32_t
tmp
=
__umulhi
(
dividend
,
multiplier
);
return
(
tmp
+
dividend
)
>>
shift
;
}
}
CK_TILE_HOST
static
constexpr
uint32_t
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
...
...
@@ -76,11 +84,20 @@ struct magic_division32_bit_range
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE
static
constexpr
int32_t
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
if
(
__builtin_is_constant_evaluated
())
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
(
static_cast
<
uint64_t
>
(
dividend_u32
)
*
multiplier
)
>>
32
;
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
else
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
__umulhi
(
dividend_u32
,
multiplier
);
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
}
CK_TILE_HOST
static
constexpr
int32_t
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
...
...
include/ck_tile/host/host_tensor.hpp
View file @
eed60199
...
...
@@ -11,6 +11,7 @@
#include <thread>
#include <utility>
#include <vector>
#include <functional>
#include "ck_tile/core.hpp"
#include "ck_tile/host/ranges.hpp"
...
...
@@ -532,6 +533,28 @@ struct HostTensor
typename
Data
::
size_type
size
()
const
{
return
mData
.
size
();
}
// return a slice of this tensor
// for simplicity we just copy the data and return a new tensor
auto
slice
(
std
::
vector
<
size_t
>
s_begin
,
std
::
vector
<
size_t
>
s_end
)
const
{
assert
(
s_begin
.
size
()
==
s_end
.
size
());
assert
(
s_begin
.
size
()
==
get_num_of_dimension
());
std
::
vector
<
size_t
>
s_len
(
s_begin
.
size
());
std
::
transform
(
s_end
.
begin
(),
s_end
.
end
(),
s_begin
.
begin
(),
s_len
.
begin
(),
std
::
minus
<
size_t
>
{});
HostTensor
<
T
>
sliced_tensor
(
s_len
);
sliced_tensor
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
std
::
vector
<
size_t
>
src_idx
(
idx
.
size
());
std
::
transform
(
idx
.
begin
(),
idx
.
end
(),
s_begin
.
begin
(),
src_idx
.
begin
(),
std
::
plus
<
size_t
>
{});
self
(
idx
)
=
operator
()(
src_idx
);
});
return
sliced_tensor
;
}
template
<
typename
U
=
T
>
auto
AsSpan
()
const
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
View file @
eed60199
...
...
@@ -229,7 +229,7 @@ struct BlockFmhaPipelineQRAsyncEx
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_acc
));
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_acc
s
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
...
...
@@ -336,7 +336,7 @@ struct BlockFmhaPipelineQRAsyncEx
do
{
// STAGE 1, QK gemm
clear_tile
(
s_acc
);
// initialize C
clear_tile
(
s_acc
s
);
// initialize C
if
constexpr
(
k0_loops
>
1
)
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
...
...
@@ -350,7 +350,7 @@ struct BlockFmhaPipelineQRAsyncEx
async_load_fence
(
k_dram_window
.
get_num_access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
gemm_0
(
s_acc
s
,
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
...
...
@@ -373,7 +373,7 @@ struct BlockFmhaPipelineQRAsyncEx
__builtin_amdgcn_sched_barrier
(
0
);
{
// tail
gemm_0
(
s_acc
,
s_acc
s
,
get_slice_tile
(
q
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
get_slice_tile
(
k_lds_load
,
...
...
@@ -385,8 +385,8 @@ struct BlockFmhaPipelineQRAsyncEx
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
s_acc
s
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
s
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
s
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
...
...
@@ -396,33 +396,33 @@ struct BlockFmhaPipelineQRAsyncEx
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#endif
},
s_acc
,
s_acc
s
,
bias_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
constexpr
auto
s_spans
=
decltype
(
s_acc
s
)
::
get_distributed_spans
();
s_acc
s
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
s
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
s_acc
s
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
s_acc
s
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
s
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
s_acc
s
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
s
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
s
);
#endif
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
...
...
@@ -437,7 +437,7 @@ struct BlockFmhaPipelineQRAsyncEx
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
s_acc
s
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
...
...
@@ -445,7 +445,7 @@ struct BlockFmhaPipelineQRAsyncEx
}
}
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
s
);
// S{j}
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
sequence
<
1
>
{},
...
...
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
eed60199
...
...
@@ -10,114 +10,134 @@
namespace
ck_tile
{
// fp16
using
WarpGemmMfmaF16F16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
16N16K16
>>
;
using
WarpGemmMfmaF16F16F32M
32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
32N32K8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaF16F16F32M
32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
IterateK
<
WarpGemmAttributeMfmaImplF16F16F32M
32N32K8
,
2
>>
;
using
WarpGemmMfmaF16F16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
16N16K16
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
1
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
// bf16
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
using
WarpGemmMfmaBf16Bf16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M
16N16K16
>>
;
using
WarpGemmMfmaBf16Bf16F32M
32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M
32N32K8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaBf16Bf16F32M
32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
IterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M
32N32K8
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M
16N16K16
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
1
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
// fp8
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
template
<
index_t
swizzle_factor
=
2
>
using
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
>
,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
,
WGAttrCtlEnum
::
Default_
>
,
2
,
swizzle_factor
>>
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
eed60199
...
...
@@ -510,11 +510,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
});
}
template
<
index_t
k
KIter
,
bool
post_nop_
=
false
>
template
<
index_t
i
KIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
k
KIter
>
,
number
<
i
KIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
eed60199
...
...
@@ -139,7 +139,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vaa
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_16x16x16f16"
,
"+v"
,
"a"
,
"
b
"
,
"v"
)
DISPATCH_MFMA_
(
"v_mfma_f32_16x16x16f16"
,
"+v"
,
"a"
,
"
a
"
,
"v"
)
}
else
{
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
View file @
eed60199
...
...
@@ -32,10 +32,8 @@ struct WarpGemmImpl
using
CWarpTensor
=
static_distributed_tensor
<
CDataType
,
CWarpDstr
>
;
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
bool_constant
<
post_nop_
>
=
{})
const
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
bool_constant
<
post_nop_
>
=
{})
const
{
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
CTensor
,
CTensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
ATensor
>
&&
...
...
@@ -56,7 +54,11 @@ struct WarpGemmImpl
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
}
template
<<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
index_t
i_subk
,
bool
post_nop_
=
false
>
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
index_t
i_subk
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
...
...
@@ -82,6 +84,7 @@ struct WarpGemmImpl
template
<
typename
ATensor
,
typename
BTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
ATensor
&
a
,
const
BTensor
&
b
)
const
{
using
CTensor
=
CWarpTensor
;
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
ATensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BTensor
>
);
CTensor
c
;
...
...
include/ck_tile/ops/reduce/block/block_reduce.hpp
View file @
eed60199
...
...
@@ -160,10 +160,9 @@ CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor
// reduction sweep forward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
// TODO: lid_over_rid_derivative not ok in xor? maybe need limit the usage of
// xor
index_t
src_lane
=
(
__lane_id
()
*
lid_over_rid_derivative
)
^
(
number
<
1
<<
istage
.
value
>
{}.
value
);
index_t
src_lane
=
__lane_id
()
^
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
const
auto
v_remote
=
warp_shuffle
(
v_local
,
src_lane
);
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment