Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5cfd751b
Commit
5cfd751b
authored
Oct 20, 2024
by
carlushuang
Browse files
refactor layernorm2d pipeline and add block-per-block utility
parent
68e67701
Changes
40
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1475 additions
and
597 deletions
+1475
-597
example/ck_tile/02_layernorm2d/script/smoke_test.sh
example/ck_tile/02_layernorm2d/script/smoke_test.sh
+22
-0
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+2
-0
include/ck_tile/core/container/sequence.hpp
include/ck_tile/core/container/sequence.hpp
+122
-0
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+20
-0
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+14
-0
include/ck_tile/core/tensor/sweep_tile.hpp
include/ck_tile/core/tensor/sweep_tile.hpp
+278
-0
include/ck_tile/core/utility/functional_with_tuple.hpp
include/ck_tile/core/utility/functional_with_tuple.hpp
+173
-0
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+4
-2
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+148
-354
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp
.../ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp
+79
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_default_policy.hpp
.../pipeline/layernorm2d_fwd_warp_per_row_default_policy.hpp
+99
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_pipeline.hpp
...norm2d/pipeline/layernorm2d_fwd_warp_per_row_pipeline.hpp
+120
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_problem.hpp
...rnorm2d/pipeline/layernorm2d_fwd_warp_per_row_problem.hpp
+4
-1
include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp
...e/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp
+0
-36
include/ck_tile/ops/welford.hpp
include/ck_tile/ops/welford.hpp
+2
-1
include/ck_tile/ops/welford/block/block_welford.hpp
include/ck_tile/ops/welford/block/block_welford.hpp
+345
-0
include/ck_tile/ops/welford/block/block_welford_problem.hpp
include/ck_tile/ops/welford/block/block_welford_problem.hpp
+18
-0
include/ck_tile/ops/welford/thread/thread_welford.hpp
include/ck_tile/ops/welford/thread/thread_welford.hpp
+24
-89
include/ck_tile/ops/welford/warp/warp_welford.hpp
include/ck_tile/ops/welford/warp/warp_welford.hpp
+0
-114
No files found.
example/ck_tile/02_layernorm2d/script/smoke_test.sh
0 → 100644
View file @
5cfd751b
#!/bin/sh
# call from top of CK folder
EXE
=
./build/bin/tile_example_layernorm2d_fwd
for
pr_i
in
"fp16"
"bf16"
;
do
$EXE
-prec
=
$pr_i
-m
=
99
-n
=
13
$EXE
-prec
=
$pr_i
-m
=
17
-n
=
16
$EXE
-prec
=
$pr_i
-m
=
1
-n
=
100
$EXE
-prec
=
$pr_i
-m
=
4
-n
=
128
$EXE
-prec
=
$pr_i
-m
=
80
-n
=
127
$EXE
-prec
=
$pr_i
-m
=
22
-n
=
255
-stride
=
256
$EXE
-prec
=
$pr_i
-m
=
7
-n
=
599
$EXE
-prec
=
$pr_i
-m
=
19
-n
=
512
$EXE
-prec
=
$pr_i
-m
=
33
-n
=
313
-stride
=
1000
$EXE
-prec
=
$pr_i
-m
=
11
-n
=
510
$EXE
-prec
=
$pr_i
-m
=
171
-n
=
676
-stride
=
818
$EXE
-prec
=
$pr_i
-m
=
91
-n
=
636
$EXE
-prec
=
$pr_i
-m
=
12
-n
=
768
-stride
=
800
$EXE
-prec
=
$pr_i
-m
=
100
-n
=
766
-stride
=
812
$EXE
-prec
=
$pr_i
-m
=
31
-n
=
1024
$EXE
-prec
=
$pr_i
-m
=
64
-n
=
1000
done
include/ck_tile/core.hpp
View file @
5cfd751b
...
@@ -52,6 +52,7 @@
...
@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
...
...
include/ck_tile/core/config.hpp
View file @
5cfd751b
...
@@ -32,11 +32,13 @@
...
@@ -32,11 +32,13 @@
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else
#else
#define CK_TILE_HOST inline
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
...
...
include/ck_tile/core/container/sequence.hpp
View file @
5cfd751b
...
@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
...
@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
}
namespace
impl
{
template
<
typename
,
typename
,
typename
,
index_t
>
struct
reverse_slice_sequence_impl
;
template
<
index_t
x
,
index_t
...
xs
,
index_t
m
,
index_t
...
ms
,
index_t
id
,
index_t
...
ids
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
,
xs
...
>
,
sequence
<
m
,
ms
...
>
,
sequence
<
id
,
ids
...
>
,
SliceSize
>
{
using
old_scan
=
reverse_slice_sequence_impl
<
sequence
<
xs
...
>
,
sequence
<
ms
...
>
,
sequence
<
ids
...
>
,
SliceSize
>
;
static
constexpr
auto
slice_size
=
old_scan
::
remaining_slice_sizes
::
front
().
value
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
typename
sequence_merge
<
sequence
<
slice_length
>
,
typename
old_scan
::
dim_lengths
>::
type
;
using
dim_slices
=
typename
sequence_merge
<
sequence
<
x
/
slice_length
>
,
typename
old_scan
::
dim_slices
>::
type
;
using
remaining_slice_sizes
=
typename
sequence_merge
<
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
,
typename
old_scan
::
remaining_slice_sizes
>::
type
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
_split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
_split_idx
=
std
::
conditional_t
<
_split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_flag
=
_split_flag
||
old_scan
::
split_flag
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
old_scan
::
split_flag
,
number
<
old_scan
::
split_idx
>
,
number
<
_split_idx
>>::
value
;
};
template
<
index_t
x
,
index_t
m
,
index_t
id
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
>
,
sequence
<
m
>
,
sequence
<
id
>
,
SliceSize
>
{
static
constexpr
auto
slice_size
=
SliceSize
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
sequence
<
slice_length
>
;
using
dim_slices
=
sequence
<
x
/
slice_length
>
;
using
remaining_slice_sizes
=
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
};
}
// namespace impl
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template
<
typename
Seq
,
index_t
SliceSize
,
typename
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>
::
type
>
constexpr
auto
reverse_slice_sequence
(
Seq
,
number
<
SliceSize
>
,
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>::
type
{})
{
static_assert
(
Seq
::
size
()
==
Mask
::
size
());
using
sliced_type
=
impl
::
reverse_slice_sequence_impl
<
Seq
,
Mask
,
typename
arithmetic_sequence_gen
<
0
,
Seq
::
size
(),
1
>::
type
,
SliceSize
>
;
static_assert
(
sliced_type
::
remaining_slice_sizes
::
front
().
value
==
1
,
"can not evenly divide this sequence, please check"
);
return
make_tuple
(
typename
sliced_type
::
dim_lengths
{},
typename
sliced_type
::
dim_slices
{},
number
<
sliced_type
::
split_idx
>
{});
}
template
<
typename
Seq
,
index_t
SliceSize
,
typename
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>
::
type
>
constexpr
auto
slice_sequence
(
Seq
,
number
<
SliceSize
>
,
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>::
type
{})
{
constexpr
auto
r
=
reverse_slice_sequence
(
Seq
{}.
reverse
(),
number
<
SliceSize
>
{},
Mask
{}.
reverse
());
return
make_tuple
(
r
[
number
<
0
>
{}].
reverse
(),
r
[
number
<
1
>
{}].
reverse
(),
number
<
Seq
::
size
()
-
r
[
number
<
2
>
{}]
-
1
>
{});
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/container/tuple.hpp
View file @
5cfd751b
...
@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
...
@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
}
namespace
detail
{
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
embed_tuples_impl
(
F
f
,
const
X
&
x
,
sequence
<
Is
...
>
)
{
return
concat_tuple
(
f
(
x
.
at
(
number
<
Is
>
{}))...);
}
}
// namespace detail
// make sure F return at least a tuple
// e.g. x : tuple<X, Y>, f will return tuple<Z, W>
// this function will return
template
<
typename
F
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
embed_tuples
(
F
f
,
const
X
&
x
)
{
return
detail
::
embed_tuples_impl
(
f
,
x
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
// By default unroll to the flatten
// By default unroll to the flatten
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<>&
t
)
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<>&
t
)
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
5cfd751b
...
@@ -187,4 +187,18 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
...
@@ -187,4 +187,18 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
});
});
}
}
// this function used inside span loop over
template
<
typename
YLengths
,
index_t
XUnpacks
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_y_unpacks_from_x_unpacks
(
YLengths
,
number
<
XUnpacks
>
)
{
constexpr
auto
y_size
=
reduce_on_sequence
(
YLengths
{},
multiplies
{},
number
<
1
>
{});
constexpr
auto
y_packs
=
number
<
XUnpacks
>
{};
static_assert
(
y_size
%
y_packs
==
0
);
constexpr
auto
y_slice_size
=
y_size
/
y_packs
;
constexpr
auto
slice_info
=
slice_sequence
(
YLengths
{},
number
<
y_slice_size
>
{});
constexpr
auto
unpacks
=
slice_info
[
number
<
1
>
{}];
return
unpacks
;
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/tensor/sweep_tile.hpp
View file @
5cfd751b
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -27,4 +28,281 @@ CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
...
@@ -27,4 +28,281 @@ CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
});
});
}
}
// unpacked span, this version support span with unpack(multi-arg) functor
//
template
<
typename
TileDistributedSpan_
,
// tile_distributed_span<...>
typename
F
,
// signature: F(tile_distributed_index<...>)
typename
Unpacks
=
typename
uniform_sequence_gen
<
TileDistributedSpan_
::
Impl
::
size
(),
1
>
::
type
>
CK_TILE_DEVICE
void
sweep_tile_uspan
(
TileDistributedSpan_
,
const
F
&
f
,
Unpacks
=
{})
{
using
DstrSpan
=
remove_cvref_t
<
TileDistributedSpan_
>
;
static_uford
<
typename
DstrSpan
::
Impl
,
Unpacks
>
{}(
[
&
](
auto
...
dstr_idx_impl
)
{
f
(
detail
::
make_tile_distributed_index
(
dstr_idx_impl
)...);
});
}
namespace
impl
{
template
<
typename
,
typename
,
typename
>
struct
sweep_tile_impl
;
template
<
typename
DistributedTensor
,
typename
UnpacksPerXDim
,
index_t
I
,
index_t
...
Is
>
struct
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
I
,
Is
...
>>
{
CK_TILE_HOST_DEVICE
constexpr
auto
get_y_unpacks
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
y_lengths
=
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
{};
constexpr
auto
x_unpacks
=
number
<
UnpacksPerXDim
{}.
at
(
number
<
I
>
{})
>
{};
constexpr
auto
y_unpacks
=
get_y_unpacks_from_x_unpacks
(
y_lengths
,
x_unpacks
);
return
y_unpacks
;
}
CK_TILE_HOST_DEVICE
constexpr
index_t
get_num_of_access
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
return
u
.
get_num_of_access
()
*
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
}
template
<
typename
F
,
typename
SpanIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
sweep_tile_uspan
(
spans
[
number
<
I
>
{}],
[
&
](
auto
...
i_idx
)
{
const
auto
next_span_idx
=
embed_tuples
(
[
&
](
auto
si
)
{
return
make_tuple
(
concat_tuple
(
si
,
make_tuple
(
i_idx
))...);
},
span_idx
);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
);
},
get_y_unpacks
());
}
template
<
typename
F
,
typename
SpanIdx
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
,
number
<
i_access
>
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
constexpr
auto
access_stride
=
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
constexpr
auto
curr_i_access
=
number
<
i_access
/
access_stride
>
{};
constexpr
auto
next_i_access
=
number
<
i_access
%
access_stride
>
{};
u
(
[
&
](
auto
...
i_idx
)
{
const
auto
next_span_idx
=
embed_tuples
(
[
&
](
auto
si
)
{
return
make_tuple
(
concat_tuple
(
si
,
make_tuple
(
detail
::
make_tile_distributed_index
(
i_idx
)))...);
},
span_idx
);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
,
next_i_access
);
},
curr_i_access
);
}
};
template
<
typename
DistributedTensor
,
typename
UnpacksPerXDim
>
struct
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<>>
{
CK_TILE_HOST_DEVICE
constexpr
index_t
get_num_of_access
()
const
{
return
1
;
}
template
<
typename
F
,
typename
SpanIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
)
const
{
unpack
(
f
,
span_idx
);
}
template
<
typename
F
,
typename
SpanIdx
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
,
number
<
i_access
>
)
const
{
unpack
(
f
,
span_idx
);
}
};
template
<
typename
,
typename
,
typename
>
struct
sweep_tile_impl_0
;
// TODO: support empty tuple to remove this "entry-point" like function
template
<
typename
DistributedTensor
,
typename
UnpacksPerXDim
,
index_t
I
,
index_t
...
Is
>
struct
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
I
,
Is
...
>>
{
CK_TILE_HOST_DEVICE
constexpr
auto
get_y_unpacks
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
y_lengths
=
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
{};
constexpr
auto
x_unpacks
=
number
<
UnpacksPerXDim
{}.
at
(
number
<
I
>
{})
>
{};
constexpr
auto
y_unpacks
=
get_y_unpacks_from_x_unpacks
(
y_lengths
,
x_unpacks
);
return
y_unpacks
;
}
CK_TILE_HOST_DEVICE
constexpr
index_t
get_num_of_access
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
return
u
.
get_num_of_access
()
*
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
}
template
<
typename
F
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
sweep_tile_uspan
(
spans
[
number
<
I
>
{}],
[
&
](
auto
...
i_idx
)
{
constexpr
auto
next_span_idx
=
make_tuple
(
make_tuple
(
i_idx
)...);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
);
},
get_y_unpacks
());
}
template
<
typename
F
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
number
<
i_access
>
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
constexpr
auto
access_stride
=
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
constexpr
auto
curr_i_access
=
number
<
i_access
/
access_stride
>
{};
constexpr
auto
next_i_access
=
number
<
i_access
%
access_stride
>
{};
u
(
[
&
](
auto
...
i_idx
)
{
constexpr
auto
next_span_idx
=
make_tuple
(
make_tuple
(
detail
::
make_tile_distributed_index
(
i_idx
))...);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
,
next_i_access
);
},
curr_i_access
);
}
};
}
// namespace impl
/*
* Enhanced sweep-tile utility, can control unpacks along each X-dim
* the lambda function argument is the distributed-idx, which can directly
* plugged into the distributed tensor as setter/getter
*
* e.g. below function, y with the type DistributedTensor, r is row scale
*
* // sweep tile 1 by 1
* sweep_tile<DistributedTensor>([&](auto idx) {
* constexpr auto row_id = make_tuple(idx[number<0>{}]);
* y(idx) = y(idx) * r(row_id);
* });
*
* // sweep tile with 2 pixel from last dim each function call
* sweep_tile<DistributedTensor>(
* [&](auto idx_0, auto idx_1) {
* constexpr auto row_id = make_tuple(idx_0[number<0>{}]);
* y(idx_0) = y(idx_0) * r(row_id);
* y(idx_1) = y(idx_1) * r(row_id);
* },
* sequence<1, 2>{});
*
* // sweep tile with 2x2 pixel each function call
* sweep_tile<DistributedTensor>(
* [&](auto idx_00, auto idx_01, auto idx_10, auto idx_11) {
* constexpr auto row_id0 = make_tuple(idx_00[number<0>{}]);
* constexpr auto row_id1 = make_tuple(idx_10[number<0>{}]);
* y(idx_00) = y(idx_00) * r(row_id0);
* y(idx_01) = y(idx_01) * r(row_id0);
* y(idx_10) = y(idx_10) * r(row_id1);
* y(idx_11) = y(idx_11) * r(row_id1);
* },
* sequence<2, 2>{});
*
* TODO: do we need constexpr? lambda function could be non-constexpr
*/
template
<
typename
DistributedTensor
,
typename
F
,
typename
UnpacksPerXDim
=
typename
uniform_sequence_gen
<
DistributedTensor
::
get_num_of_dimension
(),
1
>
::
type
>
CK_TILE_HOST_DEVICE
constexpr
void
sweep_tile
(
const
F
&
f
,
UnpacksPerXDim
=
{})
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
impl
::
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
typename
arithmetic_sequence_gen
<
0
,
spans
.
size
(),
1
>::
type
>
{}(
f
);
}
template
<
typename
DistributedTensor
,
typename
F
,
typename
UnpacksPerXDim
=
typename
uniform_sequence_gen
<
DistributedTensor
::
get_num_of_dimension
(),
1
>
::
type
>
CK_TILE_HOST_DEVICE
constexpr
void
sweep_tile
(
const
DistributedTensor
&
,
const
F
&
f
,
UnpacksPerXDim
=
{})
{
sweep_tile
<
DistributedTensor
,
F
,
UnpacksPerXDim
>
(
f
,
UnpacksPerXDim
{});
}
/*
* construct a sweep tile instance, which support issue the lambda one by one
* Note that this struct will hold the lambda functor, but will not hold the distributed tensor
* the functionality is the same as sweep_tile()
*/
template
<
typename
DistributedTensor_
,
typename
F_
,
typename
UnpacksPerXDim_
=
typename
uniform_sequence_gen
<
DistributedTensor_
::
get_num_of_dimension
(),
1
>
::
type
>
struct
tile_sweeper
{
using
DistributedTensor
=
remove_cvref_t
<
DistributedTensor_
>
;
using
F
=
remove_cvref_t
<
F_
>
;
using
UnpacksPerXDim
=
remove_cvref_t
<
UnpacksPerXDim_
>
;
CK_TILE_HOST_DEVICE
tile_sweeper
(
const
F
&
f_
,
UnpacksPerXDim
=
{})
:
f
(
f_
)
{}
CK_TILE_HOST_DEVICE
tile_sweeper
(
const
DistributedTensor
&
,
const
F
&
f_
,
UnpacksPerXDim
=
{})
:
f
(
f_
)
{
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_access
()
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
tmp
=
impl
::
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
typename
arithmetic_sequence_gen
<
0
,
spans
.
size
(),
1
>::
type
>
{};
return
tmp
.
get_num_of_access
();
}
CK_TILE_HOST_DEVICE
void
operator
()()
const
{
sweep_tile
<
DistributedTensor
>
(
f
,
UnpacksPerXDim
{});
}
template
<
index_t
i_access
>
CK_TILE_HOST_DEVICE
void
operator
()(
number
<
i_access
>
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
impl
::
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
typename
arithmetic_sequence_gen
<
0
,
spans
.
size
(),
1
>::
type
>
{}(
f
,
number
<
i_access
>
{});
}
F
f
;
};
// partial deduction is not allowed
// template <typename T, typename F, typename U>
// CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
// deduction guide
template
<
typename
T
,
typename
F
,
typename
U
=
typename
uniform_sequence_gen
<
T
::
get_num_of_dimension
(),
1
>
::
type
>
CK_TILE_HOST_DEVICE_EXTERN
tile_sweeper
(
const
T
&
,
const
F
&
,
U
=
{})
->
tile_sweeper
<
T
,
F
,
U
>
;
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/utility/functional_with_tuple.hpp
0 → 100644
View file @
5cfd751b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// This file should not be included inside tuple.hpp!
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
#include <utility>
namespace
ck_tile
{
namespace
detail
{
// RemainLengths: sequence<...>
// Orders: sequence<...>
template
<
class
RemainLengths
,
class
RamainUnpacks
,
class
Orders
>
struct
static_uford_impl
{
CK_TILE_HOST_DEVICE
constexpr
static_uford_impl
()
{
static_assert
(
RemainLengths
::
size
()
>
0
,
"wrong! should not get here"
);
static_assert
(
RamainUnpacks
::
size
()
>
0
,
"wrong! should not get here"
);
}
template
<
class
F
,
class
CurrentUnpackIds
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
CurrentUnpackIds
)
const
{
constexpr
index_t
pack_len
=
RamainUnpacks
::
front
();
static_for
<
0
,
RemainLengths
::
front
(),
pack_len
>
{}([
=
](
auto
I
)
{
constexpr
auto
new_pack
=
generate_tuple
(
[
&
](
auto
idx_
)
{
constexpr
auto
i_new_pack
=
number
<
I
+
idx_
%
pack_len
>
{};
constexpr
auto
i_pre_pack
=
number
<
idx_
/
pack_len
>
{};
return
CurrentUnpackIds
{}.
at
(
i_pre_pack
).
push_back
(
i_new_pack
);
},
number
<
CurrentUnpackIds
::
size
()
*
pack_len
>
{});
static_uford_impl
<
decltype
(
RemainLengths
::
pop_front
()),
decltype
(
RamainUnpacks
::
pop_front
()),
Orders
>
{}(
f
,
new_pack
);
});
}
};
template
<
class
Orders
>
struct
static_uford_impl
<
sequence
<>
,
sequence
<>
,
Orders
>
{
template
<
class
F
,
class
PackedId
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
PackedId
)
const
{
constexpr
auto
origin_packs
=
transform_tuples
(
[](
auto
pack_
)
{
return
decltype
(
pack_
)
::
reorder_old_to_new
(
Orders
{});
},
PackedId
{});
unpack
(
f
,
origin_packs
);
}
};
template
<
class
RemainLengths
,
class
RamainUnpacks
,
class
Orders
>
struct
static_uford_one_shot_impl
{
template
<
class
F
,
class
CurrentUnpackIds
,
index_t
current_acc
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
CurrentUnpackIds
,
number
<
current_acc
>
)
const
{
constexpr
auto
r_lens_stride
=
reverse_exclusive_scan_sequence
(
RemainLengths
{},
multiplies
{},
number
<
1
>
{});
constexpr
auto
r_upks_stride
=
reverse_exclusive_scan_sequence
(
RamainUnpacks
{},
multiplies
{},
number
<
1
>
{});
constexpr
index_t
current_stride
=
r_lens_stride
.
front
()
/
r_upks_stride
.
front
();
constexpr
index_t
pack_len
=
RamainUnpacks
::
front
();
constexpr
index_t
current_idx
=
(
current_acc
/
current_stride
)
*
pack_len
;
constexpr
auto
new_pack
=
generate_tuple
(
[
&
](
auto
idx_
)
{
constexpr
auto
i_new_pack
=
number
<
current_idx
+
idx_
%
pack_len
>
{};
constexpr
auto
i_pre_pack
=
number
<
idx_
/
pack_len
>
{};
return
CurrentUnpackIds
{}.
at
(
i_pre_pack
).
push_back
(
i_new_pack
);
},
number
<
CurrentUnpackIds
::
size
()
*
pack_len
>
{});
static_uford_one_shot_impl
<
decltype
(
RemainLengths
::
pop_front
()),
decltype
(
RamainUnpacks
::
pop_front
()),
Orders
>
{}(
f
,
new_pack
,
number
<
current_acc
%
current_stride
>
{});
}
};
template
<
class
Orders
>
struct
static_uford_one_shot_impl
<
sequence
<>
,
sequence
<>
,
Orders
>
{
template
<
class
F
,
class
PackedId
,
index_t
current_acc
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
PackedId
,
number
<
current_acc
>
)
const
{
constexpr
auto
origin_packs
=
transform_tuples
(
[](
auto
pack_
)
{
return
decltype
(
pack_
)
::
reorder_old_to_new
(
Orders
{});
},
PackedId
{});
unpack
(
f
,
origin_packs
);
}
};
}
// namespace detail
// TODO: we may unify static_ford/static_uford in the future
//
// loop over nd space(sequence) with packs
// you must make sure the function passed in has same number of argument
//
// e.g.
// Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2>
// static_uford<Lengths, Unpacks>{}([&](auto i_0, auto i_1){}); // require 2 args(packs)
//
// loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1>
// loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3>
// loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1>
// loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3>
// loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1>
// loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3>
// loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1>
// ...
template
<
class
Lengths
,
class
Unpacks
=
typename
uniform_sequence_gen
<
Lengths
::
size
(),
1
>
::
type
,
class
Orders
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
size
(),
1
>::
type
>
struct
static_uford
{
static
constexpr
index_t
num_packs
=
reduce_on_sequence
(
Unpacks
{},
multiplies
{},
number
<
1
>
{});
CK_TILE_HOST_DEVICE
constexpr
static_uford
()
{
static_assert
(
Lengths
::
size
()
>
0
,
"wrong! Lengths is empty"
);
static_assert
(
Lengths
::
size
()
==
Unpacks
::
size
(),
"wrong! inconsistent size"
);
static_assert
(
Lengths
::
size
()
==
Orders
::
size
(),
"wrong! inconsistent size"
);
static_for
<
0
,
Lengths
::
size
(),
1
>
{}(
[
&
](
auto
i
)
{
static_assert
(
Lengths
{}.
at
(
i
)
%
Unpacks
{}.
at
(
i
)
==
0
);
});
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_access
()
{
using
L_
=
decltype
(
Lengths
{}
/
Unpacks
{});
return
reduce_on_sequence
(
L_
{},
multiplies
{},
number
<
1
>
{});
}
// F signature: F(sequence<...> multi_id...)
// multi_id is the unordered multi-index
template
<
class
F
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
)
const
{
constexpr
auto
ordered_lengths
=
Lengths
::
reorder_new_to_old
(
Orders
{});
constexpr
auto
ordered_unpacks
=
Unpacks
::
reorder_new_to_old
(
Orders
{});
detail
::
static_uford_impl
<
decltype
(
ordered_lengths
),
decltype
(
ordered_unpacks
),
Orders
>
{}(
f
,
make_tuple
(
sequence
<>
{}));
}
// this version is friendly for issue function one by one
template
<
class
F
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
number
<
i_access
>
)
const
{
static_assert
(
i_access
<
get_num_of_access
());
constexpr
auto
ordered_lengths
=
Lengths
::
reorder_new_to_old
(
Orders
{});
constexpr
auto
ordered_unpacks
=
Unpacks
::
reorder_new_to_old
(
Orders
{});
detail
::
static_uford_one_shot_impl
<
decltype
(
ordered_lengths
),
decltype
(
ordered_unpacks
),
Orders
>
{}(
f
,
make_tuple
(
sequence
<>
{}),
number
<
i_access
>
{});
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d.hpp
View file @
5cfd751b
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
#pragma once
#pragma once
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_pipeline.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
5cfd751b
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp
0 → 100644
View file @
5cfd751b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
// clang-format off
4-level descriptor: BlockTile-> BlockWarps-> WarpTile-> Vector
Block_N (Warp_N * BlockWarps_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
| |
+<-- <BlockWarps_N(2)> -->+
Warp_M
+--------------+--------------+--------------+--------------+----+----------------+
Warp_N | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <BlockWarps_M(2)> |
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
| | | (Warp_M *
BlockWarps_M * Repeat_M )
+ + |
| | | v
+--------------+--------------+--------------+--------------+ +
each Warp-tile (e.g 16 thrd per row)
Vector_N (contiguous pixels each thrd holds along N, or vector size)
+-----------+-----------+-----------+-----------+-----------+
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
+-----------+-----------+-----------+-----------+-----------+
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template
<
typename
BlockTile_
,
// block size, seq<M, N>
typename
BlockWarps_
,
// num warps along seq<M, N>
typename
WarpTile_
,
// warp size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
BlockWarps_
{}
,
multiplies
{}
,
number
<
1
>{})
>
struct
Layernorm2dShape
{
// block size
static
constexpr
index_t
Block_M
=
BlockTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N
=
BlockTile_
::
at
(
number
<
1
>
{});
// num warps along seq<M, N>, within each block
static
constexpr
index_t
BlockWarps_M
=
BlockWarps_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
BlockWarps_N
=
BlockWarps_
::
at
(
number
<
1
>
{});
// warp size
static
constexpr
index_t
Warp_M
=
WarpTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N
=
WarpTile_
::
at
(
number
<
1
>
{});
static_assert
(
Block_M
%
(
BlockWarps_M
*
Warp_M
)
==
0
);
static_assert
(
Block_N
%
(
BlockWarps_N
*
Warp_N
)
==
0
);
// repeat of each thread along seq<M, N>
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
BlockWarps_M
*
Warp_M
);
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
BlockWarps_N
*
Warp_N
);
// vector size along seq<M, N>
static
constexpr
index_t
Vector_M
=
Vector_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Vector_N
=
Vector_
::
at
(
number
<
1
>
{});
static_assert
(
Warp_M
%
Vector_M
==
0
);
static_assert
(
Warp_N
%
Vector_N
==
0
);
// num of threads along seq<M, N>, within each warp
static
constexpr
index_t
Thread_M
=
Warp_M
/
Vector_M
;
static
constexpr
index_t
Thread_N
=
Warp_N
/
Vector_N
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_default_policy.hpp
0 → 100644
View file @
5cfd751b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/block/block_welford.hpp"
namespace
ck_tile
{
struct
Layernorm2dFwdWarpPerRowDefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeXBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S
::
BlockWarps_M
,
S
::
Thread_M
,
S
::
Vector_M
>
,
sequence
<
S
::
Repeat_N
,
S
::
BlockWarps_N
,
S
::
Thread_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
2
,
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBetaBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
S
::
BlockWarps_M
,
S
::
Thread_M
>
,
tuple
<
sequence
<
S
::
Repeat_N
,
S
::
BlockWarps_N
,
S
::
Thread_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelford
()
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockWelford
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordSync
()
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockWelfordSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordCrossWarpSync
()
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockWelfordCrossWarpSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
XDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
using
block_welford
=
BlockWelford
<
P_
>
;
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
XDataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
using
mean_var_block_tile
=
decltype
(
block_welford
::
template
MakeMeanVarBlockTile
<
x_block_tile
>());
return
GetBlockWelfordCrossWarpSync
<
Problem
>
()
.
template
GetSmemSize
<
mean_var_block_tile
>();
}
else
{
return
1
;
// zero size arrays are an extension
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_pipeline.hpp
0 → 100644
View file @
5cfd751b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dFwdWarpPerRowDefaultPolicy
>
struct
Layernorm2dFwdWarpPerRowPipeline
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr"
;
// block per row
else
return
"wpr"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
MeanWindow
,
typename
InvStdWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
const
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
beta_window
=
make_tile_window
(
beta_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
x
=
load_tile
(
x_window
);
int
cur_count
=
0
;
int
max_count
=
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
row_size
);
auto
block_welford
=
Policy
::
template
GetBlockWelford
<
Problem
>();
auto
block_welford_sync
=
Policy
::
template
GetBlockWelfordSync
<
Problem
>();
auto
block_welford_cross_warp_sync
=
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
// compute welford each-thread->cross-lane->cross-warp
auto
[
mean
,
var
]
=
block_welford
(
x
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
)
+
epsilon
);
},
var
);
if
constexpr
(
kSaveMean
)
store_tile
(
mean_window
,
cast_tile
<
MeanDataType
>
(
mean
));
if
constexpr
(
kSaveInvStd
)
store_tile
(
inv_std_window
,
cast_tile
<
InvStdDataType
>
(
inv_std
));
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
// layernorm computation
auto
y
=
make_static_distributed_tensor
<
YDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
);
});
store_tile
(
y_window
,
y
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/
block_
layernorm2d_fwd_problem.hpp
→
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_
warp_per_row_
problem.hpp
View file @
5cfd751b
...
@@ -18,7 +18,7 @@ template <typename XDataType_,
...
@@ -18,7 +18,7 @@ template <typename XDataType_,
bool
kPadN_
,
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
>
bool
kTwoPass_
>
struct
Block
Layernorm2dFwdProblem
struct
Layernorm2dFwd
WarpPerRow
Problem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
...
@@ -29,6 +29,9 @@ struct BlockLayernorm2dFwdProblem
...
@@ -29,6 +29,9 @@ struct BlockLayernorm2dFwdProblem
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
Thread_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
BlockWarps_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
...
...
include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp
deleted
100644 → 0
View file @
68e67701
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
ThreadTile
,
// Sequence<...
typename
WarpTile
,
// Sequence<...
typename
BlockTile
>
// Sequence<...
struct
TileLayernorm2dShape
{
static
constexpr
index_t
kMPerThread
=
ThreadTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNRepeat
=
ThreadTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kNPerThread
=
ThreadTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kMPerWarp
=
WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerWarp
=
WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMThreadPerWarp
=
kMPerWarp
/
kMPerThread
;
static
constexpr
index_t
kNThreadPerWarp
=
kNPerWarp
/
kNPerThread
/
kNRepeat
;
static
constexpr
index_t
kMPerBlock
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerBlock
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMWarpPerBlock
=
kMPerBlock
/
kMPerWarp
;
static
constexpr
index_t
kNWarpPerBlock
=
kNPerBlock
/
kNPerWarp
;
// TODO - kNNumWarps can only be 1 if we don't support cross warp welford
static_assert
(
kNWarpPerBlock
==
1
);
static
constexpr
index_t
kBlockSize
=
warpSize
*
kMWarpPerBlock
*
kNWarpPerBlock
;
};
}
// namespace ck_tile
include/ck_tile/ops/welford.hpp
View file @
5cfd751b
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#pragma once
#pragma once
#include "ck_tile/ops/welford/block/block_welford.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/welford/warp/warp_welford.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/welford/block/block_welford.hpp
0 → 100644
View file @
5cfd751b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockWelford
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
XDataType
=
typename
Problem
::
XDataType
;
using
ComputeDataType
=
typename
Problem
::
ComputeDataType
;
CK_TILE_DEVICE
constexpr
BlockWelford
()
{}
// [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN welford will have different
// calculation of max_count_
// -> use block_welford_calculate_max_count to compute
template
<
typename
XDistributedTensor_
,
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
const
XDistributedTensor_
&
x_tensor
,
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
cur_count_
,
// -> prefer init as zero
const
int
&
max_count_
)
{
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
spans
=
XDistributedTensor_
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
I1
],
[
&
](
auto
dstr_idx_i1
)
{
if
(
cur_count_
<
max_count_
)
{
++
cur_count_
;
sweep_tile_span
(
spans
[
I0
],
[
&
](
auto
dstr_idx_i0
)
{
constexpr
auto
in_dstr_idx
=
make_tuple
(
dstr_idx_i0
,
dstr_idx_i1
);
constexpr
auto
out_dstr_idx
=
make_tuple
(
dstr_idx_i0
);
auto
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
in_dstr_idx
]);
welford_update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
x
,
cur_count_
);
});
}
});
}
template
<
typename
XDistributedTensor_
>
CK_TILE_DEVICE
static
auto
MakeMeanVarBlockTile
()
{
static_assert
(
std
::
is_same_v
<
XDataType
,
typename
XDistributedTensor_
::
DataType
>
,
"wrong!"
);
constexpr
auto
reduce_dims
=
sequence
<
1
>
{};
constexpr
auto
dstr
=
make_static_tile_distribution
(
detail
::
make_reduce_tile_distribution_encoding
(
XDistributedTensor_
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
(),
reduce_dims
));
auto
tensor
=
make_static_distributed_tensor
<
ComputeDataType
>
(
dstr
);
return
tensor
;
}
template
<
typename
XDistributedTensor_
>
CK_TILE_DEVICE
auto
operator
()(
const
XDistributedTensor_
&
x_tensor
,
int
&
cur_count_
,
const
int
&
max_count_
)
{
auto
mean_tensor
=
MakeMeanVarBlockTile
<
XDistributedTensor_
>
();
auto
var_tensor
=
MakeMeanVarBlockTile
<
XDistributedTensor_
>
();
clear_tile
(
mean_tensor
);
clear_tile
(
var_tensor
);
(
*
this
)(
x_tensor
,
mean_tensor
,
var_tensor
,
cur_count_
,
max_count_
);
return
ck_tile
::
make_tuple
(
mean_tensor
,
var_tensor
);
}
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockWelfordSync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
count
)
{
using
Dstr
=
typename
MeanDistributedTensor_
::
StaticTileDistribution
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncodeDetail
=
typename
DstrEncode
::
detail
;
static_assert
(
std
::
is_same_v
<
Dstr
,
typename
VarDistributedTensor_
::
StaticTileDistribution
>
,
"wrong!"
);
constexpr
index_t
NDimP
=
Dstr
::
get_num_of_dimension_p
();
constexpr
index_t
NDimR
=
Dstr
::
get_num_of_dimension_r
();
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx =
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
const
int
original_count
=
count
;
// loop over thread data
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
auto
v_local_mean
=
mean_tensor
.
get_thread_buffer
()[
i
];
auto
v_local_var
=
var_tensor
.
get_thread_buffer
()[
i
];
auto
v_local_count
=
original_count
;
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
// FIXME: nasty to use does_p_own_r_
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_lane
][
idim_r
])
{
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
constexpr
index_t
lid_over_rid_derivative
=
DstrEncodeDetail
::
ps_over_rs_derivative_
[
idim_p_lane
][
idim_r
];
static_assert
(
is_power_of_two_integer
(
r_length
),
"wrong! only support power of 2 reduction"
);
constexpr
index_t
nstage
=
integer_log2_floor
(
r_length
);
// reduction sweep forward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
// xor
index_t
src_lane
=
(
__lane_id
())
^
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_var
=
warp_shuffle
(
v_local_var
,
src_lane
);
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
// welford merge
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
);
});
}
});
mean_tensor
.
get_thread_buffer
()(
i
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i
)
=
v_local_var
;
count
=
v_local_count
;
});
}
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockWelfordCrossWarpSync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
template
<
typename
MeanDistributedTensor_
>
CK_TILE_DEVICE
static
constexpr
index_t
GetReduceWarps
()
{
constexpr
index_t
num_reduce_warps
=
[
&
]()
{
using
Dstr
=
typename
MeanDistributedTensor_
::
StaticTileDistribution
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncodeDetail
=
typename
DstrEncode
::
detail
;
constexpr
index_t
NDimR
=
Dstr
::
get_num_of_dimension_r
();
constexpr
index_t
idim_p_warp
=
0
;
index_t
len_
=
1
;
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_warp
][
idim_r
])
{
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
len_
*=
r_length
;
}
});
return
len_
;
}();
return
num_reduce_warps
;
}
// return in byte
template
<
typename
MeanDistributedTensor_
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
// constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
// data need to exchange is very small, we just pack mean+var+count -> 4dword
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
// we need to store all data from every wave into smem
// e.g. 2x2 reduce along N
// -------------> reduce N
// | w0 | w1 | ___> | w01 |
// | w2 | w3 | | w23 |
//
// -> store data from every wave into LDS
//
//
// -------------> reduce N
// | w0 | w1 | w2 | w3 | -----> | w0123 |
//
// -> also store data from every wave into LDS
constexpr
index_t
num_warps
=
BlockShape
::
BlockSize
/
warpSize
;
return
num_warps
*
4
*
thread_buf_size
*
sizeof
(
float
);
}
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
count
,
void
*
smem
)
{
using
DataType
=
typename
MeanDistributedTensor_
::
DataType
;
using
Dstr
=
typename
MeanDistributedTensor_
::
StaticTileDistribution
;
// using DstrEncode = typename Dstr::DstrEncode;
// using DstrEncodeDetail = typename DstrEncode::detail;
static_assert
(
std
::
is_same_v
<
Dstr
,
typename
VarDistributedTensor_
::
StaticTileDistribution
>
,
"wrong!"
);
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
// Note: we always pack everything into fp32x4
fp32x4_t
*
smem_ptr
=
reinterpret_cast
<
fp32x4_t
*>
(
smem
);
const
index_t
lane_id
=
get_lane_id
();
const
index_t
warp_id
=
get_warp_id
();
constexpr
auto
num_reduce_warps
=
GetReduceWarps
<
MeanDistributedTensor_
>
();
constexpr
index_t
num_warps
=
BlockShape
::
BlockSize
/
warpSize
;
const
index_t
smem_offset
=
warp_id
;
// skip if nonthing to do
if
constexpr
(
num_reduce_warps
==
1
)
return
;
// store into smem only for lane-0 within one warp
if
(
lane_id
==
0
)
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
fp32x4_t
local_scratch_
;
local_scratch_
[
0
]
=
bit_cast
<
float
>
(
mean_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
1
]
=
bit_cast
<
float
>
(
var_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
2
]
=
bit_cast
<
float
>
(
count
);
smem_ptr
[
smem_offset
+
i
*
num_warps
]
=
local_scratch_
;
});
}
block_sync_lds
();
// load from smem. here we let everythread to do compute :)
index_t
local_warp_id
=
warp_id
/
num_reduce_warps
;
index_t
local_smem_os
=
local_warp_id
*
num_reduce_warps
;
fp32x4_t
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
static_for
<
0
,
num_reduce_warps
,
1
>
{}([
&
](
auto
i_1
)
{
all_scratch
[
i_0
*
num_warps
+
i_1
]
=
smem_ptr
[
i_0
*
num_reduce_warps
+
local_smem_os
+
i_1
];
});
});
block_sync_lds
();
// TODO: we don't need sync here
// const int original_count = count;
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
// TODO: use descriptor for this
auto
v_local
=
all_scratch
[
i_0
*
num_warps
];
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
auto
v_local_count
=
bit_cast
<
int
>
(
v_local
[
2
]);
// further reduce mean/var
static_for
<
0
,
num_reduce_warps
-
1
,
1
>
{}([
&
](
auto
i_1_n1
)
{
constexpr
auto
i_1
=
number
<
i_1_n1
+
1
>
{};
const
fp32x4_t
v_remote
=
all_scratch
[
i_0
*
num_warps
+
i_1
];
const
auto
v_remote_mean
=
bit_cast
<
DataType
>
(
v_remote
[
0
]);
const
auto
v_remote_var
=
bit_cast
<
DataType
>
(
v_remote
[
1
]);
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
);
});
mean_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_var
;
count
=
v_local_count
;
});
}
};
// compute the max count for a last dim reduce
// everything may have vector/repeat, so the max count could be uneven
// TODO: specify which dim to compute and proper set the problem
// TODO: BlockShape we reuse layernorm_fwd_shape :)
template
<
typename
BlockShape
>
CK_TILE_DEVICE
constexpr
index_t
block_tile_welford_calculate_max_count
(
int
row_size
)
{
using
S
=
BlockShape
;
index_t
LastloopN
=
row_size
%
S
::
Block_N
==
0
?
S
::
Block_N
:
row_size
%
S
::
Block_N
;
constexpr
index_t
NThread
=
S
::
BlockWarps_N
*
S
::
Thread_N
;
index_t
iNLane
=
get_thread_id
()
%
NThread
;
index_t
iN0
=
LastloopN
/
(
S
::
Vector_N
*
S
::
Thread_N
);
index_t
iN1
=
(
LastloopN
%
(
S
::
Vector_N
*
S
::
Thread_N
))
/
S
::
Vector_N
;
index_t
N2
=
(
LastloopN
%
(
S
::
Vector_N
*
S
::
Thread_N
))
%
S
::
Vector_N
;
index_t
iN3
=
iNLane
<
iN1
?
S
::
Vector_N
:
iNLane
==
iN1
?
N2
:
0
;
return
iN0
*
S
::
Vector_N
+
iN3
;
}
// Note: this function must be called after all the computation
template
<
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
constexpr
void
block_tile_welford_post_scale_var
(
VarDistributedTensor_
&
var_tensor
,
int
count
)
{
using
DataType
=
typename
VarDistributedTensor_
::
DataType
;
tile_elementwise_inout
([
&
count
](
auto
&
x
)
{
x
=
x
/
type_convert
<
DataType
>
(
count
);
},
var_tensor
);
}
}
// namespace ck_tile
include/ck_tile/ops/welford/block/block_welford_problem.hpp
0 → 100644
View file @
5cfd751b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
XDataType_
,
typename
ComputeDataType_
,
typename
BlockShape_
>
struct
BlockWelfordProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/welford/thread/thread_welford.hpp
View file @
5cfd751b
...
@@ -7,95 +7,30 @@
...
@@ -7,95 +7,30 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
ComputeDataType_
,
typename
XDataType_
>
template
<
typename
T
>
struct
ThreadWelford
CK_TILE_DEVICE
void
welford_update
(
T
&
mean
,
T
&
var
,
T
x
,
int
count
)
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
// TODO: check nan? maybe no
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
T
delta
=
x
-
mean
;
mean
+=
delta
/
count
;
template
<
typename
T
>
T
delta2
=
x
-
mean
;
CK_TILE_DEVICE
void
Update
(
T
&
mean
,
T
&
var
,
T
x
)
var
+=
delta
*
delta2
;
{
}
if
(
ck_tile
::
isnan
(
x
))
{
template
<
typename
T
>
mean
=
x
;
CK_TILE_DEVICE
static
void
var
=
x
;
welford_merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
}
{
else
int
count
=
count_a
+
count_b
;
{
T
count_
=
type_convert
<
T
>
(
count
);
T
delta
=
x
-
mean
;
T
count_a_
=
type_convert
<
T
>
(
count_a
);
mean
+=
delta
/
cur_count_
;
T
count_b_
=
type_convert
<
T
>
(
count_b
);
T
delta2
=
x
-
mean
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
count_b_
/
count_
;
var
+=
delta
*
delta2
;
}
T
delta
=
mean_b
-
mean_a
;
}
mean_a
+=
delta
*
count_b_over_count
;
var_a
+=
var_b
+
delta
*
delta
*
count_a_
*
count_b_over_count
;
// [CAUSION] - max_count_ is to deal with the padding problem
count_a
=
count
;
// max_count_ is depend on caller, eg: naive and splitN welford will have different
}
// calculation of max_count_
CK_TILE_DEVICE
constexpr
ThreadWelford
(
int
max_count
)
:
cur_count_
(
0
),
max_count_
(
max_count
)
{}
template
<
typename
XDistributedTensor_
,
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
const
XDistributedTensor_
&
x_tensor
,
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
)
{
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
spans
=
XDistributedTensor_
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
I1
],
[
&
](
auto
dstr_idx_i1
)
{
if
(
cur_count_
<
max_count_
)
{
++
cur_count_
;
sweep_tile_span
(
spans
[
I0
],
[
&
](
auto
dstr_idx_i0
)
{
constexpr
auto
in_dstr_idx
=
make_tuple
(
dstr_idx_i0
,
dstr_idx_i1
);
constexpr
auto
out_dstr_idx
=
make_tuple
(
dstr_idx_i0
);
auto
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
in_dstr_idx
]);
Update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
x
);
});
}
});
}
template
<
typename
XDistributedTensor_
>
CK_TILE_DEVICE
static
auto
MakeInitialMeanVarDistributedTensor
()
{
static_assert
(
std
::
is_same_v
<
XDataType
,
typename
XDistributedTensor_
::
DataType
>
,
"wrong!"
);
constexpr
auto
reduce_dims
=
sequence
<
1
>
{};
constexpr
auto
dstr
=
make_static_tile_distribution
(
detail
::
make_reduce_tile_distribution_encoding
(
XDistributedTensor_
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
(),
reduce_dims
));
auto
tensor
=
make_static_distributed_tensor
<
ComputeDataType
>
(
dstr
);
clear_tile
(
tensor
);
return
tensor
;
}
template
<
typename
XDistributedTensor_
>
CK_TILE_DEVICE
auto
operator
()(
const
XDistributedTensor_
&
x_tensor
)
{
auto
mean_tensor
=
MakeInitialMeanVarDistributedTensor
<
XDistributedTensor_
>
();
auto
var_tensor
=
MakeInitialMeanVarDistributedTensor
<
XDistributedTensor_
>
();
(
*
this
)(
x_tensor
,
mean_tensor
,
var_tensor
);
return
ck_tile
::
make_tuple
(
mean_tensor
,
var_tensor
);
}
int
cur_count_
;
int
max_count_
;
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/welford/warp/warp_welford.hpp
deleted
100644 → 0
View file @
68e67701
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
ComputeDataType_
,
bool
BroadcastLane
=
true
,
bool
GetActualVariance
=
true
>
struct
WarpMergeWelford
{
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
template
<
typename
T
>
CK_TILE_DEVICE
static
void
Merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
{
int
count
=
count_a
+
count_b
;
T
count_
=
type_convert
<
T
>
(
count
);
T
count_a_
=
type_convert
<
T
>
(
count_a
);
T
count_b_
=
type_convert
<
T
>
(
count_b
);
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
count_b_
/
count_
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
var_a
+=
var_b
+
delta
*
delta
*
count_a_
*
count_b_over_count
;
count_a
=
count
;
}
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
operator
()(
MeanDistributedTensor_
&
mean_tensor
,
VarDistributedTensor_
&
var_tensor
,
int
&
count
)
{
using
Dstr
=
typename
MeanDistributedTensor_
::
StaticTileDistribution
;
using
DstrEncode
=
typename
Dstr
::
DstrEncode
;
using
DstrEncodeDetail
=
typename
DstrEncode
::
detail
;
static_assert
(
std
::
is_same_v
<
Dstr
,
typename
VarDistributedTensor_
::
StaticTileDistribution
>
,
"wrong!"
);
constexpr
index_t
NDimP
=
Dstr
::
get_num_of_dimension_p
();
constexpr
index_t
NDimR
=
Dstr
::
get_num_of_dimension_r
();
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
// const auto rs_idx =
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
const
int
original_count
=
count
;
// loop over thread data
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
auto
v_local_mean
=
mean_tensor
.
get_thread_buffer
()[
i
];
auto
v_local_var
=
var_tensor
.
get_thread_buffer
()[
i
];
auto
v_local_count
=
original_count
;
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
// FIXME: nasty to use does_p_own_r_
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_lane
][
idim_r
])
{
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
constexpr
index_t
lid_over_rid_derivative
=
DstrEncodeDetail
::
ps_over_rs_derivative_
[
idim_p_lane
][
idim_r
];
static_assert
(
is_power_of_two_integer
(
r_length
),
"wrong! only support power of 2 reduction"
);
constexpr
index_t
nstage
=
integer_log2_floor
(
r_length
);
// reduction sweep forward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
// xor
index_t
src_lane
=
(
__lane_id
())
^
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_var
=
warp_shuffle
(
v_local_var
,
src_lane
);
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
// welford merge
Merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
);
});
}
});
mean_tensor
.
get_thread_buffer
()(
i
)
=
v_local_mean
;
if
constexpr
(
GetActualVariance
)
var_tensor
.
get_thread_buffer
()(
i
)
=
v_local_var
/
v_local_count
;
else
var_tensor
.
get_thread_buffer
()(
i
)
=
v_local_var
;
count
=
v_local_count
;
});
}
};
}
// namespace ck_tile
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment