Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cc1898fc
Commit
cc1898fc
authored
Oct 20, 2024
by
carlushuang
Browse files
fix name
parent
dc1c2bf8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
32 additions
and
33 deletions
+32
-33
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+4
-4
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp
.../ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp
+18
-19
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_default_policy.hpp
.../pipeline/layernorm2d_fwd_warp_per_row_default_policy.hpp
+4
-4
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_problem.hpp
...rnorm2d/pipeline/layernorm2d_fwd_warp_per_row_problem.hpp
+2
-2
include/ck_tile/ops/welford/block/block_welford.hpp
include/ck_tile/ops/welford/block/block_welford.hpp
+4
-4
No files found.
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
cc1898fc
...
@@ -53,9 +53,9 @@ struct Layernorm2dFwd
...
@@ -53,9 +53,9 @@ struct Layernorm2dFwd
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
index_t
Thread_N
=
Problem
::
BlockShape
::
Thread_N
;
static
constexpr
index_t
Thread
PerWarp
_N
=
Problem
::
BlockShape
::
Thread
PerWarp
_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
static
constexpr
index_t
Repeat_N
=
Problem
::
BlockShape
::
Repeat_N
;
static
constexpr
index_t
Repeat_N
=
Problem
::
BlockShape
::
Repeat_N
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
...
@@ -125,7 +125,7 @@ struct Layernorm2dFwd
...
@@ -125,7 +125,7 @@ struct Layernorm2dFwd
#define _SS_ std::string
#define _SS_ std::string
#define _TS_ std::to_string
#define _TS_ std::to_string
return
_SS_
(
"layernorm2d_fwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
return
_SS_
(
"layernorm2d_fwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
Block
Warps
_M
)
+
"x"
+
_TS_
(
S_
::
Block
Warps
_N
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPer
Block_M
)
+
"x"
+
_TS_
(
S_
::
WarpPer
Block_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
_SS_
(
Pipeline
::
name
)
+
surfix
;
#undef _SS_
#undef _SS_
...
...
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp
View file @
cc1898fc
...
@@ -9,20 +9,19 @@ namespace ck_tile {
...
@@ -9,20 +9,19 @@ namespace ck_tile {
/*
/*
// clang-format off
// clang-format off
4-level descriptor: BlockTile-> Block
Warps
-> WarpTile-> Vector
4-level descriptor: BlockTile->
WarpPer
Block-> WarpTile-> Vector
Block_N (Warp_N * Block
Warps
_N * Repeat_N )
Block_N (Warp_N *
WarpPer
Block_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
+<----------------------< Repeat_N(2)>--------------------->+
| |
| |
+<-- <Block
Warps
_N(2)> -->+
+<-- <
WarpPer
Block_N(2)> -->+
Warp_M
Warp_M
+--------------+--------------+--------------+--------------+----+----------------+
+--------------+--------------+--------------+--------------+----+----------------+
Warp_N | wrap_0 | wrap_1 | | ^ ^
Warp_N | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <Block
Warps
_M(2)> |
+--------------+--------------+ | <
WarpPer
Block_M(2)> |
| wrap_2 | wrap_3 | | v
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
+--------------+--------------+--------------+--------------+----+ Block_M
| | | (Warp_M *
| | | (Warp_M * WarpPerBlock_M * Repeat_M )
BlockWarps_M * Repeat_M )
+ + |
+ + |
| | | v
| | | v
+--------------+--------------+--------------+--------------+ +
+--------------+--------------+--------------+--------------+ +
...
@@ -37,12 +36,12 @@ BlockWarps_M * Repeat_M )
...
@@ -37,12 +36,12 @@ BlockWarps_M * Repeat_M )
+-----------+-----------+-----------+-----------+-----------+
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
// clang-format on
*/
*/
template
<
typename
BlockTile_
,
// block size, seq<M, N>
template
<
typename
BlockTile_
,
// block size, seq<M, N>
typename
Block
Warps
_
,
// num warps along seq<M, N>
typename
WarpPer
Block_
,
// num warps along seq<M, N>
typename
WarpTile_
,
// warp size, seq<M, N>
typename
WarpTile_
,
// warp size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
Block
Warps
_
{}
,
multiplies
{}
,
number
<
1
>{})
>
warpSize
*
reduce_on_sequence
(
WarpPer
Block_
{}
,
multiplies
{}
,
number
<
1
>{})
>
struct
Layernorm2dShape
struct
Layernorm2dShape
{
{
// block size
// block size
...
@@ -50,18 +49,18 @@ struct Layernorm2dShape
...
@@ -50,18 +49,18 @@ struct Layernorm2dShape
static
constexpr
index_t
Block_N
=
BlockTile_
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_N
=
BlockTile_
::
at
(
number
<
1
>
{});
// num warps along seq<M, N>, within each block
// num warps along seq<M, N>, within each block
static
constexpr
index_t
BlockWarps_M
=
BlockWarps
_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_M
=
WarpPerBlock
_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
BlockWarps_N
=
BlockWarps
_
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_N
=
WarpPerBlock
_
::
at
(
number
<
1
>
{});
// warp size
// warp size
static
constexpr
index_t
Warp_M
=
WarpTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_M
=
WarpTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N
=
WarpTile_
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_N
=
WarpTile_
::
at
(
number
<
1
>
{});
static_assert
(
Block_M
%
(
Block
Warps
_M
*
Warp_M
)
==
0
);
static_assert
(
Block_M
%
(
WarpPer
Block_M
*
Warp_M
)
==
0
);
static_assert
(
Block_N
%
(
Block
Warps
_N
*
Warp_N
)
==
0
);
static_assert
(
Block_N
%
(
WarpPer
Block_N
*
Warp_N
)
==
0
);
// repeat of each thread along seq<M, N>
// repeat of each thread along seq<M, N>
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
Block
Warps
_M
*
Warp_M
);
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
WarpPer
Block_M
*
Warp_M
);
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
Block
Warps
_N
*
Warp_N
);
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
WarpPer
Block_N
*
Warp_N
);
// vector size along seq<M, N>
// vector size along seq<M, N>
static
constexpr
index_t
Vector_M
=
Vector_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Vector_M
=
Vector_
::
at
(
number
<
0
>
{});
...
@@ -70,8 +69,8 @@ struct Layernorm2dShape
...
@@ -70,8 +69,8 @@ struct Layernorm2dShape
static_assert
(
Warp_M
%
Vector_M
==
0
);
static_assert
(
Warp_M
%
Vector_M
==
0
);
static_assert
(
Warp_N
%
Vector_N
==
0
);
static_assert
(
Warp_N
%
Vector_N
==
0
);
// num of threads along seq<M, N>, within each warp
// num of threads along seq<M, N>, within each warp
static
constexpr
index_t
Thread_M
=
Warp_M
/
Vector_M
;
static
constexpr
index_t
Thread
PerWarp
_M
=
Warp_M
/
Vector_M
;
static
constexpr
index_t
Thread_N
=
Warp_N
/
Vector_N
;
static
constexpr
index_t
Thread
PerWarp
_N
=
Warp_N
/
Vector_N
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
};
};
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_default_policy.hpp
View file @
cc1898fc
...
@@ -19,8 +19,8 @@ struct Layernorm2dFwdWarpPerRowDefaultPolicy
...
@@ -19,8 +19,8 @@ struct Layernorm2dFwdWarpPerRowDefaultPolicy
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
S
::
Block
Warps
_M
,
S
::
Thread_M
,
S
::
Vector_M
>
,
tuple
<
sequence
<
S
::
WarpPer
Block_M
,
S
::
Thread
PerWarp
_M
,
S
::
Vector_M
>
,
sequence
<
S
::
Repeat_N
,
S
::
Block
Warps
_N
,
S
::
Thread_N
,
S
::
Vector_N
>>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPer
Block_N
,
S
::
Thread
PerWarp
_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
1
,
2
,
2
>
,
...
@@ -33,8 +33,8 @@ struct Layernorm2dFwdWarpPerRowDefaultPolicy
...
@@ -33,8 +33,8 @@ struct Layernorm2dFwdWarpPerRowDefaultPolicy
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<
S
::
Block
Warps
_M
,
S
::
Thread_M
>
,
sequence
<
S
::
WarpPer
Block_M
,
S
::
Thread
PerWarp
_M
>
,
tuple
<
sequence
<
S
::
Repeat_N
,
S
::
Block
Warps
_N
,
S
::
Thread_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
S
::
Repeat_N
,
S
::
WarpPer
Block_N
,
S
::
Thread
PerWarp
_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
1
>
,
sequence
<
1
,
1
>
,
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_warp_per_row_problem.hpp
View file @
cc1898fc
...
@@ -29,8 +29,8 @@ struct Layernorm2dFwdWarpPerRowProblem
...
@@ -29,8 +29,8 @@ struct Layernorm2dFwdWarpPerRowProblem
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
Thread_N
>
1
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
Thread
PerWarp
_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
Block
Warps
_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPer
Block_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
...
...
include/ck_tile/ops/welford/block/block_welford.hpp
View file @
cc1898fc
...
@@ -324,11 +324,11 @@ CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_
...
@@ -324,11 +324,11 @@ CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_
{
{
using
S
=
BlockShape
;
using
S
=
BlockShape
;
index_t
LastloopN
=
row_size
%
S
::
Block_N
==
0
?
S
::
Block_N
:
row_size
%
S
::
Block_N
;
index_t
LastloopN
=
row_size
%
S
::
Block_N
==
0
?
S
::
Block_N
:
row_size
%
S
::
Block_N
;
constexpr
index_t
NThread
=
S
::
Block
Warps
_N
*
S
::
Thread_N
;
constexpr
index_t
NThread
=
S
::
WarpPer
Block_N
*
S
::
Thread
PerWarp
_N
;
index_t
iNLane
=
get_thread_id
()
%
NThread
;
index_t
iNLane
=
get_thread_id
()
%
NThread
;
index_t
iN0
=
LastloopN
/
(
S
::
Vector_N
*
S
::
Thread_N
);
index_t
iN0
=
LastloopN
/
(
S
::
Vector_N
*
S
::
Thread
PerWarp
_N
);
index_t
iN1
=
(
LastloopN
%
(
S
::
Vector_N
*
S
::
Thread_N
))
/
S
::
Vector_N
;
index_t
iN1
=
(
LastloopN
%
(
S
::
Vector_N
*
S
::
Thread
PerWarp
_N
))
/
S
::
Vector_N
;
index_t
N2
=
(
LastloopN
%
(
S
::
Vector_N
*
S
::
Thread_N
))
%
S
::
Vector_N
;
index_t
N2
=
(
LastloopN
%
(
S
::
Vector_N
*
S
::
Thread
PerWarp
_N
))
%
S
::
Vector_N
;
index_t
iN3
=
iNLane
<
iN1
?
S
::
Vector_N
:
iNLane
==
iN1
?
N2
:
0
;
index_t
iN3
=
iNLane
<
iN1
?
S
::
Vector_N
:
iNLane
==
iN1
?
N2
:
0
;
return
iN0
*
S
::
Vector_N
+
iN3
;
return
iN0
*
S
::
Vector_N
+
iN3
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment