Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9d13f91b
Commit
9d13f91b
authored
Oct 20, 2024
by
carlushuang
Browse files
add more block-per-tile instance
parent
1cb3e443
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
114 additions
and
11 deletions
+114
-11
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
.../ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
+27
-11
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1024_instance.cpp
...rnorm2d/instances/layernorm2d_fwd_bf16_n1024_instance.cpp
+7
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n2048_instance.cpp
...rnorm2d/instances/layernorm2d_fwd_bf16_n2048_instance.cpp
+14
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n3072_instance.cpp
...rnorm2d/instances/layernorm2d_fwd_bf16_n3072_instance.cpp
+14
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1024_instance.cpp
...rnorm2d/instances/layernorm2d_fwd_fp16_n1024_instance.cpp
+7
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n2048_instance.cpp
...rnorm2d/instances/layernorm2d_fwd_fp16_n2048_instance.cpp
+14
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp
...rnorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp
+14
-0
include/ck_tile/ops/welford/block/block_welford.hpp
include/ck_tile/ops/welford/block/block_welford.hpp
+17
-0
No files found.
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_api.cpp
View file @
9d13f91b
...
@@ -12,7 +12,7 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
...
@@ -12,7 +12,7 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
#if 1
#if 1
float
r
=
-
1
;
float
r
=
-
1
;
// clang-format off
// clang-format off
//
rm rn tm tn vn pd mv 2p
// rm rn tm
tn vn pd mv 2p
if
(
a
.
n
<=
64
)
{
if
(
a
.
n
<=
64
)
{
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
}
...
@@ -49,18 +49,34 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
...
@@ -49,18 +49,34 @@ float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/,
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
12
,
4
,
64
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
n
<=
1024
)
{
else
if
(
a
.
n
<=
1024
)
{
if
(
a
.
n
%
8
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
1
,
128
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
if
(
a
.
n
%
4
==
0
)
else
if
(
a
.
n
%
4
==
0
)
// r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 4, true, false, false>>(s, a);
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
128
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
128
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
8
==
0
)
else
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
4
,
64
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
2048
)
{
if
(
a
.
n
%
8
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
2
,
1
,
128
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
4
,
1
,
256
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
8
,
1
,
256
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
else
if
(
a
.
n
<=
3072
)
{
if
(
a
.
n
%
8
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
128
,
8
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
4
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
256
,
4
,
true
,
false
,
false
>>
(
s
,
a
);
else
if
(
a
.
n
%
2
==
0
)
else
if
(
a
.
n
%
2
==
0
)
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
8
,
4
,
6
4
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
6
,
1
,
25
6
,
2
,
true
,
false
,
false
>>
(
s
,
a
);
else
else
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
16
,
4
,
6
4
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
r
=
layernorm2d_fwd_
<
trait_
<
data_type
,
1
,
3
,
1
,
102
4
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
}
return
r
;
return
r
;
#else
#else
...
...
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n1024_instance.cpp
View file @
9d13f91b
...
@@ -6,10 +6,17 @@
...
@@ -6,10 +6,17 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd mv 2p
// rm rn tm tn vn pd mv 2p
#if 0
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
1
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
128
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n2048_instance.cpp
0 → 100644
View file @
9d13f91b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
1
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
8
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_bf16_n3072_instance.cpp
0 → 100644
View file @
9d13f91b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n1024_instance.cpp
View file @
9d13f91b
...
@@ -6,10 +6,17 @@
...
@@ -6,10 +6,17 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd mv 2p
// rm rn tm tn vn pd mv 2p
#if 0
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
1
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
128
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
128
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n2048_instance.cpp
0 → 100644
View file @
9d13f91b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
1
,
256
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
8
,
1
,
256
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_n3072_instance.cpp
0 → 100644
View file @
9d13f91b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd mv 2p
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
include/ck_tile/ops/welford/block/block_welford.hpp
View file @
9d13f91b
...
@@ -322,6 +322,7 @@ struct BlockWelfordCrossWarpSync
...
@@ -322,6 +322,7 @@ struct BlockWelfordCrossWarpSync
template
<
typename
BlockShape
>
template
<
typename
BlockShape
>
CK_TILE_DEVICE
constexpr
index_t
block_tile_welford_calculate_max_count
(
int
row_size
)
CK_TILE_DEVICE
constexpr
index_t
block_tile_welford_calculate_max_count
(
int
row_size
)
{
{
#if 0
using S = BlockShape;
using S = BlockShape;
index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N;
index_t LastloopN = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N;
constexpr index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N;
constexpr index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N;
...
@@ -331,6 +332,22 @@ CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_
...
@@ -331,6 +332,22 @@ CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_
index_t N2 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N;
index_t N2 = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N;
index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0;
index_t iN3 = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0;
return iN0 * S::Vector_N + iN3;
return iN0 * S::Vector_N + iN3;
#endif
using
S_
=
BlockShape
;
constexpr
index_t
ThreadsPerBlock_N
=
S_
::
WarpPerBlock_N
*
S_
::
ThreadPerWarp_N
;
// TODO: we always check vector size, need be evenly devidable by vector-n
const
index_t
element_per_row
=
row_size
/
S_
::
Vector_N
;
index_t
lane_id_n
=
get_thread_id
()
%
ThreadsPerBlock_N
;
index_t
cnt
=
0
;
// TODO: Repeat_N can not be too long, otherwise this is not good
static_for
<
0
,
S_
::
Repeat_N
,
1
>
{}([
&
](
auto
)
{
index_t
_a
=
lane_id_n
<
element_per_row
?
1
:
0
;
cnt
+=
_a
;
lane_id_n
+=
ThreadsPerBlock_N
;
});
return
cnt
*
S_
::
Vector_N
;
}
}
// Note: this function must be called after all the computation
// Note: this function must be called after all the computation
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment