Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4e14a894
Commit
4e14a894
authored
Oct 16, 2024
by
rocking
Browse files
refactor api
parent
8c3d43cf
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
220 additions
and
304 deletions
+220
-304
example/ck_tile/02_layernorm2d/CMakeLists.txt
example/ck_tile/02_layernorm2d/CMakeLists.txt
+2
-1
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
+0
-4
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_instance.cpp
...2_layernorm2d/instances/layernorm2d_fwd_fp16_instance.cpp
+70
-22
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_pad_instance.cpp
...yernorm2d/instances/layernorm2d_fwd_fp16_pad_instance.cpp
+80
-22
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp32_instance.cpp
...2_layernorm2d/instances/layernorm2d_fwd_fp32_instance.cpp
+0
-36
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
+32
-0
example/ck_tile/02_layernorm2d/layernorm2d_fwd_api.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd_api.cpp
+36
-130
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
+0
-89
No files found.
example/ck_tile/02_layernorm2d/CMakeLists.txt
View file @
4e14a894
...
...
@@ -9,7 +9,8 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE layernorm2d_fwd_api.cpp ${INST
set
(
EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS
)
# list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list
(
APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
)
target_compile_options
(
${
EXAMPLE_LAYERNORM2D_FWD
}
PRIVATE
${
EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS
}
)
...
...
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
View file @
4e14a894
...
...
@@ -127,10 +127,6 @@ int main(int argc, char* argv[])
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
}
if
(
data_type
==
"fp32"
)
{
return
run
<
float
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_instance.cpp
View file @
4e14a894
...
...
@@ -3,26 +3,74 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
#include "layernorm_dispatch.hpp"
// clang-format off
// template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
template
<
typename
Traits_
>
float
layernorm2d_fwd_
(
const
ck_tile
::
stream_config
&
s
,
layernorm2d_fwd_args
a
)
{
using
DataType
=
typename
Traits_
::
DataType
;
using
PipelineProblem
=
ck_tile
::
BlockLayernorm2dFwdProblem
<
typename
LayerNormTypeConfig
<
DataType
>::
XDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
GammaDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
BetaDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
ComputeDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
YDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
MeanDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
InvStdDataType
,
typename
Traits_
::
Shape
,
Traits_
::
kPadN
,
Traits_
::
kSaveMeanInvStd
,
Traits_
::
kTwoPass
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
const
dim3
grids
=
Kernel
::
GridSize
(
a
.
M
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
return
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
a
.
p_x
,
a
.
p_gamma
,
a
.
p_beta
,
a
.
p_y
,
a
.
p_mean
,
a
.
p_invStd
,
a
.
epsilon
,
a
.
M
,
a
.
N
));
}
template
<
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kTwoPass
>
using
t
=
layernorm2d_fwd_traits_
<
ck_tile
::
fp16_t
,
NRepeat
,
NThread
,
VectorAccessSize
,
false
,
false
,
kTwoPass
>
;
using
S
=
ck_tile
::
stream_config
;
using
A
=
layernorm2d_fwd_args
;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
template
float
layernorm2d_fwd_
<
t
<
1
,
32
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
4
,
true
>
>
(
const
S
&
,
A
);
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_pad_instance.cpp
View file @
4e14a894
...
...
@@ -3,26 +3,84 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"
#include "layernorm_dispatch.hpp"
// clang-format off
// template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
template
<
typename
Traits_
>
float
layernorm2d_fwd_
(
const
ck_tile
::
stream_config
&
s
,
layernorm2d_fwd_args
a
)
{
using
DataType
=
typename
Traits_
::
DataType
;
using
PipelineProblem
=
ck_tile
::
BlockLayernorm2dFwdProblem
<
typename
LayerNormTypeConfig
<
DataType
>::
XDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
GammaDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
BetaDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
ComputeDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
YDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
MeanDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
InvStdDataType
,
typename
Traits_
::
Shape
,
Traits_
::
kPadN
,
Traits_
::
kSaveMeanInvStd
,
Traits_
::
kTwoPass
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
const
dim3
grids
=
Kernel
::
GridSize
(
a
.
M
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
return
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
a
.
p_x
,
a
.
p_gamma
,
a
.
p_beta
,
a
.
p_y
,
a
.
p_mean
,
a
.
p_invStd
,
a
.
epsilon
,
a
.
M
,
a
.
N
));
}
template
<
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kTwoPass
>
using
t
=
layernorm2d_fwd_traits_
<
ck_tile
::
fp16_t
,
NRepeat
,
NThread
,
VectorAccessSize
,
true
,
false
,
kTwoPass
>
;
using
S
=
const
ck_tile
::
stream_config
;
using
A
=
layernorm2d_fwd_args
;
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);
template
float
layernorm2d_fwd_
<
t
<
1
,
32
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
4
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
4
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
1
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
2
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
4
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
8
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
16
,
64
,
2
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
16
,
64
,
2
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
32
,
64
,
1
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_fwd_
<
t
<
32
,
64
,
1
,
true
>
>
(
const
S
&
,
A
);
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp32_instance.cpp
deleted
100644 → 0
View file @
8c3d43cf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
#ifdef CK_TILE_LAYERNORM2D_FWD_FP32_DEFAULT
template
float
run_layernorm
<
float
,
1
,
32
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
2
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
2
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
4
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
4
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
32
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
2
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
2
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
4
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
4
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
#endif
// clang-format on
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
View file @
4e14a894
...
...
@@ -49,6 +49,38 @@ struct layernorm2d_fwd_args
ck_tile
::
index_t
N
;
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
DataType_
,
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
>
struct
layernorm2d_fwd_traits_
{
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
ck_tile
::
index_t
MRepeat
=
1
;
static_assert
(
NThread
<=
64
,
"We only support intra-wave reduction"
);
static
constexpr
ck_tile
::
index_t
WaveNum
=
NThread
/
16
;
using
thread_tile
=
ck_tile
::
sequence
<
MRepeat
,
NRepeat
,
VectorAccessSize
>
;
using
warp_tile
=
ck_tile
::
sequence
<
MRepeat
*
64
/
NThread
,
NRepeat
*
NThread
*
VectorAccessSize
>
;
using
block_tile
=
ck_tile
::
sequence
<
MRepeat
*
WaveNum
*
64
/
NThread
,
NRepeat
*
NThread
*
VectorAccessSize
>
;
using
Shape
=
ck_tile
::
TileLayernorm2dShape
<
thread_tile
,
warp_tile
,
block_tile
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
template
<
typename
Traits_
>
float
layernorm2d_fwd_
(
const
ck_tile
::
stream_config
&
s
,
layernorm2d_fwd_args
a
);
// This is the public API, will be generated by script
struct
layernorm2d_fwd_traits
{
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd_api.cpp
View file @
4e14a894
...
...
@@ -2,7 +2,16 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
#include "layernorm2d_fwd.hpp"
template
<
typename
DataType
,
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kPadN
,
bool
kTwoPass
=
false
>
using
trait_
=
layernorm2d_fwd_traits_
<
DataType
,
NRepeat
,
NThread
,
VectorAccessSize
,
kPadN
,
false
,
kTwoPass
>
;
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
t
,
layernorm2d_fwd_args
a
,
...
...
@@ -11,182 +20,79 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
float
r
=
-
1
;
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
// Disable all vector 8fp16 read/write instances as it has performance issue regarding
// compiler
#if 0
if(a.N % 8 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(a, s);
}
else
{
return a.N == 2048 ? run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(a, s);
}
}
else if(a.N % 4 == 0)
#endif
if
(
a
.
N
%
4
==
0
)
{
if
(
a
.
N
<=
128
)
{
return
a
.
N
==
128
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
256
)
{
return
a
.
N
==
256
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
512
)
{
return
a
.
N
==
512
?
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
1024
)
{
return
a
.
N
==
1024
?
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
2048
)
{
return
a
.
N
==
2048
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
>
(
a
,
s
);
}
else
{
return
a
.
N
%
2048
==
0
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
,
true
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
,
true
>
(
a
,
s
);
}
}
else
if
(
a
.
N
%
2
==
0
)
{
if
(
a
.
N
<=
128
)
{
return
a
.
N
==
128
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
true
>
(
a
,
s
);
return
a
.
N
==
128
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
256
)
{
return
a
.
N
==
256
?
run_
layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
false
>
(
a
,
s
)
:
run_
layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
true
>
(
a
,
s
);
return
a
.
N
==
256
?
layernorm
2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
false
>
>
(
s
,
a
)
:
layernorm
2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
true
>
>
(
s
,
a
);
}
else
if
(
a
.
N
<=
512
)
{
return
a
.
N
==
512
?
run_
layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
false
>
(
a
,
s
)
:
run_
layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
true
>
(
a
,
s
);
return
a
.
N
==
512
?
layernorm
2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
false
>
>
(
s
,
a
)
:
layernorm
2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
true
>
>
(
s
,
a
);
}
else
if
(
a
.
N
<=
1024
)
{
return
a
.
N
==
1024
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
true
>
(
a
,
s
);
return
a
.
N
==
1024
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
2048
)
{
return
a
.
N
==
2048
?
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
>
(
a
,
s
);
return
a
.
N
==
2048
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
>>
(
s
,
a
);
}
else
{
return
a
.
N
%
2048
==
0
?
run_
layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
,
true
>
(
a
,
s
)
:
run_
layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
,
true
>
(
a
,
s
);
?
layernorm
2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
,
true
>
>
(
s
,
a
)
:
layernorm
2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
,
true
>
>
(
s
,
a
);
}
}
}
#ifdef CK_TILE_LAYERNORM2D_FWD_FP32_DEFAULT
else
if
(
t
.
data_type
.
compare
(
"fp32"
)
==
0
)
{
if
(
a
.
N
%
4
==
0
)
else
if
(
a
.
N
%
2
==
0
)
{
if
(
a
.
N
<=
128
)
{
return
a
.
N
==
128
?
run_layernorm
<
float
,
1
,
32
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
1
,
32
,
4
,
true
>
(
a
,
s
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
256
)
{
return
a
.
N
==
256
?
run_layernorm
<
float
,
1
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
1
,
64
,
4
,
true
>
(
a
,
s
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
512
)
{
return
a
.
N
==
512
?
run_layernorm
<
float
,
2
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
2
,
64
,
4
,
true
>
(
a
,
s
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
1024
)
{
return
a
.
N
==
1024
?
run_layernorm
<
float
,
4
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
4
,
64
,
4
,
true
>
(
a
,
s
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
if
(
a
.
N
<=
2048
)
{
return
a
.
N
==
2048
?
run_layernorm
<
float
,
8
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
8
,
64
,
4
,
true
>
(
a
,
s
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
>>
(
s
,
a
);
}
else
{
return
a
.
N
%
2048
==
0
?
run_layernorm
<
float
,
8
,
64
,
4
,
false
,
true
>
(
a
,
s
)
:
run_layernorm
<
float
,
8
,
64
,
4
,
true
,
true
>
(
a
,
s
);
return
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
,
true
>>
(
s
,
a
);
}
}
else
if
(
a
.
N
%
2
==
0
)
else
{
if
(
a
.
N
<=
128
)
{
return
a
.
N
==
128
?
run_layernorm
<
float
,
1
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
1
,
64
,
2
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
256
)
{
return
a
.
N
==
256
?
run_layernorm
<
float
,
2
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
2
,
64
,
2
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
512
)
{
return
a
.
N
==
512
?
run_layernorm
<
float
,
4
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
4
,
64
,
2
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
1024
)
{
return
a
.
N
==
1024
?
run_layernorm
<
float
,
8
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
8
,
64
,
2
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
2048
)
{
return
a
.
N
==
2048
?
run_layernorm
<
float
,
16
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
16
,
64
,
2
,
true
>
(
a
,
s
);
}
else
{
return
a
.
N
%
2048
==
0
?
run_layernorm
<
float
,
16
,
64
,
2
,
false
,
true
>
(
a
,
s
)
:
run_layernorm
<
float
,
16
,
64
,
2
,
true
,
true
>
(
a
,
s
);
}
return
a
.
N
<=
2048
?
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
32
,
64
,
1
,
true
,
false
>>
(
s
,
a
)
:
layernorm2d_fwd_
<
trait_
<
ck_tile
::
fp16_t
,
32
,
64
,
1
,
true
,
true
>>
(
s
,
a
);
}
}
#endif
if
(
r
<
0
)
if
(
r
<
0
)
throw
std
::
runtime_error
(
"Without supported instances!"
);
return
r
;
}
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
deleted
100644 → 0
View file @
8c3d43cf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core/numeric/integer.hpp>
#include <ck_tile/host.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include "layernorm2d_fwd.hpp"
template
<
typename
InOutDataType
,
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kPadN
,
bool
kSaveMeanInvStd
,
bool
kTwoPass
>
struct
layernorm_dispatch
{
static
constexpr
ck_tile
::
index_t
MRepeat
=
1
;
static_assert
(
NThread
<=
64
,
"We only support intra-wave reduction"
);
static
constexpr
ck_tile
::
index_t
WaveNum
=
NThread
/
16
;
// clang-format off
using
thread_tile
=
ck_tile
::
sequence
<
MRepeat
,
NRepeat
,
VectorAccessSize
>
;
using
warp_tile
=
ck_tile
::
sequence
<
MRepeat
*
64
/
NThread
,
NRepeat
*
NThread
*
VectorAccessSize
>
;
using
block_tile
=
ck_tile
::
sequence
<
MRepeat
*
WaveNum
*
64
/
NThread
,
NRepeat
*
NThread
*
VectorAccessSize
>
;
// clang-format on
using
Shape
=
ck_tile
::
TileLayernorm2dShape
<
thread_tile
,
warp_tile
,
block_tile
>
;
using
PipelineProblem
=
ck_tile
::
BlockLayernorm2dFwdProblem
<
typename
LayerNormTypeConfig
<
InOutDataType
>::
XDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
GammaDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
BetaDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
ComputeDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
YDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
MeanDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
InvStdDataType
,
Shape
,
kPadN
,
kSaveMeanInvStd
,
kTwoPass
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
static
float
Run
(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
{
using
k_
=
Kernel
;
const
dim3
grids
=
k_
::
GridSize
(
param
.
M
);
constexpr
dim3
blocks
=
k_
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
return
ck_tile
::
launch_kernel
(
stream
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
k_
{},
grids
,
blocks
,
0
,
param
.
p_x
,
param
.
p_gamma
,
param
.
p_beta
,
param
.
p_y
,
param
.
p_mean
,
param
.
p_invStd
,
param
.
epsilon
,
param
.
M
,
param
.
N
));
};
};
template
<
typename
InOutDataType
,
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kPadN
,
bool
kTwoPass
=
false
>
float
run_layernorm
(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
{
// TODO - Add SaveMeanInvStd instance
constexpr
bool
kSaveMeanInvStd
=
false
;
return
layernorm_dispatch
<
InOutDataType
,
NRepeat
,
NThread
,
VectorAccessSize
,
kPadN
,
kSaveMeanInvStd
,
kTwoPass
>::
Run
(
param
,
stream
);
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment