Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
98395085
Commit
98395085
authored
Oct 16, 2024
by
rocking
Browse files
Add kMThreadPerBlock to template parameter
parent
03247367
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
145 additions
and
116 deletions
+145
-116
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_instance.cpp
...2_layernorm2d/instances/layernorm2d_fwd_bf16_instance.cpp
+16
-14
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_pad_instance.cpp
...yernorm2d/instances/layernorm2d_fwd_bf16_pad_instance.cpp
+18
-16
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_instance.cpp
...2_layernorm2d/instances/layernorm2d_fwd_fp16_instance.cpp
+16
-14
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_pad_instance.cpp
...yernorm2d/instances/layernorm2d_fwd_fp16_pad_instance.cpp
+24
-22
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
+11
-7
example/ck_tile/02_layernorm2d/layernorm2d_fwd_api.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd_api.cpp
+60
-43
No files found.
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_instance.cpp
View file @
98395085
...
...
@@ -5,28 +5,30 @@
#include "layernorm2d_fwd_instance_common.hpp"
template
<
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
kMThreadPerBlock
,
ck_tile
::
index_t
kNThreadPerBlock
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kTwoPass
>
using
t
=
layernorm2d_fwd_traits_
<
ck_tile
::
bf16_t
,
NRepeat
,
NThread
,
kMThreadPerBlock
,
kNThreadPerBlock
,
VectorAccessSize
,
false
,
false
,
kTwoPass
>
;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
// template float layernorm2d_fwd_<t<1,
4,
16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1,
4,
32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1,
4,
64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2,
4,
64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4,
4,
64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4,
4,
64, 8, true>>(const S&, A);
template
float
layernorm2d_fwd_
<
t
<
1
,
32
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
4
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
4
,
32
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
4
,
64
,
4
,
true
>
>
(
const
S
&
,
A
);
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_pad_instance.cpp
View file @
98395085
...
...
@@ -5,12 +5,14 @@
#include "layernorm2d_fwd_instance_common.hpp"
template
<
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
kMThreadPerBlock
,
ck_tile
::
index_t
kNThreadPerBlock
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kTwoPass
>
using
t
=
layernorm2d_fwd_traits_
<
ck_tile
::
bf16_t
,
NRepeat
,
NThread
,
kMThreadPerBlock
,
kNThreadPerBlock
,
VectorAccessSize
,
true
,
false
,
...
...
@@ -24,19 +26,19 @@ using t = layernorm2d_fwd_traits_<ck_tile::bf16_t,
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
template
float
layernorm2d_fwd_
<
t
<
1
,
32
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
4
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
4
,
32
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
4
,
64
,
4
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
16
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
16
,
64
,
2
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
4
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
4
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
4
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
4
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
16
,
4
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
16
,
4
,
64
,
2
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
32
,
64
,
1
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
32
,
64
,
1
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
32
,
4
,
64
,
1
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
32
,
4
,
64
,
1
,
true
>
>
(
const
S
&
,
A
);
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_instance.cpp
View file @
98395085
...
...
@@ -5,28 +5,30 @@
#include "layernorm2d_fwd_instance_common.hpp"
template
<
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
kMThreadPerBlock
,
ck_tile
::
index_t
kNThreadPerBlock
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kTwoPass
>
using
t
=
layernorm2d_fwd_traits_
<
ck_tile
::
fp16_t
,
NRepeat
,
NThread
,
kMThreadPerBlock
,
kNThreadPerBlock
,
VectorAccessSize
,
false
,
false
,
kTwoPass
>
;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
// template float layernorm2d_fwd_<t<1,
4,
16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1,
4,
32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1,
4,
64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2,
4,
64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4,
4,
64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4,
4,
64, 8, true>>(const S&, A);
template
float
layernorm2d_fwd_
<
t
<
1
,
32
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
4
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
4
,
32
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
4
,
64
,
4
,
true
>
>
(
const
S
&
,
A
);
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_pad_instance.cpp
View file @
98395085
...
...
@@ -5,38 +5,40 @@
#include "layernorm2d_fwd_instance_common.hpp"
template
<
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
kMThreadPerBlock
,
ck_tile
::
index_t
kNThreadPerBlock
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kTwoPass
>
using
t
=
layernorm2d_fwd_traits_
<
ck_tile
::
fp16_t
,
NRepeat
,
NThread
,
kMThreadPerBlock
,
kNThreadPerBlock
,
VectorAccessSize
,
true
,
false
,
kTwoPass
>
;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
// template float layernorm2d_fwd_<t<1,
4,
16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1,
4,
32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1,
4,
64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2,
4,
64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4,
4,
64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4,
4,
64, 8, true>>(const S&, A);
template
float
layernorm2d_fwd_
<
t
<
1
,
32
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
4
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
4
,
32
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
4
,
64
,
4
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
16
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
16
,
64
,
2
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
4
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
4
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
4
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
4
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
16
,
4
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
16
,
4
,
64
,
2
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
32
,
64
,
1
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
32
,
64
,
1
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
32
,
4
,
64
,
1
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
32
,
4
,
64
,
1
,
true
>
>
(
const
S
&
,
A
);
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
View file @
98395085
...
...
@@ -52,7 +52,8 @@ struct layernorm2d_fwd_args
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
DataType_
,
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
kMThreadPerBlock
,
ck_tile
::
index_t
kNThreadPerBlock
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
...
...
@@ -62,14 +63,17 @@ struct layernorm2d_fwd_traits_
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
ck_tile
::
index_t
MRepeat
=
1
;
static_assert
(
NThread
<=
64
,
"We only support intra-wave reduction"
);
static
constexpr
ck_tile
::
index_t
WaveNum
=
NThread
/
16
;
static_assert
(
kNThreadPerBlock
<=
64
,
"We only support intra-wave reduction"
);
static
constexpr
ck_tile
::
index_t
kNWarpPerBlock
=
1
;
static
constexpr
ck_tile
::
index_t
kMWarpPerBlock
=
kMThreadPerBlock
*
kNThreadPerBlock
/
warpSize
;
// kNThreadPerBlock / 16;
using
thread_tile
=
ck_tile
::
sequence
<
MRepeat
,
NRepeat
,
VectorAccessSize
>
;
using
warp_tile
=
ck_tile
::
sequence
<
MRepeat
*
64
/
NThread
,
NRepeat
*
NThread
*
VectorAccessSize
>
;
using
block_tile
=
ck_tile
::
sequence
<
MRepeat
*
WaveNum
*
64
/
NThread
,
NRepeat
*
NThread
*
VectorAccessSize
>
;
using
warp_tile
=
ck_tile
::
sequence
<
MRepeat
*
warpSize
/
kNThreadPerBlock
,
NRepeat
*
k
NThread
PerBlock
*
VectorAccessSize
>
;
using
block_tile
=
ck_tile
::
sequence
<
kMWarpPerBlock
*
MRepeat
*
warpSize
/
kNThreadPerBlock
,
NRepeat
*
k
NThread
PerBlock
*
VectorAccessSize
>
;
using
Shape
=
ck_tile
::
TileLayernorm2dShape
<
thread_tile
,
warp_tile
,
block_tile
>
;
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd_api.cpp
View file @
98395085
...
...
@@ -6,12 +6,19 @@
template
<
typename
DataType
,
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
kMThreadPerBlock
,
ck_tile
::
index_t
kNThreadPerBlock
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kPadN
,
bool
kTwoPass
=
false
>
using
trait_
=
layernorm2d_fwd_traits_
<
DataType
,
NRepeat
,
NThread
,
VectorAccessSize
,
kPadN
,
false
,
kTwoPass
>
;
using
trait_
=
layernorm2d_fwd_traits_
<
DataType
,
NRepeat
,
kMThreadPerBlock
,
kNThreadPerBlock
,
VectorAccessSize
,
kPadN
,
false
,
kTwoPass
>
;
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
t
,
layernorm2d_fwd_args
a
,
...
...
@@ -24,70 +31,75 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
{
if
(
a
.
N
<=
128
)
{
return
a
.
N
==
128
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
true
>>
(
s
,
a
);
return
a
.
N
==
128
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
32
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
32
,
4
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
256
)
{
return
a
.
N
==
256
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
true
>>
(
s
,
a
);
return
a
.
N
==
256
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
64
,
4
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
512
)
{
return
a
.
N
==
512
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
true
>>
(
s
,
a
);
return
a
.
N
==
512
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
2
,
4
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
2
,
4
,
64
,
4
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
1024
)
{
return
a
.
N
==
1024
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
true
>>
(
s
,
a
);
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
4
,
4
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
4
,
4
,
64
,
4
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
2048
)
{
return
a
.
N
==
2048
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
>>
(
s
,
a
);
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
4
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
4
,
64
,
4
,
true
>>
(
s
,
a
);
}
else
{
return
a
.
N
%
2048
==
0
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
,
true
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
,
true
>>
(
s
,
a
);
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
4
,
64
,
4
,
false
,
true
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
4
,
64
,
4
,
true
,
true
>>
(
s
,
a
);
}
}
else
if
(
a
.
N
%
2
==
0
)
{
if
(
a
.
N
<=
128
)
{
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
true
>>
(
s
,
a
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
256
)
{
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
true
>>
(
s
,
a
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
2
,
4
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
512
)
{
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
true
>>
(
s
,
a
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
4
,
4
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
1024
)
{
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
true
>>
(
s
,
a
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
4
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
2048
)
{
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
>>
(
s
,
a
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
16
,
4
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
{
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
,
true
>>
(
s
,
a
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
16
,
4
,
64
,
2
,
true
,
true
>>
(
s
,
a
);
}
}
else
{
return
a
.
N
<=
2048
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
32
,
64
,
1
,
true
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
32
,
64
,
1
,
true
,
true
>>
(
s
,
a
);
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
32
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
32
,
4
,
64
,
1
,
true
,
true
>>
(
s
,
a
);
}
}
else
if
(
t
.
data_type
.
compare
(
"bf16"
)
==
0
)
...
...
@@ -96,70 +108,75 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
{
if
(
a
.
N
<=
128
)
{
return
a
.
N
==
128
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
32
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
32
,
4
,
true
>>
(
s
,
a
);
return
a
.
N
==
128
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
32
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
32
,
4
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
256
)
{
return
a
.
N
==
256
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
64
,
4
,
true
>>
(
s
,
a
);
return
a
.
N
==
256
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
64
,
4
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
512
)
{
return
a
.
N
==
512
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
2
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
2
,
64
,
4
,
true
>>
(
s
,
a
);
return
a
.
N
==
512
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
2
,
4
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
2
,
4
,
64
,
4
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
1024
)
{
return
a
.
N
==
1024
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
4
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
4
,
64
,
4
,
true
>>
(
s
,
a
);
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
4
,
4
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
4
,
4
,
64
,
4
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
2048
)
{
return
a
.
N
==
2048
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
8
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
8
,
64
,
4
,
true
>>
(
s
,
a
);
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
8
,
4
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
8
,
4
,
64
,
4
,
true
>>
(
s
,
a
);
}
else
{
return
a
.
N
%
2048
==
0
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
8
,
64
,
4
,
false
,
true
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
8
,
64
,
4
,
true
,
true
>>
(
s
,
a
);
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
8
,
4
,
64
,
4
,
false
,
true
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
8
,
4
,
64
,
4
,
true
,
true
>>
(
s
,
a
);
}
}
else
if
(
a
.
N
%
2
==
0
)
{
if
(
a
.
N
<=
128
)
{
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
64
,
2
,
true
>>
(
s
,
a
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
256
)
{
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
2
,
64
,
2
,
true
>>
(
s
,
a
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
2
,
4
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
512
)
{
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
4
,
64
,
2
,
true
>>
(
s
,
a
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
4
,
4
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
1024
)
{
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
8
,
64
,
2
,
true
>>
(
s
,
a
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
8
,
4
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
2048
)
{
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
16
,
64
,
2
,
true
>>
(
s
,
a
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
16
,
4
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
{
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
16
,
64
,
2
,
true
,
true
>>
(
s
,
a
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
16
,
4
,
64
,
2
,
true
,
true
>>
(
s
,
a
);
}
}
else
{
return
a
.
N
<=
2048
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
32
,
64
,
1
,
true
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
32
,
64
,
1
,
true
,
true
>>
(
s
,
a
);
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
32
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
32
,
4
,
64
,
1
,
true
,
true
>>
(
s
,
a
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment