Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9a681c7d
Commit
9a681c7d
authored
Nov 04, 2024
by
dummycoderfe
Browse files
change block tile and tests ok
parent
4d7e063a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
5 additions
and
2 deletions
+5
-2
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
.../ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
+1
-1
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n64_n128_instance.cpp
...rm2d/instances/layernorm2d_fwd_bf16_n64_n128_instance.cpp
+1
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp
...rm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp
+1
-0
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
+2
-1
No files found.
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
View file @
9a681c7d
...
@@ -33,7 +33,7 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
...
@@ -33,7 +33,7 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
// clang-format off
// clang-format off
// rm rn tm tn vn pd mv 2p
// rm rn tm tn vn pd mv 2p
if
(
a
.
n
<=
64
)
{
if
(
a
.
n
<=
64
)
{
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
6
4
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
1
6
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
n
<=
128
)
{
else
if
(
a
.
n
<=
128
)
{
if
(
a
.
n
%
2
==
0
)
if
(
a
.
n
%
2
==
0
)
...
...
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n64_n128_instance.cpp
View file @
9a681c7d
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd mv 2p
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
16
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n64_n128_instance.cpp
View file @
9a681c7d
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd mv 2p
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
16
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
View file @
9a681c7d
...
@@ -101,7 +101,8 @@ struct layernorm2d_fwd_traits_
...
@@ -101,7 +101,8 @@ struct layernorm2d_fwd_traits_
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Shape
=
ck_tile
::
Layernorm2dShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Shape
=
ck_tile
::
Layernorm2dShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
,
ThreadPerBlock_M_
*
ThreadPerBlock_N_
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment