Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4cef2fc5
Commit
4cef2fc5
authored
Oct 21, 2024
by
carlushuang
Browse files
add n1536
parent
970be900
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
4 deletions
+40
-4
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
.../ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
+10
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1536_instance.cpp
...rnorm2d/instances/layernorm2d_fwd_bf16_n1536_instance.cpp
+13
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1536_instance.cpp
...rnorm2d/instances/layernorm2d_fwd_fp16_n1536_instance.cpp
+13
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_pipeline.hpp
...layernorm2d/pipeline/layernorm2d_fwd_rowwise_pipeline.hpp
+4
-4
No files found.
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
View file @
4cef2fc5
...
@@ -58,6 +58,16 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
...
@@ -58,6 +58,16 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
else
else
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
n
<=
1536
)
{
if
(
a
.
n
%
8
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
4
,
64
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
2
,
128
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
6
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
2048
)
{
else
if
(
a
.
n
<=
2048
)
{
if
(
a
.
n
%
8
==
0
)
if
(
a
.
n
%
8
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
...
...
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1536_instance.cpp
0 → 100644
View file @
4cef2fc5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1536_instance.cpp
0 → 100644
View file @
4cef2fc5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
2
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_rowwise_pipeline.hpp
View file @
4cef2fc5
...
@@ -78,6 +78,10 @@ struct Layernorm2dFwdRowwisePipeline
...
@@ -78,6 +78,10 @@ struct Layernorm2dFwdRowwisePipeline
auto
block_welford_cross_warp_sync
=
auto
block_welford_cross_warp_sync
=
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
// compute welford each-thread->cross-lane->cross-warp
// compute welford each-thread->cross-lane->cross-warp
auto
[
mean
,
var
]
=
block_welford
(
x
,
cur_count
,
max_count
);
auto
[
mean
,
var
]
=
block_welford
(
x
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
...
@@ -96,10 +100,6 @@ struct Layernorm2dFwdRowwisePipeline
...
@@ -96,10 +100,6 @@ struct Layernorm2dFwdRowwisePipeline
if
constexpr
(
kSaveInvStd
)
if
constexpr
(
kSaveInvStd
)
store_tile
(
inv_std_window
,
cast_tile
<
InvStdDataType
>
(
inv_std
));
store_tile
(
inv_std_window
,
cast_tile
<
InvStdDataType
>
(
inv_std
));
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
// layernorm computation
// layernorm computation
auto
y
=
make_static_distributed_tensor
<
YDataType
>
(
x
.
get_tile_distribution
());
auto
y
=
make_static_distributed_tensor
<
YDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
sweep_tile
(
y
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment