Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
eed60199
Commit
eed60199
authored
Sep 13, 2024
by
carlushuang
Browse files
more robust api
parent
cae751d1
Changes
27
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1727 additions
and
210 deletions
+1727
-210
example/ck_tile/05_moe/topk_softmax_api.cpp
example/ck_tile/05_moe/topk_softmax_api.cpp
+25
-1
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/algorithm/space_filling_curve.hpp
include/ck_tile/core/algorithm/space_filling_curve.hpp
+23
-2
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+54
-45
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+24
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+74
-53
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+84
-1
include/ck_tile/core/tensor/store_tile.hpp
include/ck_tile/core/tensor/store_tile.hpp
+27
-0
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+156
-10
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+2
-0
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+53
-11
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+1055
-0
include/ck_tile/core/utility/magic_div.hpp
include/ck_tile/core/utility/magic_div.hpp
+22
-5
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+23
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
...ile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
+16
-16
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
+75
-55
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
+2
-2
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+1
-1
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
+8
-5
include/ck_tile/ops/reduce/block/block_reduce.hpp
include/ck_tile/ops/reduce/block/block_reduce.hpp
+2
-3
No files found.
example/ck_tile/05_moe/topk_softmax_api.cpp
View file @
eed60199
...
...
@@ -25,7 +25,7 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
using
ts_input_type
=
ck_tile
::
fp16_t
;
using
ts_weight_type
=
float
;
using
ts_index_type
=
ck_tile
::
index_t
;
#if 1
if
(
t
.
experts
<=
8
)
{
TOPK_SOFTMAX_DISPATCH
(
8
)
...
...
@@ -42,9 +42,24 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
{
TOPK_SOFTMAX_DISPATCH
(
64
)
}
else
if
(
t
.
experts
<=
128
)
{
TOPK_SOFTMAX_DISPATCH
(
128
)
}
else
if
(
t
.
experts
<=
192
)
{
TOPK_SOFTMAX_DISPATCH
(
192
)
}
#else
if
(
t
.
experts
<=
16
)
{
TOPK_SOFTMAX_DISPATCH
(
16
)
}
#endif
}
else
if
(
t
.
input_type
==
"bf16"
&&
t
.
weight_type
==
"fp32"
)
{
#if 1
using
ts_input_type
=
ck_tile
::
bf16_t
;
using
ts_weight_type
=
float
;
using
ts_index_type
=
ck_tile
::
index_t
;
...
...
@@ -64,6 +79,15 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
{
TOPK_SOFTMAX_DISPATCH
(
64
)
}
else
if
(
t
.
experts
<=
128
)
{
TOPK_SOFTMAX_DISPATCH
(
128
)
}
else
if
(
t
.
experts
<=
192
)
{
TOPK_SOFTMAX_DISPATCH
(
192
)
}
#endif
}
return
-
1
;
}
include/ck_tile/core.hpp
View file @
eed60199
...
...
@@ -50,6 +50,7 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
...
...
include/ck_tile/core/algorithm/space_filling_curve.hpp
View file @
eed60199
...
...
@@ -66,6 +66,20 @@ struct space_filling_curve
return
idx_tail
-
idx_head
;
}
template
<
index_t
AccessIdx1dHead
,
index_t
AccessIdx1dTail
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_step_between_static
(
number
<
AccessIdx1dHead
>
,
number
<
AccessIdx1dTail
>
)
{
static_assert
(
AccessIdx1dHead
>=
0
&&
AccessIdx1dHead
<
get_num_of_access
(),
"1D index out of range"
);
static_assert
(
AccessIdx1dTail
>=
0
&&
AccessIdx1dTail
<
get_num_of_access
(),
"1D index out of range"
);
constexpr
auto
idx_head
=
get_index_static
(
number
<
AccessIdx1dHead
>
{});
constexpr
auto
idx_tail
=
get_index_static
(
number
<
AccessIdx1dTail
>
{});
return
idx_tail
-
idx_head
;
}
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_forward_step
(
number
<
AccessIdx1d
>
)
{
...
...
@@ -73,6 +87,13 @@ struct space_filling_curve
return
get_step_between
(
number
<
AccessIdx1d
>
{},
number
<
AccessIdx1d
+
1
>
{});
}
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_forward_step_static
(
number
<
AccessIdx1d
>
)
{
static_assert
(
AccessIdx1d
<
get_num_of_access
(),
"1D index should be larger than 0"
);
return
get_step_between_static
(
number
<
AccessIdx1d
>
{},
number
<
AccessIdx1d
+
1
>
{});
}
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_backward_step
(
number
<
AccessIdx1d
>
)
{
...
...
@@ -153,9 +174,9 @@ struct space_filling_curve
return
idx_md
;
}
// FIXME: re
name this function
// FIXME: re
turn tuple of number<>, which is compile time only variable
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_index_
tuple_of_number
(
number
<
AccessIdx1d
>
)
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_index_
static
(
number
<
AccessIdx1d
>
)
{
constexpr
auto
idx
=
get_index
(
number
<
AccessIdx1d
>
{});
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
eed60199
...
...
@@ -156,8 +156,8 @@ struct buffer_load<2, pre_nop>
index_t
/*flag*/
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
// subdword is buggy, use dword buf and convert manually
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
2
,
T
>::
payload_t
;
//
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using
mbuf_t
=
ushort
;
//
typename impl::buffer_load_trait<2, T>::payload_t;
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
...
...
@@ -315,9 +315,9 @@ struct buffer_load_if<2, pre_nop>
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
{
static_assert
(
sizeof
(
T
)
==
4
);
//
static_assert(sizeof(T) == 4);
auto
saved_exec
=
__builtin_amdgcn_read_exec
();
using
mbuf_t
=
typename
impl
::
buffer_load_trait
<
2
,
T
>::
payload_t
;
using
mbuf_t
=
ushort
;
//
typename impl::buffer_load_trait<2, T>::payload_t;
if
constexpr
(
pre_nop
)
asm
volatile
(
"s_nop 4
\n
"
"v_cmpx_le_u32 exec, 1, %4
\n
"
...
...
@@ -676,19 +676,17 @@ template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; };
}
// namespace impl
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
template
<
index_t
>
struct
smem_load
;
template
<
index_t
>
struct
smem_load
;
template
<
>
template
<
>
struct
smem_load
<
16
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
16
,
T
>::
payload_t
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
16
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_b128 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
...
...
@@ -700,9 +698,7 @@ template <>
struct
smem_load
<
8
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
8
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
8
,
T
>::
payload_t
;
...
...
@@ -717,9 +713,7 @@ template <>
struct
smem_load
<
4
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
4
,
T
>::
payload_t
;
...
...
@@ -734,11 +728,10 @@ template <>
struct
smem_load
<
2
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
4
);
// subdword is buggy, use dword buf and convert manually
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
1
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_u16 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
...
...
@@ -750,9 +743,7 @@ template <>
struct
smem_load
<
1
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
1
,
T
>::
payload_t
;
...
...
@@ -1879,6 +1870,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
,
index_t
dst_linear_addr_offset
,
index_t
is_valid_element
=
1
)
{
constexpr
index_t
bytes
=
sizeof
(
T
)
*
N
;
...
...
@@ -1892,7 +1884,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
,
dst_linear_addr_offset
,
is_valid_element
);
}
else
...
...
@@ -1901,7 +1893,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
dst_linear_addr_offset
);
}
}
...
...
@@ -2266,6 +2258,7 @@ template <typename T,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
index_t
src_element_space_size
,
bool_constant
<
pre_nop
>
=
{})
{
...
...
@@ -2273,9 +2266,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
src_linear_addr_offset
,
bool_constant
<
pre_nop
>
{});
}
// This version support buffer resource as input arg
...
...
@@ -2286,12 +2284,18 @@ template <typename T,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
bool_constant
<
pre_nop
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
src_linear_addr_offset
,
bool_constant
<
pre_nop
>
{});
}
// This version support buffer resource as input arg
...
...
@@ -2302,16 +2306,18 @@ template <typename T,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob
(
CK_TILE_LDS_ADDR
T
*
smem
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_async_buffer_load
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
src_linear_addr_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
...
...
@@ -2368,6 +2374,7 @@ template <typename T,
CK_TILE_DEVICE
void
amd_buffer_store_raw
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
index_t
dst_linear_element_offset
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
)
{
...
...
@@ -2375,11 +2382,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
*
sizeof
(
T
));
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
index_t
dst_linear_addr_offset
=
dst_linear_element_offset
*
sizeof
(
T
);
amd_buffer_store_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
,
dst_linear_addr_offset
,
dst_thread_element_valid
);
}
...
...
include/ck_tile/core/container/tuple.hpp
View file @
eed60199
...
...
@@ -635,6 +635,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
return
r
;
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
+
y
[
i
];
},
number
<
NSize
>
{});
}
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
...
...
@@ -649,6 +657,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
return
r
;
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
-
y
[
i
];
},
number
<
NSize
>
{});
}
template
<
typename
...
Xs
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
...
...
@@ -686,6 +702,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
return
a
*
x
;
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
*
y
[
i
];
},
number
<
NSize
>
{});
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
...
...
include/ck_tile/core/tensor/buffer_view.hpp
View file @
eed60199
This diff is collapsed.
Click to expand it.
include/ck_tile/core/tensor/load_tile.hpp
View file @
eed60199
...
...
@@ -11,7 +11,7 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window
_linear
.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
...
...
@@ -31,6 +31,20 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
return
tile_window
.
load
(
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
T
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
...
...
@@ -49,6 +63,24 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
T
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
// while creating the smem window, which can enable compiler properly detect the
// dependency if using multiple smem window (multiple buffer)
...
...
@@ -69,6 +101,22 @@ async_load_tile(LdsTileWindow_&& lds_tile,
return
tile_window
.
async_load
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load_tile
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
async_load
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
...
...
@@ -89,6 +137,25 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
return
tile_window
.
async_load_raw
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
auto
load_tile
(
const
null_tile_window
<
WindowLengths
>&
)
{
...
...
@@ -100,4 +167,20 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<Windo
{
}
// TODO: this function requires some sub-fileds exist for the target tile window
template
<
typename
TileWindow
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
const
TileWindow
&
w
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
using
TileDstr
=
typename
TileWindow
::
TileDstr
;
using
DataType
=
typename
TileWindow
::
DataType
;
auto
t
=
make_static_distributed_tensor
<
DataType
>
(
TileDstr
{});
load_tile_raw
(
t
,
w
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
return
t
;
}
}
// namespace ck_tile
include/ck_tile/core/tensor/store_tile.hpp
View file @
eed60199
...
...
@@ -10,6 +10,7 @@
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
...
...
@@ -90,4 +91,30 @@ store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
tile_window
.
store_raw
(
dstr_tensor
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
tile_window
.
store
(
dstr_tensor
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile_raw
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
tile_window
.
store_raw
(
dstr_tensor
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/tensor_view.hpp
View file @
eed60199
...
...
@@ -75,14 +75,34 @@ struct tensor_view
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
X
>
get_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
return
buf_
.
template
get
<
X
>(
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
X
>
get_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
// flag
bool_constant
<
oob_conditional_check
>
=
{})
const
{
return
buf_
.
template
get
<
X
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
...
...
@@ -106,6 +126,24 @@ struct tensor_view
bool_constant
<
pre_nop
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
dst
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
...
...
@@ -114,26 +152,71 @@ struct tensor_view
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
)
const
const
TensorCoord
&
coord
,
index_t
linear_offset
)
const
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
)
const
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements_raw
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
bool_constant
<
pre_nop
>
=
{})
const
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements_raw
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool_constant
<
pre_nop
>
=
{})
const
{
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
coord
.
get_offset
(),
true
/*not used*/
,
bool_constant
<
pre_nop
>
{});
smem
,
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
pre_nop
>
{});
}
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements_raw
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
pre_nop
>
=
{})
const
{
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
// X is vector of DataType.
...
...
@@ -144,11 +227,15 @@ struct tensor_view
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements
(
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
...
...
@@ -159,15 +246,53 @@ struct tensor_view
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
...
...
@@ -176,15 +301,36 @@ struct tensor_view
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements
(
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"tensor_view{"
);
...
...
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
eed60199
...
...
@@ -454,6 +454,7 @@ struct tile_distribution_detail
}
// namespace detail
#if 0
// this returns a constexpr tile_distribution
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
...
...
@@ -490,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
#endif
// this returns a static tile_distribution
template
<
typename
StaticTileDistributionEncoding_
>
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
eed60199
...
...
@@ -223,10 +223,11 @@ struct tile_window_with_static_distribution
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template
<
typename
ATopIndex
>
CK_TILE_DEVICE
void
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
WindowAdaptorCoord
&
window_adaptor_thread_coord
,
BottomTensorCoord
&
bottom_tensor_thread_coord
,
const
A
daptor
TopIndex
&
idx_diff_adaptor_top
)
const
const
ATopIndex
&
idx_diff_adaptor_top
)
const
{
array
<
index_t
,
NDimBottomTensor
>
idx_diff_adaptor_bottom
;
...
...
@@ -309,7 +310,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
bottom_tensor_thread_coord
,
0
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
// write into distributed tensor
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
...
...
@@ -337,10 +338,11 @@ struct tile_window_with_static_distribution
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
_static
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
...
...
@@ -398,7 +400,7 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
bottom_tensor_thread_coord
,
/**/
,
0
/**/
,
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
...
...
@@ -484,7 +486,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements_raw
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
pre_nop_
);
smem
,
bottom_tensor_thread_coord
,
0
,
pre_nop_
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
@@ -552,7 +554,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
smem
,
bottom_tensor_thread_coord
,
0
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
@@ -618,7 +620,10 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
@@ -676,7 +681,7 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view
()
.
template
set_vectorized_elements_raw
<
vector_t
,
oob_conditional_check
>(
bottom_tensor_thread_coord
,
vec_value
);
bottom_tensor_thread_coord
,
0
,
vec_value
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
@@ -736,7 +741,10 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
...
@@ -868,6 +876,27 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view
,
window_lengths
,
origin
,
tile_distribution
};
}
// this version must not be called under a constexpr context
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
index_t
NumCoord
=
1
>
CK_TILE_DEVICE
auto
make_tile_window_raw
(
const
TensorView_
&
tensor_view
,
const
WindowLengths_
&
window_lengths
,
const
multi_index
<
TensorView_
::
get_num_of_dimension
()
>&
origin
,
const
StaticTileDistribution_
&
tile_distribution
,
number
<
NumCoord
>
=
{})
{
auto
w
=
tile_window_with_static_distribution
<
remove_cvref_t
<
TensorView_
>
,
remove_cvref_t
<
WindowLengths_
>
,
remove_cvref_t
<
StaticTileDistribution_
>
,
NumCoord
>
{
tensor_view
,
window_lengths
,
origin
,
tile_distribution
};
w
.
init_raw
();
return
w
;
}
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
...
...
@@ -992,6 +1021,19 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
tile_distribution
);
}
template
<
typename
TensorView
,
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window_raw
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
StaticTileDistribution
&
tile_distribution
)
{
auto
w
=
make_tile_window
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
tile_window
.
get_window_origin
(),
tile_distribution
);
w
.
init_raw
();
return
w
;
}
template
<
typename
TensorView_
,
typename
WindowLengths_
>
CK_TILE_DEVICE
void
move_tile_window
(
tile_window_with_static_lengths
<
TensorView_
,
WindowLengths_
>&
window
,
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
0 → 100644
View file @
eed60199
This diff is collapsed.
Click to expand it.
include/ck_tile/core/utility/magic_div.hpp
View file @
eed60199
...
...
@@ -58,10 +58,18 @@ struct magic_division32_bit_range
// magic division for uint32_t
CK_TILE_DEVICE
static
constexpr
uint32_t
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
if
(
__builtin_is_constant_evaluated
())
{
uint32_t
tmp
=
(
static_cast
<
uint64_t
>
(
dividend
)
*
multiplier
)
>>
32
;
return
(
tmp
+
dividend
)
>>
shift
;
}
else
{
uint32_t
tmp
=
__umulhi
(
dividend
,
multiplier
);
return
(
tmp
+
dividend
)
>>
shift
;
}
}
CK_TILE_HOST
static
constexpr
uint32_t
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
...
...
@@ -76,11 +84,20 @@ struct magic_division32_bit_range
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE
static
constexpr
int32_t
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
if
(
__builtin_is_constant_evaluated
())
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
(
static_cast
<
uint64_t
>
(
dividend_u32
)
*
multiplier
)
>>
32
;
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
else
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
__umulhi
(
dividend_u32
,
multiplier
);
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
}
CK_TILE_HOST
static
constexpr
int32_t
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
...
...
include/ck_tile/host/host_tensor.hpp
View file @
eed60199
...
...
@@ -11,6 +11,7 @@
#include <thread>
#include <utility>
#include <vector>
#include <functional>
#include "ck_tile/core.hpp"
#include "ck_tile/host/ranges.hpp"
...
...
@@ -532,6 +533,28 @@ struct HostTensor
typename
Data
::
size_type
size
()
const
{
return
mData
.
size
();
}
// return a slice of this tensor
// for simplicity we just copy the data and return a new tensor
auto
slice
(
std
::
vector
<
size_t
>
s_begin
,
std
::
vector
<
size_t
>
s_end
)
const
{
assert
(
s_begin
.
size
()
==
s_end
.
size
());
assert
(
s_begin
.
size
()
==
get_num_of_dimension
());
std
::
vector
<
size_t
>
s_len
(
s_begin
.
size
());
std
::
transform
(
s_end
.
begin
(),
s_end
.
end
(),
s_begin
.
begin
(),
s_len
.
begin
(),
std
::
minus
<
size_t
>
{});
HostTensor
<
T
>
sliced_tensor
(
s_len
);
sliced_tensor
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
std
::
vector
<
size_t
>
src_idx
(
idx
.
size
());
std
::
transform
(
idx
.
begin
(),
idx
.
end
(),
s_begin
.
begin
(),
src_idx
.
begin
(),
std
::
plus
<
size_t
>
{});
self
(
idx
)
=
operator
()(
src_idx
);
});
return
sliced_tensor
;
}
template
<
typename
U
=
T
>
auto
AsSpan
()
const
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
View file @
eed60199
...
...
@@ -229,7 +229,7 @@ struct BlockFmhaPipelineQRAsyncEx
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_acc
));
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_acc
s
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
...
...
@@ -336,7 +336,7 @@ struct BlockFmhaPipelineQRAsyncEx
do
{
// STAGE 1, QK gemm
clear_tile
(
s_acc
);
// initialize C
clear_tile
(
s_acc
s
);
// initialize C
if
constexpr
(
k0_loops
>
1
)
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
...
...
@@ -350,7 +350,7 @@ struct BlockFmhaPipelineQRAsyncEx
async_load_fence
(
k_dram_window
.
get_num_access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
gemm_0
(
s_acc
s
,
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
...
...
@@ -373,7 +373,7 @@ struct BlockFmhaPipelineQRAsyncEx
__builtin_amdgcn_sched_barrier
(
0
);
{
// tail
gemm_0
(
s_acc
,
s_acc
s
,
get_slice_tile
(
q
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
get_slice_tile
(
k_lds_load
,
...
...
@@ -385,8 +385,8 @@ struct BlockFmhaPipelineQRAsyncEx
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
s_acc
s
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
s
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
s
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
...
...
@@ -396,33 +396,33 @@ struct BlockFmhaPipelineQRAsyncEx
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#endif
},
s_acc
,
s_acc
s
,
bias_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
constexpr
auto
s_spans
=
decltype
(
s_acc
s
)
::
get_distributed_spans
();
s_acc
s
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
s
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
s_acc
s
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
s_acc
s
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
s
(
i_j_idx
),
row
,
col
);
});
});
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
s_acc
s
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
s
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
s
);
#endif
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
...
...
@@ -437,7 +437,7 @@ struct BlockFmhaPipelineQRAsyncEx
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
s_acc
s
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
...
...
@@ -445,7 +445,7 @@ struct BlockFmhaPipelineQRAsyncEx
}
}
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
s
);
// S{j}
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
sequence
<
1
>
{},
...
...
include/ck_tile/ops/gemm/warp/warp_gemm.hpp
View file @
eed60199
...
...
@@ -10,114 +10,134 @@
namespace
ck_tile
{
// fp16
using
WarpGemmMfmaF16F16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
16N16K16
>>
;
using
WarpGemmMfmaF16F16F32M
32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
32N32K8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaF16F16F32M
32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
IterateK
<
WarpGemmAttributeMfmaImplF16F16F32M
32N32K8
,
2
>>
;
using
WarpGemmMfmaF16F16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplF16F16F32M
16N16K16
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
1
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
// bf16
using
WarpGemmMfmaBf16Bf16F32M32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
using
WarpGemmMfmaBf16Bf16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M
16N16K16
>>
;
using
WarpGemmMfmaBf16Bf16F32M
32N32K8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M
32N32K8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaBf16Bf16F32M
32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
IterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M
32N32K8
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M
16N16K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImplBf16Bf16F32M
16N16K16
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
1
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
1
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK_SwizzleA
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
,
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
using
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
WGAttrCtlEnum
::
Default_
>
,
2
>>
;
// fp8
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8
=
WarpGemmImpl
<
WarpGemmAtrributeMfma
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
<
WGAttrCtlEnum
::
Default_
>
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
<
WGAttrCtlEnum
::
Default_
>>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
>>
;
using
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
<
WGAttrCtlEnum
::
Default_
>>>
;
template
<
index_t
swizzle_factor
=
2
>
using
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
>
,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
<
fp8_t
,
fp8_t
,
WGAttrCtlEnum
::
Default_
>
,
2
,
swizzle_factor
>>
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp
View file @
eed60199
...
...
@@ -510,11 +510,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
});
}
template
<
index_t
k
KIter
,
bool
post_nop_
=
false
>
template
<
index_t
i
KIter
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
,
number
<
k
KIter
>
,
number
<
i
KIter
>
,
bool_constant
<
post_nop_
>
=
{})
const
{
using
buf_a
=
thread_buffer
<
typename
Impl
::
AVecType
,
kKIter
>
;
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
eed60199
...
...
@@ -139,7 +139,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
}
else
if
constexpr
(
Ctrl
==
WGAttrCtlEnum
::
Raw_vaa
)
{
DISPATCH_MFMA_
(
"v_mfma_f32_16x16x16f16"
,
"+v"
,
"a"
,
"
b
"
,
"v"
)
DISPATCH_MFMA_
(
"v_mfma_f32_16x16x16f16"
,
"+v"
,
"a"
,
"
a
"
,
"v"
)
}
else
{
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp
View file @
eed60199
...
...
@@ -32,10 +32,8 @@ struct WarpGemmImpl
using
CWarpTensor
=
static_distributed_tensor
<
CDataType
,
CWarpDstr
>
;
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
bool_constant
<
post_nop_
>
=
{})
const
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
bool_constant
<
post_nop_
>
=
{})
const
{
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
CTensor
,
CTensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
ATensor
>
&&
...
...
@@ -56,7 +54,11 @@ struct WarpGemmImpl
c
.
get_thread_buffer
().
template
set_as
<
CVec
>(
I0
,
c_vec
);
}
template
<<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
index_t
i_subk
,
bool
post_nop_
=
false
>
template
<
typename
CTensor
,
typename
ATensor
,
typename
BTensor
,
index_t
i_subk
,
bool
post_nop_
=
false
>
CK_TILE_DEVICE
void
operator
()(
CTensor
&
c
,
const
ATensor
&
a
,
const
BTensor
&
b
,
...
...
@@ -82,6 +84,7 @@ struct WarpGemmImpl
template
<
typename
ATensor
,
typename
BTensor
>
CK_TILE_DEVICE
auto
operator
()(
const
ATensor
&
a
,
const
BTensor
&
b
)
const
{
using
CTensor
=
CWarpTensor
;
static_assert
(
detail
::
is_similiar_distributed_tensor_v
<
ATensor
,
ATensor
>
&&
detail
::
is_similiar_distributed_tensor_v
<
BTensor
,
BTensor
>
);
CTensor
c
;
...
...
include/ck_tile/ops/reduce/block/block_reduce.hpp
View file @
eed60199
...
...
@@ -160,10 +160,9 @@ CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor
// reduction sweep forward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
// TODO: lid_over_rid_derivative not ok in xor? maybe need limit the usage of
// xor
index_t
src_lane
=
(
__lane_id
()
*
lid_over_rid_derivative
)
^
(
number
<
1
<<
istage
.
value
>
{}.
value
);
index_t
src_lane
=
__lane_id
()
^
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
const
auto
v_remote
=
warp_shuffle
(
v_local
,
src_lane
);
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment