Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
fe488bf2
Unverified
Commit
fe488bf2
authored
Oct 22, 2024
by
rocking
Committed by
GitHub
Oct 22, 2024
Browse files
Merge branch 'develop' into layernorm/instance_support
parents
8bed8529
0394f8a7
Changes
80
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1041 additions
and
511 deletions
+1041
-511
example/ck_tile/05_reduce/reduce.hpp
example/ck_tile/05_reduce/reduce.hpp
+118
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+1
-0
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/arch/utility.hpp
include/ck_tile/core/arch/utility.hpp
+43
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+2
-0
include/ck_tile/core/container/sequence.hpp
include/ck_tile/core/container/sequence.hpp
+122
-0
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+20
-0
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+14
-0
include/ck_tile/core/tensor/sweep_tile.hpp
include/ck_tile/core/tensor/sweep_tile.hpp
+278
-0
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+35
-123
include/ck_tile/core/utility/functional_with_tuple.hpp
include/ck_tile/core/utility/functional_with_tuple.hpp
+173
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-1
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
+0
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+18
-9
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
...lock_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
+46
-21
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+2
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+13
-2
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-1
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+5
-2
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+148
-351
No files found.
example/ck_tile/05_reduce/reduce.hpp
0 → 100644
View file @
fe488bf2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
template
<
typename
ADataType
,
typename
AccDataType
,
typename
BDataType
,
index_t
kBlockSize
,
typename
BlockWarps
,
// num warps along seq<M, N>
typename
BlockTile
,
// block size, seq<M, N>
typename
WarpTile
,
// warp size, seq<M, N>
typename
ThreadTile
>
// contiguous pixels(vector size) along seq<M, N>
struct
Reduce
{
static
constexpr
index_t
Block_M
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_M
=
WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N
=
WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Thread_M
=
ThreadTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Thread_N
=
ThreadTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_M
=
BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N
=
BlockWarps
::
at
(
number
<
1
>
{});
static
constexpr
index_t
ThreadPerWarp_M
=
Warp_M
/
Thread_M
;
static
constexpr
index_t
ThreadPerWarp_N
=
Warp_N
/
Thread_N
;
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
WarpPerBlock_M
*
Warp_M
);
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
WarpPerBlock_N
*
Warp_N
);
__device__
static
constexpr
auto
MakeABlockTileDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Repeat_M
,
WarpPerBlock_M
,
ThreadPerWarp_M
,
Thread_M
>
,
sequence
<
Repeat_N
,
WarpPerBlock_N
,
ThreadPerWarp_N
,
Thread_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
0
,
3
,
0
,
3
>>
{});
}
__device__
void
operator
()(
const
ADataType
*
p_a
,
BDataType
*
p_b
,
index_t
M
,
index_t
N
)
const
{
const
auto
a_m_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_a
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
Thread_N
>
{},
number
<
1
>
{});
const
auto
iM
=
get_block_id
()
*
Block_M
;
// A window
auto
a_block_window
=
make_tile_window
(
a_m_n
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
},
MakeABlockTileDistribution
());
const
auto
f_reduce
=
[](
const
auto
&
v0
,
const
auto
&
v1
)
{
return
v0
+
v1
;
};
const
ADataType
reduce_init_value
=
0
;
constexpr
auto
reduce_dims
=
sequence
<
1
>
{};
// Acc tile
// TODO: support cross warp reduction
auto
acc_block_tensor
=
decltype
(
block_tile_reduce
<
AccDataType
>
(
load_tile
(
a_block_window
),
reduce_dims
,
f_reduce
,
reduce_init_value
)){};
// init Acc tile
tile_elementwise_inout
(
[
&
](
auto
&
acc
)
{
acc
=
type_convert
<
AccDataType
>
(
reduce_init_value
);
},
acc_block_tensor
);
// loop
index_t
iN
=
0
;
do
{
const
auto
a_block_tensor
=
load_tile
(
a_block_window
);
// FIXME: support cross warp reduction
block_tile_reduce
(
acc_block_tensor
,
a_block_tensor
,
reduce_dims
,
f_reduce
);
move_tile_window
(
a_block_window
,
{
0
,
Block_N
});
iN
+=
Block_N
;
}
while
(
iN
<
N
);
// FIXME: support cross warp reduction
block_tile_reduce_sync
(
acc_block_tensor
,
f_reduce
);
// convert acc_block_tensor to b_block_tensor
const
auto
b_block_tensor
=
tile_elementwise_in
(
[](
const
auto
&
acc
)
{
return
type_convert
<
BDataType
>
(
acc
);
},
acc_block_tensor
);
// B
const
auto
b_m
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_b
,
make_tuple
(
M
),
number
<
32
>
{});
// B window
auto
b_block_window
=
make_tile_window
(
b_m
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
// store B tile
store_tile
(
b_block_window
,
b_block_tensor
);
}
};
}
// namespace ck_tile
example/ck_tile/CMakeLists.txt
View file @
fe488bf2
...
@@ -6,3 +6,4 @@ add_subdirectory(01_fmha)
...
@@ -6,3 +6,4 @@ add_subdirectory(01_fmha)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
04_img2col
)
add_subdirectory
(
04_img2col
)
add_subdirectory
(
05_reduce
)
include/ck_tile/core.hpp
View file @
fe488bf2
...
@@ -52,6 +52,7 @@
...
@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
...
...
include/ck_tile/core/arch/utility.hpp
View file @
fe488bf2
...
@@ -59,4 +59,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
...
@@ -59,4 +59,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif
#endif
}
}
template
<
typename
T
>
CK_TILE_DEVICE
T
warp_shuffle
(
const
T
&
v_local
,
uint32_t
src_lane
)
{
#if 0
return __shfl(v_local, src_lane);
#elif
1
if
constexpr
(
sizeof
(
int32_t
)
>
sizeof
(
T
))
{
union
packet
{
int32_t
x
;
T
v
;
};
packet
p
;
p
.
v
=
v_local
;
packet
p_remote
;
p_remote
.
x
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
p
));
return
p_remote
.
v
;
}
else
if
constexpr
(
sizeof
(
int32_t
)
==
sizeof
(
T
))
{
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
v_local
));
return
bit_cast
<
T
>
(
v_remote_tmp
);
}
else
{
static_assert
(
sizeof
(
T
)
%
sizeof
(
int32_t
)
==
0
,
"wrong!"
);
constexpr
index_t
elm
=
sizeof
(
T
)
/
sizeof
(
int32_t
);
using
vector_type
=
thread_buffer
<
int32_t
,
elm
>
;
auto
vs
=
bit_cast
<
vector_type
>
(
v_local
);
auto
vs_remote
=
vector_type
{};
static_for
<
0
,
elm
,
1
>
{}([
&
](
auto
i_e
)
{
int32_t
tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
vs
[
i_e
]));
vs_remote
(
i_e
)
=
tmp
;
});
return
bit_cast
<
T
>
(
vs_remote
);
}
#endif
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/config.hpp
View file @
fe488bf2
...
@@ -32,11 +32,13 @@
...
@@ -32,11 +32,13 @@
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else
#else
#define CK_TILE_HOST inline
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
...
...
include/ck_tile/core/container/sequence.hpp
View file @
fe488bf2
...
@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
...
@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
}
namespace
impl
{
template
<
typename
,
typename
,
typename
,
index_t
>
struct
reverse_slice_sequence_impl
;
template
<
index_t
x
,
index_t
...
xs
,
index_t
m
,
index_t
...
ms
,
index_t
id
,
index_t
...
ids
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
,
xs
...
>
,
sequence
<
m
,
ms
...
>
,
sequence
<
id
,
ids
...
>
,
SliceSize
>
{
using
old_scan
=
reverse_slice_sequence_impl
<
sequence
<
xs
...
>
,
sequence
<
ms
...
>
,
sequence
<
ids
...
>
,
SliceSize
>
;
static
constexpr
auto
slice_size
=
old_scan
::
remaining_slice_sizes
::
front
().
value
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
typename
sequence_merge
<
sequence
<
slice_length
>
,
typename
old_scan
::
dim_lengths
>::
type
;
using
dim_slices
=
typename
sequence_merge
<
sequence
<
x
/
slice_length
>
,
typename
old_scan
::
dim_slices
>::
type
;
using
remaining_slice_sizes
=
typename
sequence_merge
<
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
,
typename
old_scan
::
remaining_slice_sizes
>::
type
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
_split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
_split_idx
=
std
::
conditional_t
<
_split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_flag
=
_split_flag
||
old_scan
::
split_flag
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
old_scan
::
split_flag
,
number
<
old_scan
::
split_idx
>
,
number
<
_split_idx
>>::
value
;
};
template
<
index_t
x
,
index_t
m
,
index_t
id
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
>
,
sequence
<
m
>
,
sequence
<
id
>
,
SliceSize
>
{
static
constexpr
auto
slice_size
=
SliceSize
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
sequence
<
slice_length
>
;
using
dim_slices
=
sequence
<
x
/
slice_length
>
;
using
remaining_slice_sizes
=
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
};
}
// namespace impl
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template
<
typename
Seq
,
index_t
SliceSize
,
typename
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>
::
type
>
constexpr
auto
reverse_slice_sequence
(
Seq
,
number
<
SliceSize
>
,
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>::
type
{})
{
static_assert
(
Seq
::
size
()
==
Mask
::
size
());
using
sliced_type
=
impl
::
reverse_slice_sequence_impl
<
Seq
,
Mask
,
typename
arithmetic_sequence_gen
<
0
,
Seq
::
size
(),
1
>::
type
,
SliceSize
>
;
static_assert
(
sliced_type
::
remaining_slice_sizes
::
front
().
value
==
1
,
"can not evenly divide this sequence, please check"
);
return
make_tuple
(
typename
sliced_type
::
dim_lengths
{},
typename
sliced_type
::
dim_slices
{},
number
<
sliced_type
::
split_idx
>
{});
}
template
<
typename
Seq
,
index_t
SliceSize
,
typename
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>
::
type
>
constexpr
auto
slice_sequence
(
Seq
,
number
<
SliceSize
>
,
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>::
type
{})
{
constexpr
auto
r
=
reverse_slice_sequence
(
Seq
{}.
reverse
(),
number
<
SliceSize
>
{},
Mask
{}.
reverse
());
return
make_tuple
(
r
[
number
<
0
>
{}].
reverse
(),
r
[
number
<
1
>
{}].
reverse
(),
number
<
Seq
::
size
()
-
r
[
number
<
2
>
{}]
-
1
>
{});
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/container/tuple.hpp
View file @
fe488bf2
...
@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
...
@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
}
namespace
detail
{
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
embed_tuples_impl
(
F
f
,
const
X
&
x
,
sequence
<
Is
...
>
)
{
return
concat_tuple
(
f
(
x
.
at
(
number
<
Is
>
{}))...);
}
}
// namespace detail
// make sure F return at least a tuple
// e.g. x : tuple<X, Y>, f will return tuple<Z, W>
// this function will return
template
<
typename
F
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
embed_tuples
(
F
f
,
const
X
&
x
)
{
return
detail
::
embed_tuples_impl
(
f
,
x
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
// By default unroll to the flatten
// By default unroll to the flatten
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<>&
t
)
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<>&
t
)
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
fe488bf2
...
@@ -187,4 +187,18 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
...
@@ -187,4 +187,18 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
});
});
}
}
// this function used inside span loop over
template
<
typename
YLengths
,
index_t
XUnpacks
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_y_unpacks_from_x_unpacks
(
YLengths
,
number
<
XUnpacks
>
)
{
constexpr
auto
y_size
=
reduce_on_sequence
(
YLengths
{},
multiplies
{},
number
<
1
>
{});
constexpr
auto
y_packs
=
number
<
XUnpacks
>
{};
static_assert
(
y_size
%
y_packs
==
0
);
constexpr
auto
y_slice_size
=
y_size
/
y_packs
;
constexpr
auto
slice_info
=
slice_sequence
(
YLengths
{},
number
<
y_slice_size
>
{});
constexpr
auto
unpacks
=
slice_info
[
number
<
1
>
{}];
return
unpacks
;
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/tensor/sweep_tile.hpp
View file @
fe488bf2
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -27,4 +28,281 @@ CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
...
@@ -27,4 +28,281 @@ CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
});
});
}
}
// unpacked span, this version support span with unpack(multi-arg) functor
//
template
<
typename
TileDistributedSpan_
,
// tile_distributed_span<...>
typename
F
,
// signature: F(tile_distributed_index<...>)
typename
Unpacks
=
typename
uniform_sequence_gen
<
TileDistributedSpan_
::
Impl
::
size
(),
1
>
::
type
>
CK_TILE_DEVICE
void
sweep_tile_uspan
(
TileDistributedSpan_
,
const
F
&
f
,
Unpacks
=
{})
{
using
DstrSpan
=
remove_cvref_t
<
TileDistributedSpan_
>
;
static_uford
<
typename
DstrSpan
::
Impl
,
Unpacks
>
{}(
[
&
](
auto
...
dstr_idx_impl
)
{
f
(
detail
::
make_tile_distributed_index
(
dstr_idx_impl
)...);
});
}
namespace
impl
{
template
<
typename
,
typename
,
typename
>
struct
sweep_tile_impl
;
template
<
typename
DistributedTensor
,
typename
UnpacksPerXDim
,
index_t
I
,
index_t
...
Is
>
struct
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
I
,
Is
...
>>
{
CK_TILE_HOST_DEVICE
constexpr
auto
get_y_unpacks
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
y_lengths
=
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
{};
constexpr
auto
x_unpacks
=
number
<
UnpacksPerXDim
{}.
at
(
number
<
I
>
{})
>
{};
constexpr
auto
y_unpacks
=
get_y_unpacks_from_x_unpacks
(
y_lengths
,
x_unpacks
);
return
y_unpacks
;
}
CK_TILE_HOST_DEVICE
constexpr
index_t
get_num_of_access
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
return
u
.
get_num_of_access
()
*
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
}
template
<
typename
F
,
typename
SpanIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
sweep_tile_uspan
(
spans
[
number
<
I
>
{}],
[
&
](
auto
...
i_idx
)
{
const
auto
next_span_idx
=
embed_tuples
(
[
&
](
auto
si
)
{
return
make_tuple
(
concat_tuple
(
si
,
make_tuple
(
i_idx
))...);
},
span_idx
);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
);
},
get_y_unpacks
());
}
template
<
typename
F
,
typename
SpanIdx
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
,
number
<
i_access
>
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
constexpr
auto
access_stride
=
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
constexpr
auto
curr_i_access
=
number
<
i_access
/
access_stride
>
{};
constexpr
auto
next_i_access
=
number
<
i_access
%
access_stride
>
{};
u
(
[
&
](
auto
...
i_idx
)
{
const
auto
next_span_idx
=
embed_tuples
(
[
&
](
auto
si
)
{
return
make_tuple
(
concat_tuple
(
si
,
make_tuple
(
detail
::
make_tile_distributed_index
(
i_idx
)))...);
},
span_idx
);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
,
next_i_access
);
},
curr_i_access
);
}
};
template
<
typename
DistributedTensor
,
typename
UnpacksPerXDim
>
struct
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<>>
{
CK_TILE_HOST_DEVICE
constexpr
index_t
get_num_of_access
()
const
{
return
1
;
}
template
<
typename
F
,
typename
SpanIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
)
const
{
unpack
(
f
,
span_idx
);
}
template
<
typename
F
,
typename
SpanIdx
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
,
number
<
i_access
>
)
const
{
unpack
(
f
,
span_idx
);
}
};
template
<
typename
,
typename
,
typename
>
struct
sweep_tile_impl_0
;
// TODO: support empty tuple to remove this "entry-point" like function
template
<
typename
DistributedTensor
,
typename
UnpacksPerXDim
,
index_t
I
,
index_t
...
Is
>
struct
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
I
,
Is
...
>>
{
CK_TILE_HOST_DEVICE
constexpr
auto
get_y_unpacks
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
y_lengths
=
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
{};
constexpr
auto
x_unpacks
=
number
<
UnpacksPerXDim
{}.
at
(
number
<
I
>
{})
>
{};
constexpr
auto
y_unpacks
=
get_y_unpacks_from_x_unpacks
(
y_lengths
,
x_unpacks
);
return
y_unpacks
;
}
CK_TILE_HOST_DEVICE
constexpr
index_t
get_num_of_access
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
return
u
.
get_num_of_access
()
*
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
}
template
<
typename
F
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
sweep_tile_uspan
(
spans
[
number
<
I
>
{}],
[
&
](
auto
...
i_idx
)
{
constexpr
auto
next_span_idx
=
make_tuple
(
make_tuple
(
i_idx
)...);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
);
},
get_y_unpacks
());
}
template
<
typename
F
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
number
<
i_access
>
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
constexpr
auto
access_stride
=
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
constexpr
auto
curr_i_access
=
number
<
i_access
/
access_stride
>
{};
constexpr
auto
next_i_access
=
number
<
i_access
%
access_stride
>
{};
u
(
[
&
](
auto
...
i_idx
)
{
constexpr
auto
next_span_idx
=
make_tuple
(
make_tuple
(
detail
::
make_tile_distributed_index
(
i_idx
))...);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
,
next_i_access
);
},
curr_i_access
);
}
};
}
// namespace impl
/*
* Enhanced sweep-tile utility, can control unpacks along each X-dim
* the lambda function argument is the distributed-idx, which can directly
* plugged into the distributed tensor as setter/getter
*
* e.g. below function, y with the type DistributedTensor, r is row scale
*
* // sweep tile 1 by 1
* sweep_tile<DistributedTensor>([&](auto idx) {
* constexpr auto row_id = make_tuple(idx[number<0>{}]);
* y(idx) = y(idx) * r(row_id);
* });
*
* // sweep tile with 2 pixel from last dim each function call
* sweep_tile<DistributedTensor>(
* [&](auto idx_0, auto idx_1) {
* constexpr auto row_id = make_tuple(idx_0[number<0>{}]);
* y(idx_0) = y(idx_0) * r(row_id);
* y(idx_1) = y(idx_1) * r(row_id);
* },
* sequence<1, 2>{});
*
* // sweep tile with 2x2 pixel each function call
* sweep_tile<DistributedTensor>(
* [&](auto idx_00, auto idx_01, auto idx_10, auto idx_11) {
* constexpr auto row_id0 = make_tuple(idx_00[number<0>{}]);
* constexpr auto row_id1 = make_tuple(idx_10[number<0>{}]);
* y(idx_00) = y(idx_00) * r(row_id0);
* y(idx_01) = y(idx_01) * r(row_id0);
* y(idx_10) = y(idx_10) * r(row_id1);
* y(idx_11) = y(idx_11) * r(row_id1);
* },
* sequence<2, 2>{});
*
* TODO: do we need constexpr? lambda function could be non-constexpr
*/
template
<
typename
DistributedTensor
,
typename
F
,
typename
UnpacksPerXDim
=
typename
uniform_sequence_gen
<
DistributedTensor
::
get_num_of_dimension
(),
1
>
::
type
>
CK_TILE_HOST_DEVICE
constexpr
void
sweep_tile
(
const
F
&
f
,
UnpacksPerXDim
=
{})
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
impl
::
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
typename
arithmetic_sequence_gen
<
0
,
spans
.
size
(),
1
>::
type
>
{}(
f
);
}
template
<
typename
DistributedTensor
,
typename
F
,
typename
UnpacksPerXDim
=
typename
uniform_sequence_gen
<
DistributedTensor
::
get_num_of_dimension
(),
1
>
::
type
>
CK_TILE_HOST_DEVICE
constexpr
void
sweep_tile
(
const
DistributedTensor
&
,
const
F
&
f
,
UnpacksPerXDim
=
{})
{
sweep_tile
<
DistributedTensor
,
F
,
UnpacksPerXDim
>
(
f
,
UnpacksPerXDim
{});
}
/*
* construct a sweep tile instance, which support issue the lambda one by one
* Note that this struct will hold the lambda functor, but will not hold the distributed tensor
* the functionality is the same as sweep_tile()
*/
template
<
typename
DistributedTensor_
,
typename
F_
,
typename
UnpacksPerXDim_
=
typename
uniform_sequence_gen
<
DistributedTensor_
::
get_num_of_dimension
(),
1
>
::
type
>
struct
tile_sweeper
{
using
DistributedTensor
=
remove_cvref_t
<
DistributedTensor_
>
;
using
F
=
remove_cvref_t
<
F_
>
;
using
UnpacksPerXDim
=
remove_cvref_t
<
UnpacksPerXDim_
>
;
CK_TILE_HOST_DEVICE
tile_sweeper
(
const
F
&
f_
,
UnpacksPerXDim
=
{})
:
f
(
f_
)
{}
CK_TILE_HOST_DEVICE
tile_sweeper
(
const
DistributedTensor
&
,
const
F
&
f_
,
UnpacksPerXDim
=
{})
:
f
(
f_
)
{
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_access
()
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
tmp
=
impl
::
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
typename
arithmetic_sequence_gen
<
0
,
spans
.
size
(),
1
>::
type
>
{};
return
tmp
.
get_num_of_access
();
}
CK_TILE_HOST_DEVICE
void
operator
()()
const
{
sweep_tile
<
DistributedTensor
>
(
f
,
UnpacksPerXDim
{});
}
template
<
index_t
i_access
>
CK_TILE_HOST_DEVICE
void
operator
()(
number
<
i_access
>
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
impl
::
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
typename
arithmetic_sequence_gen
<
0
,
spans
.
size
(),
1
>::
type
>
{}(
f
,
number
<
i_access
>
{});
}
F
f
;
};
// partial deduction is not allowed
// template <typename T, typename F, typename U>
// CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
// deduction guide
template
<
typename
T
,
typename
F
,
typename
U
=
typename
uniform_sequence_gen
<
T
::
get_num_of_dimension
(),
1
>
::
type
>
CK_TILE_HOST_DEVICE_EXTERN
tile_sweeper
(
const
T
&
,
const
F
&
,
U
=
{})
->
tile_sweeper
<
T
,
F
,
U
>
;
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
fe488bf2
...
@@ -17,6 +17,14 @@
...
@@ -17,6 +17,14 @@
namespace
ck_tile
{
namespace
ck_tile
{
namespace
detail
{
template
<
typename
Distribution
>
CK_TILE_HOST_DEVICE
auto
get_partition_index
(
Distribution
)
{
return
Distribution
::
_get_partition_index
();
}
}
// namespace detail
// distributed span
// distributed span
template
<
index_t
...
PartialHsLengths
>
template
<
index_t
...
PartialHsLengths
>
struct
tile_distributed_span
struct
tile_distributed_span
...
@@ -83,6 +91,21 @@ struct tile_distribution
...
@@ -83,6 +91,21 @@ struct tile_distribution
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_p
()
{
return
NDimP
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_p
()
{
return
NDimP
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_r
()
{
return
NDimR
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_r
()
{
return
NDimR
;
}
CK_TILE_HOST_DEVICE
static
auto
_get_partition_index
()
{
// only support warp-tile and block-tile
static_assert
(
NDimP
==
1
or
NDimP
==
2
,
"wrong!"
);
if
constexpr
(
NDimP
==
1
)
{
return
array
<
index_t
,
1
>
{
get_lane_id
()};
}
else
if
constexpr
(
NDimP
==
2
)
{
return
array
<
index_t
,
2
>
{
get_warp_id
(),
get_lane_id
()};
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lengths
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lengths
()
{
{
#if 0
#if 0
...
@@ -149,6 +172,16 @@ struct tile_distribution
...
@@ -149,6 +172,16 @@ struct tile_distribution
}
}
#endif
#endif
template
<
typename
PartitionIndex
=
decltype
(
_get_partition_index
())>
CK_TILE_HOST_DEVICE
auto
calculate_index
(
const
PartitionIndex
&
ps_idx
=
_get_partition_index
())
const
{
const
auto
ps_ys_idx
=
container_concat
(
ps_idx
,
array
<
index_t
,
NDimY
>
{
0
});
const
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
ps_ys_to_xs_
,
ps_ys_idx
);
return
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_distributed_spans
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_distributed_spans
()
{
{
constexpr
auto
distributed_spans_impl
=
DstrEncode
::
detail
::
distributed_spans_lengthss_
;
constexpr
auto
distributed_spans_impl
=
DstrEncode
::
detail
::
distributed_spans_lengthss_
;
...
@@ -421,6 +454,7 @@ struct tile_distribution_detail
...
@@ -421,6 +454,7 @@ struct tile_distribution_detail
}
// namespace detail
}
// namespace detail
#if 0
// this returns a constexpr tile_distribution
// this returns a constexpr tile_distribution
template <typename StaticTileDistributionEncoding_>
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
...
@@ -457,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution
...
@@ -457,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
}
#endif
// this returns a static tile_distribution
// this returns a static tile_distribution
template
<
typename
StaticTileDistributionEncoding_
>
template
<
typename
StaticTileDistributionEncoding_
>
...
@@ -499,129 +534,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr
...
@@ -499,129 +534,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr
//***********************************************************************************
//***********************************************************************************
namespace
detail
{
namespace
detail
{
template
<
typename
Distribution
>
CK_TILE_HOST_DEVICE
auto
get_partition_index
(
Distribution
)
{
// only support warp-tile and block-tile
static_assert
(
Distribution
::
NDimP
==
1
or
Distribution
::
NDimP
==
2
,
"wrong!"
);
if
constexpr
(
Distribution
::
NDimP
==
1
)
{
return
array
<
index_t
,
1
>
{
get_lane_id
()};
}
else
if
constexpr
(
Distribution
::
NDimP
==
2
)
{
return
array
<
index_t
,
2
>
{
get_warp_id
(),
get_lane_id
()};
}
}
template
<
typename
,
typename
,
typename
,
index_t
>
struct
reverse_slice_sequence_impl
;
template
<
index_t
x
,
index_t
...
xs
,
index_t
m
,
index_t
...
ms
,
index_t
id
,
index_t
...
ids
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
,
xs
...
>
,
sequence
<
m
,
ms
...
>
,
sequence
<
id
,
ids
...
>
,
SliceSize
>
{
using
old_scan
=
reverse_slice_sequence_impl
<
sequence
<
xs
...
>
,
sequence
<
ms
...
>
,
sequence
<
ids
...
>
,
SliceSize
>
;
static
constexpr
auto
slice_size
=
old_scan
::
remaining_slice_sizes
::
front
().
value
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
typename
sequence_merge
<
sequence
<
slice_length
>
,
typename
old_scan
::
dim_lengths
>::
type
;
using
dim_slices
=
typename
sequence_merge
<
sequence
<
x
/
slice_length
>
,
typename
old_scan
::
dim_slices
>::
type
;
using
remaining_slice_sizes
=
typename
sequence_merge
<
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
,
typename
old_scan
::
remaining_slice_sizes
>::
type
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
_split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
_split_idx
=
std
::
conditional_t
<
_split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_flag
=
_split_flag
||
old_scan
::
split_flag
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
old_scan
::
split_flag
,
number
<
old_scan
::
split_idx
>
,
number
<
_split_idx
>>::
value
;
};
template
<
index_t
x
,
index_t
m
,
index_t
id
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
>
,
sequence
<
m
>
,
sequence
<
id
>
,
SliceSize
>
{
static
constexpr
auto
slice_size
=
SliceSize
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
sequence
<
slice_length
>
;
using
dim_slices
=
sequence
<
x
/
slice_length
>
;
using
remaining_slice_sizes
=
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
};
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template
<
typename
Seq
,
index_t
SliceSize
,
typename
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>
::
type
>
constexpr
auto
reverse_slice_sequence
(
Seq
,
number
<
SliceSize
>
,
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>::
type
{})
{
static_assert
(
Seq
::
size
()
==
Mask
::
size
());
using
sliced_type
=
reverse_slice_sequence_impl
<
Seq
,
Mask
,
typename
arithmetic_sequence_gen
<
0
,
Seq
::
size
(),
1
>::
type
,
SliceSize
>
;
static_assert
(
sliced_type
::
remaining_slice_sizes
::
front
().
value
==
1
,
"can not evenly divide this sequence, please check"
);
return
make_tuple
(
typename
sliced_type
::
dim_lengths
{},
typename
sliced_type
::
dim_slices
{},
number
<
sliced_type
::
split_idx
>
{});
}
//
//
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// We don't support slice cross p_dim (aka, slice different threads)
// We don't support slice cross p_dim (aka, slice different threads)
...
...
include/ck_tile/core/utility/functional_with_tuple.hpp
0 → 100644
View file @
fe488bf2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// This file should not be included inside tuple.hpp!
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
#include <utility>
namespace
ck_tile
{
namespace
detail
{
// RemainLengths: sequence<...>
// Orders: sequence<...>
template
<
class
RemainLengths
,
class
RamainUnpacks
,
class
Orders
>
struct
static_uford_impl
{
CK_TILE_HOST_DEVICE
constexpr
static_uford_impl
()
{
static_assert
(
RemainLengths
::
size
()
>
0
,
"wrong! should not get here"
);
static_assert
(
RamainUnpacks
::
size
()
>
0
,
"wrong! should not get here"
);
}
template
<
class
F
,
class
CurrentUnpackIds
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
CurrentUnpackIds
)
const
{
constexpr
index_t
pack_len
=
RamainUnpacks
::
front
();
static_for
<
0
,
RemainLengths
::
front
(),
pack_len
>
{}([
=
](
auto
I
)
{
constexpr
auto
new_pack
=
generate_tuple
(
[
&
](
auto
idx_
)
{
constexpr
auto
i_new_pack
=
number
<
I
+
idx_
%
pack_len
>
{};
constexpr
auto
i_pre_pack
=
number
<
idx_
/
pack_len
>
{};
return
CurrentUnpackIds
{}.
at
(
i_pre_pack
).
push_back
(
i_new_pack
);
},
number
<
CurrentUnpackIds
::
size
()
*
pack_len
>
{});
static_uford_impl
<
decltype
(
RemainLengths
::
pop_front
()),
decltype
(
RamainUnpacks
::
pop_front
()),
Orders
>
{}(
f
,
new_pack
);
});
}
};
template
<
class
Orders
>
struct
static_uford_impl
<
sequence
<>
,
sequence
<>
,
Orders
>
{
template
<
class
F
,
class
PackedId
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
PackedId
)
const
{
constexpr
auto
origin_packs
=
transform_tuples
(
[](
auto
pack_
)
{
return
decltype
(
pack_
)
::
reorder_old_to_new
(
Orders
{});
},
PackedId
{});
unpack
(
f
,
origin_packs
);
}
};
template
<
class
RemainLengths
,
class
RamainUnpacks
,
class
Orders
>
struct
static_uford_one_shot_impl
{
template
<
class
F
,
class
CurrentUnpackIds
,
index_t
current_acc
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
CurrentUnpackIds
,
number
<
current_acc
>
)
const
{
constexpr
auto
r_lens_stride
=
reverse_exclusive_scan_sequence
(
RemainLengths
{},
multiplies
{},
number
<
1
>
{});
constexpr
auto
r_upks_stride
=
reverse_exclusive_scan_sequence
(
RamainUnpacks
{},
multiplies
{},
number
<
1
>
{});
constexpr
index_t
current_stride
=
r_lens_stride
.
front
()
/
r_upks_stride
.
front
();
constexpr
index_t
pack_len
=
RamainUnpacks
::
front
();
constexpr
index_t
current_idx
=
(
current_acc
/
current_stride
)
*
pack_len
;
constexpr
auto
new_pack
=
generate_tuple
(
[
&
](
auto
idx_
)
{
constexpr
auto
i_new_pack
=
number
<
current_idx
+
idx_
%
pack_len
>
{};
constexpr
auto
i_pre_pack
=
number
<
idx_
/
pack_len
>
{};
return
CurrentUnpackIds
{}.
at
(
i_pre_pack
).
push_back
(
i_new_pack
);
},
number
<
CurrentUnpackIds
::
size
()
*
pack_len
>
{});
static_uford_one_shot_impl
<
decltype
(
RemainLengths
::
pop_front
()),
decltype
(
RamainUnpacks
::
pop_front
()),
Orders
>
{}(
f
,
new_pack
,
number
<
current_acc
%
current_stride
>
{});
}
};
template
<
class
Orders
>
struct
static_uford_one_shot_impl
<
sequence
<>
,
sequence
<>
,
Orders
>
{
template
<
class
F
,
class
PackedId
,
index_t
current_acc
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
PackedId
,
number
<
current_acc
>
)
const
{
constexpr
auto
origin_packs
=
transform_tuples
(
[](
auto
pack_
)
{
return
decltype
(
pack_
)
::
reorder_old_to_new
(
Orders
{});
},
PackedId
{});
unpack
(
f
,
origin_packs
);
}
};
}
// namespace detail
// TODO: we may unify static_ford/static_uford in the future
//
// loop over nd space(sequence) with packs
// you must make sure the function passed in has same number of argument
//
// e.g.
// Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2>
// static_uford<Lengths, Unpacks>{}([&](auto i_0, auto i_1){}); // require 2 args(packs)
//
// loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1>
// loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3>
// loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1>
// loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3>
// loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1>
// loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3>
// loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1>
// ...
template
<
class
Lengths
,
class
Unpacks
=
typename
uniform_sequence_gen
<
Lengths
::
size
(),
1
>
::
type
,
class
Orders
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
size
(),
1
>::
type
>
struct
static_uford
{
static
constexpr
index_t
num_packs
=
reduce_on_sequence
(
Unpacks
{},
multiplies
{},
number
<
1
>
{});
CK_TILE_HOST_DEVICE
constexpr
static_uford
()
{
static_assert
(
Lengths
::
size
()
>
0
,
"wrong! Lengths is empty"
);
static_assert
(
Lengths
::
size
()
==
Unpacks
::
size
(),
"wrong! inconsistent size"
);
static_assert
(
Lengths
::
size
()
==
Orders
::
size
(),
"wrong! inconsistent size"
);
static_for
<
0
,
Lengths
::
size
(),
1
>
{}(
[
&
](
auto
i
)
{
static_assert
(
Lengths
{}.
at
(
i
)
%
Unpacks
{}.
at
(
i
)
==
0
);
});
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_access
()
{
using
L_
=
decltype
(
Lengths
{}
/
Unpacks
{});
return
reduce_on_sequence
(
L_
{},
multiplies
{},
number
<
1
>
{});
}
// F signature: F(sequence<...> multi_id...)
// multi_id is the unordered multi-index
template
<
class
F
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
)
const
{
constexpr
auto
ordered_lengths
=
Lengths
::
reorder_new_to_old
(
Orders
{});
constexpr
auto
ordered_unpacks
=
Unpacks
::
reorder_new_to_old
(
Orders
{});
detail
::
static_uford_impl
<
decltype
(
ordered_lengths
),
decltype
(
ordered_unpacks
),
Orders
>
{}(
f
,
make_tuple
(
sequence
<>
{}));
}
// this version is friendly for issue function one by one
template
<
class
F
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
number
<
i_access
>
)
const
{
static_assert
(
i_access
<
get_num_of_access
());
constexpr
auto
ordered_lengths
=
Lengths
::
reorder_new_to_old
(
Orders
{});
constexpr
auto
ordered_unpacks
=
Unpacks
::
reorder_new_to_old
(
Orders
{});
detail
::
static_uford_one_shot_impl
<
decltype
(
ordered_lengths
),
decltype
(
ordered_unpacks
),
Orders
>
{}(
f
,
make_tuple
(
sequence
<>
{}),
number
<
i_access
>
{});
}
};
}
// namespace ck_tile
include/ck_tile/host.hpp
View file @
fe488bf2
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d.hpp"
#include "ck_tile/host/reference/reference_layernorm2d
_fwd
.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_config.hpp"
...
...
include/ck_tile/host/reference/reference_layernorm2d.hpp
→
include/ck_tile/host/reference/reference_layernorm2d
_fwd
.hpp
View file @
fe488bf2
File moved
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
fe488bf2
...
@@ -12,6 +12,16 @@ namespace detail {
...
@@ -12,6 +12,16 @@ namespace detail {
template
<
index_t
N
>
template
<
index_t
N
>
struct
log2
;
struct
log2
;
template
<
>
struct
log2
<
4
>
:
std
::
integral_constant
<
index_t
,
2
>
{
};
template
<
>
struct
log2
<
8
>
:
std
::
integral_constant
<
index_t
,
3
>
{
};
template
<
>
template
<
>
struct
log2
<
16
>
:
std
::
integral_constant
<
index_t
,
4
>
struct
log2
<
16
>
:
std
::
integral_constant
<
index_t
,
4
>
{
{
...
@@ -72,18 +82,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -72,18 +82,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{
{
if
constexpr
(
kHeadDimV
<=
32
)
if
constexpr
(
kHeadDimV
<=
32
)
{
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
3
,
3
,
3
,
1
};
constexpr
std
::
array
occupancy
{
3
,
3
,
3
,
3
,
3
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
}
else
if
constexpr
(
kHeadDimV
<=
128
)
else
if
constexpr
(
kHeadDimV
<=
128
)
{
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
3
,
3
,
2
,
1
};
constexpr
std
::
array
occupancy
{
3
,
3
,
3
,
3
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
}
else
if
constexpr
(
kHeadDimV
<=
256
)
else
if
constexpr
(
kHeadDimV
<=
256
)
{
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
2
,
2
,
2
,
1
};
constexpr
std
::
array
occupancy
{
2
,
2
,
2
,
2
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
}
}
}
}();
}();
...
@@ -138,9 +148,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -138,9 +148,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto
lse_accum
=
make_static_distributed_tensor
<
LSEDataType
>
(
auto
lse_accum
=
make_static_distributed_tensor
<
LSEDataType
>
(
Policy
::
template
MakeLSEaccRegTileDistribution
<
Problem
>());
Policy
::
template
MakeLSEaccRegTileDistribution
<
Problem
>());
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)])
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// this will extend the distributed tensor width so that each thread in wave have data to
// and fill up -INF values outside the [kM0, num_splits] region.
// reduce.
{
{
constexpr
auto
spans
=
decltype
(
lse_accum
)
::
get_distributed_spans
();
constexpr
auto
spans
=
decltype
(
lse_accum
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
View file @
fe488bf2
...
@@ -10,11 +10,26 @@ namespace ck_tile {
...
@@ -10,11 +10,26 @@ namespace ck_tile {
struct
BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
struct
BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
{
template
<
index_t
BlockSize
,
index_t
M
,
index_t
N
,
typename
DataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeForTile
()
{
constexpr
index_t
PixelsPerThread
=
(
M
*
N
)
/
BlockSize
;
static_assert
(
0
<
PixelsPerThread
);
constexpr
index_t
MaxNPerThread
=
16
/
sizeof
(
DataType
);
constexpr
index_t
NPerThread
=
min
(
MaxNPerThread
,
PixelsPerThread
);
return
NPerThread
;
}
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentLSE
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentLSE
()
{
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
return
GetVectorSizeForTile
<
Problem
::
kBlockSize
,
return
16
/
sizeof
(
LSEDataType
);
Problem
::
kMaxSplits
,
Problem
::
kM0
,
typename
Problem
::
LSEDataType
>
();
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -47,29 +62,31 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
...
@@ -47,29 +62,31 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
MakeLSEaccLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
MakeLSEaccLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
}
// shape=[kMaxSplits, kM0]
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccDramTileDistribution
()
{
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNumWarps
=
Problem
::
kNumWarps
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
NPerThread
=
16
/
sizeof
(
LSEDataType
);
constexpr
index_t
NPerThread
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
index_t
NThreads
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
NThreads
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
MThreadsPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
MThreadsPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
TotalWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
kNumWarps
*
MThreadsPerWarp
);
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
TotalWarps
*
MThreadsPerWarp
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
MPerThread
*
Total
Warps
*
MThreadsPerWarp
==
kMPerBlock
);
static_assert
(
MPerThread
*
kNum
Warps
*
MThreadsPerWarp
==
kMPerBlock
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MPerThread
,
Total
Warps
,
MThreadsPerWarp
>
,
tuple
<
sequence
<
MPerThread
,
kNum
Warps
,
MThreadsPerWarp
>
,
sequence
<
NThreads
,
NPerThread
>>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
...
@@ -77,15 +94,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
...
@@ -77,15 +94,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
sequence
<
0
,
1
>>
{});
sequence
<
0
,
1
>>
{});
}
}
// 3d + padding, [kMaxSplits, kM0]
// 3d + padding,
shape=
[kMaxSplits, kM0]
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccLdsStoreBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccLdsStoreBlockDescriptor
()
{
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NPack
=
16
/
sizeof
(
LSEDataType
);
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
...
@@ -103,15 +123,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
...
@@ -103,15 +123,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return
lse_acc_lds_block_desc
;
return
lse_acc_lds_block_desc
;
}
}
// 3d + padding, [kM0, kMaxSplits]
// 3d + padding,
shape=
[kM0, kMaxSplits]
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccLdsBlockDescriptor
()
{
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NPack
=
16
/
sizeof
(
LSEDataType
);
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
...
@@ -134,26 +157,28 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
...
@@ -134,26 +157,28 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
max
(
Problem
::
kMaxSplits
,
get_warp_size
())
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NThreads
=
get_warp_size
()
;
constexpr
index_t
NThreads
=
4
;
constexpr
index_t
NPerThread
=
kNPerBlock
/
NThreads
;
constexpr
index_t
NPerThread
=
kNPerBlock
/
NThreads
;
constexpr
index_t
MThreads
=
kBlockSize
/
NThreads
;
constexpr
index_t
MThreads
=
kBlockSize
/
NThreads
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
MThreads
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
MThreads
;
constexpr
index_t
MWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
NThreads
;
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
M
Threads
*
MPerThread
==
kMPerBlock
);
static_assert
(
M
Warps
*
MThreadPerWarp
*
MPerThread
==
kMPerBlock
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<
1
>
,
sequence
<
1
>
,
tuple
<
sequence
<
M
Threads
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
M
Warps
,
MThreadPerWarp
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
0
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
sequence
<
2
,
1
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
fe488bf2
...
@@ -115,7 +115,8 @@ struct BlockFmhaSplitKVCombinePipelineProblem
...
@@ -115,7 +115,8 @@ struct BlockFmhaSplitKVCombinePipelineProblem
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
256
;
static
constexpr
index_t
kNumWarps
=
kM0_
/
(
get_warp_size
()
/
4
);
static
constexpr
index_t
kBlockSize
=
kNumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kHeadDimV
=
HeadDimV_
;
static
constexpr
index_t
kHeadDimV
=
HeadDimV_
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
fe488bf2
...
@@ -88,22 +88,33 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -88,22 +88,33 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static_assert
(
WarpGemmM
==
16
||
WarpGemmM
==
32
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
{
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
// WarpGemmM == 16
return
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
{};
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
{
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
// WarpGemmM == 16
return
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
{};
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
{
static_assert
(
WarpGemmM
==
32
);
// TODO: hard coded here. Otherwise, it may incorrect result
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr
index_t
swizzle_factor
=
4
;
constexpr
index_t
swizzle_factor
=
4
;
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<
...
...
include/ck_tile/ops/gemm.hpp
View file @
fe488bf2
...
@@ -23,12 +23,12 @@
...
@@ -23,12 +23,12 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
...
...
include/ck_tile/ops/layernorm2d.hpp
View file @
fe488bf2
...
@@ -4,6 +4,9 @@
...
@@ -4,6 +4,9 @@
#pragma once
#pragma once
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
fe488bf2
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment