Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
63214d01
Commit
63214d01
authored
Oct 12, 2024
by
letaoqin
Browse files
port layernorm
parent
29d384d0
Changes
14
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
908 additions
and
251 deletions
+908
-251
example/ck_tile/02_layernorm2d/CMakeLists.txt
example/ck_tile/02_layernorm2d/CMakeLists.txt
+18
-2
example/ck_tile/02_layernorm2d/README.md
example/ck_tile/02_layernorm2d/README.md
+2
-3
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
+145
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel.cpp
.../02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel.cpp
+28
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel_pad.cpp
...layernorm2d/instances/layernorm2d_fwd_fp16_kernel_pad.cpp
+28
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp32_kernel.cpp
.../02_layernorm2d/instances/layernorm2d_fwd_fp32_kernel.cpp
+34
-0
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
+39
-2
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp16.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp16.cpp
+151
-0
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp32.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp32.cpp
+103
-0
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
+78
-0
example/ck_tile/02_layernorm2d/perf_test.sh
example/ck_tile/02_layernorm2d/perf_test.sh
+32
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+235
-230
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
...ps/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
+12
-12
include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp
...e/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp
+3
-2
No files found.
example/ck_tile/02_layernorm2d/CMakeLists.txt
View file @
63214d01
set
(
EXAMPLE_LAYERNORM2D_FWD
"tile_example_layernorm2d_fwd"
)
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable
(
tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp
)
target_compile_options
(
tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD
)
\ No newline at end of file
message
(
"adding example
${
EXAMPLE_LAYERNORM2D_FWD
}
"
)
file
(
GLOB INSTANCE_SRCS instances/*.cpp
)
add_executable
(
${
EXAMPLE_LAYERNORM2D_FWD
}
EXCLUDE_FROM_ALL example_layernorm2d_fwd.cpp
)
target_include_directories
(
${
EXAMPLE_LAYERNORM2D_FWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
EXAMPLE_LAYERNORM2D_FWD
}
PRIVATE layernorm2d_fwd_fp16.cpp layernorm2d_fwd_fp32.cpp
${
INSTANCE_SRCS
}
)
set
(
EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS
)
# list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options
(
${
EXAMPLE_LAYERNORM2D_FWD
}
PRIVATE
${
EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS
}
)
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property
(
GLOBAL PROPERTY RULE_MESSAGES OFF
)
example/ck_tile/02_layernorm2d/README.md
View file @
63214d01
...
...
@@ -6,8 +6,7 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_layernorm2d_fwd -j
```
This will result in an executable
`build/bin/tile_example_layernorm2d_fwd`
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
→
example/ck_tile/02_layernorm2d/
example_
layernorm2d_fwd.cpp
View file @
63214d01
...
...
@@ -2,61 +2,8 @@
#include "layernorm2d_fwd.hpp"
#include <cstring>
// Host API implementation
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
t
,
layernorm2d_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
#ifdef SAVE_MEAN_INV_STD
using
MeanDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
#else
using
MeanDataType
=
ck_tile
::
null_type
;
using
InvStdDataType
=
ck_tile
::
null_type
;
#endif
using
ComputeDataType
=
float
;
using
thread_tile
=
ck_tile
::
sequence
<
4
,
4
>
;
using
warp_tile
=
ck_tile
::
sequence
<
8
,
128
>
;
using
block_tile
=
ck_tile
::
sequence
<
32
,
128
>
;
using
Shape
=
ck_tile
::
TileLayernorm2dShape
<
thread_tile
,
warp_tile
,
block_tile
>
;
using
PipelineProblem
=
ck_tile
::
BlockLayernorm2dFwdProblem
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
,
Shape
,
true
,
true
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
a
.
p_x
,
a
.
p_gamma
,
a
.
p_beta
,
a
.
p_y
,
a
.
p_mean
,
a
.
p_invStd
,
a
.
epsilon
,
a
.
M
,
a
.
N
);
const
dim3
grids
=
Kernel
::
GridSize
(
a
.
M
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
Shape
::
kMWarpPerBlock
*
Shape
::
kNWarpPerBlock
;
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
return
0
;
}
extern
float
layernorm2d_fwd_fp16
(
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
float
layernorm2d_fwd_fp32
(
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
auto
create_args
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -65,37 +12,37 @@ auto create_args(int argc, char* argv[])
.
insert
(
"n"
,
"4096"
,
"m dimension"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
);
.
insert
(
"prec"
,
"fp32"
,
"precision"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
int
main
(
int
argc
,
char
*
argv
[])
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
using
TypeConfig
=
LayerNormTypeConfig
<
DataType
>
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
YDataType
=
typename
TypeConfig
::
YDataType
;
using
GammaDataType
=
typename
TypeConfig
::
GammaDataType
;
using
BetaDataType
=
typename
TypeConfig
::
BetaDataType
;
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
#ifdef SAVE_MEAN_INV_STD
using
MeanDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
#else
using
MeanDataType
=
ck_tile
::
null_type
;
using
InvStdDataType
=
ck_tile
::
null_type
;
#endif
using
ComputeDataType
=
float
;
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
M
,
N
});
...
...
@@ -108,25 +55,15 @@ int main(int argc, char* argv[])
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
M
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
M
});
#ifdef SAVE_MEAN_INV_STD
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_dev
({
M
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_dev
({
M
});
#endif
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
5.
f
,
5.
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
5.
f
,
5.
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
5.
f
,
5.
f
}(
beta_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
.5
f
,
.5
f
}(
beta_host
);
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
#ifdef SAVE_MEAN_INV_STD
ck_tile
::
DeviceMem
mean_buf
(
mean_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
invStd_buf
(
invStd_host_dev
.
get_element_space_size_in_bytes
());
#endif
x_buf
.
ToDevice
(
x_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
beta_buf
.
ToDevice
(
beta_host
.
data
());
...
...
@@ -137,26 +74,30 @@ int main(int argc, char* argv[])
gamma_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
mean_buf
.
GetDeviceBuffer
(),
invStd_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
#endif
epsilon
,
M
,
N
};
float
ave_time
=
layernorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
});
float
ave_time
=
.0
;
if
constexpr
(
std
::
is_same
<
DataType
,
ck_tile
::
fp16_t
>::
value
)
{
ave_time
=
layernorm2d_fwd_fp16
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
0
,
warmup
,
repeat
});
}
else
if
constexpr
(
std
::
is_same
<
DataType
,
float
>::
value
)
{
ave_time
=
layernorm2d_fwd_fp32
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
0
,
warmup
,
repeat
});
}
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
M
*
N
+
sizeof
(
GammaDataType
)
*
N
+
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
YDataType
)
*
M
*
N
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
M
<<
", n:"
<<
N
<<
", "
<<
ave_time
<<
"
m
s, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
<<
" m:"
<<
M
<<
", n:"
<<
N
<<
", "
<<
ave_time
*
1.E6
<<
"
n
s, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
bool
pass
=
true
;
...
...
@@ -176,18 +117,29 @@ int main(int argc, char* argv[])
pass
=
ck_tile
::
check_err
(
y_host_dev
,
y_host_ref
);
#ifdef SAVE_MEAN_INV_STD
mean_buf
.
FromDevice
(
mean_host_dev
.
data
());
pass
&=
ck_tile
::
check_err
(
mean_host_dev
,
mean_host_ref
);
invStd_buf
.
FromDevice
(
invStd_host_dev
.
data
());
pass
&=
ck_tile
::
check_err
(
invStd_host_dev
,
invStd_host_ref
);
#endif
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
}
std
::
cout
<<
std
::
endl
<<
std
::
flush
;
return
!
pass
;
return
pass
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
if
(
data_type
==
"fp16"
)
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
}
if
(
data_type
==
"fp32"
)
{
return
run
<
float
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel.cpp
0 → 100644
View file @
63214d01
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel_pad.cpp
0 → 100644
View file @
63214d01
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp32_kernel.cpp
0 → 100644
View file @
63214d01
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
template
float
run_layernorm
<
float
,
1
,
32
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
2
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
2
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
4
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
4
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
32
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
2
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
2
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
4
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
4
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
View file @
63214d01
...
...
@@ -13,14 +13,51 @@ struct layernorm2d_fwd_traits
std
::
string
data_type
;
};
template
<
typename
DataType
>
struct
LayerNormTypeConfig
;
template
<
>
struct
LayerNormTypeConfig
<
ck_tile
::
half_t
>
{
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
#ifdef SAVE_MEAN_INV_STD
using
MeanDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
#else
using
MeanDataType
=
ck_tile
::
null_type
;
using
InvStdDataType
=
ck_tile
::
null_type
;
#endif
using
ComputeDataType
=
float
;
};
template
<
>
struct
LayerNormTypeConfig
<
float
>
{
using
XDataType
=
float
;
using
YDataType
=
float
;
using
GammaDataType
=
float
;
using
BetaDataType
=
float
;
#ifdef SAVE_MEAN_INV_STD
using
MeanDataType
=
float
;
using
InvStdDataType
=
float
;
#else
using
MeanDataType
=
ck_tile
::
null_type
;
using
InvStdDataType
=
ck_tile
::
null_type
;
#endif
using
ComputeDataType
=
float
;
};
struct
layernorm2d_fwd_args
{
const
void
*
p_x
;
const
void
*
p_gamma
;
const
void
*
p_beta
;
void
*
p_y
;
void
*
p_mean
;
void
*
p_invStd
;
//
void* p_mean;
//
void* p_invStd;
float
epsilon
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp16.cpp
0 → 100644
View file @
63214d01
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
float
layernorm2d_fwd_fp16
(
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
{
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
#if 0
if(param.N % 8 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(param, stream);
}
else
{
return param.N == 2048 ? run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(param, stream);
}
}
else if(param.N % 4 == 0)
#endif
if
(
param
.
N
%
4
==
0
)
{
if
(
param
.
N
<=
128
)
{
return
param
.
N
==
128
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
256
)
{
return
param
.
N
==
256
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
512
)
{
return
param
.
N
==
512
?
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
1024
)
{
return
param
.
N
==
1024
?
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
2048
)
{
return
param
.
N
==
2048
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
{
return
param
.
N
%
2048
==
0
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
,
true
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
,
true
>
(
param
,
stream
);
}
}
else
if
(
param
.
N
%
2
==
0
)
{
if
(
param
.
N
<=
128
)
{
return
param
.
N
==
128
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
256
)
{
return
param
.
N
==
256
?
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
512
)
{
return
param
.
N
==
512
?
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
1024
)
{
return
param
.
N
==
1024
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
2048
)
{
return
param
.
N
==
2048
?
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
{
return
param
.
N
%
2048
==
0
?
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
,
true
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
,
true
>
(
param
,
stream
);
}
}
else
{
throw
std
::
runtime_error
(
"Sequence length sizes not supported!"
);
}
};
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp32.cpp
0 → 100644
View file @
63214d01
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
extern
template
float
run_layernorm
<
float
,
1
,
32
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
2
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
2
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
4
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
4
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
8
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
32
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
2
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
2
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
4
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
4
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
8
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
float
layernorm2d_fwd_fp32
(
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
{
if
(
param
.
N
%
4
==
0
)
{
if
(
param
.
N
<=
128
)
{
return
param
.
N
==
128
?
run_layernorm
<
float
,
1
,
32
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
1
,
32
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
256
)
{
return
param
.
N
==
256
?
run_layernorm
<
float
,
1
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
1
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
512
)
{
return
param
.
N
==
512
?
run_layernorm
<
float
,
2
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
2
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
1024
)
{
return
param
.
N
==
1024
?
run_layernorm
<
float
,
4
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
4
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
2048
)
{
return
param
.
N
==
2048
?
run_layernorm
<
float
,
8
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
8
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
{
return
param
.
N
%
2048
==
0
?
run_layernorm
<
float
,
8
,
64
,
4
,
false
,
true
>
(
param
,
stream
)
:
run_layernorm
<
float
,
8
,
64
,
4
,
true
,
true
>
(
param
,
stream
);
}
}
else
if
(
param
.
N
%
2
==
0
)
{
if
(
param
.
N
<=
128
)
{
return
param
.
N
==
128
?
run_layernorm
<
float
,
1
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
1
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
256
)
{
return
param
.
N
==
256
?
run_layernorm
<
float
,
2
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
2
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
512
)
{
return
param
.
N
==
512
?
run_layernorm
<
float
,
4
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
4
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
1024
)
{
return
param
.
N
==
1024
?
run_layernorm
<
float
,
8
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
8
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
2048
)
{
return
param
.
N
==
2048
?
run_layernorm
<
float
,
16
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
16
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
{
return
param
.
N
%
2048
==
0
?
run_layernorm
<
float
,
16
,
64
,
2
,
false
,
true
>
(
param
,
stream
)
:
run_layernorm
<
float
,
16
,
64
,
2
,
true
,
true
>
(
param
,
stream
);
}
}
else
{
throw
std
::
runtime_error
(
"Sequence length sizes not supported!"
);
}
};
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
0 → 100644
View file @
63214d01
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core/numeric/integer.hpp>
#include <ck_tile/host.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include "layernorm2d_fwd.hpp"
template
<
typename
InOutDataType
,
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kPadN
,
bool
kTwoPass
>
struct
layernorm_dispatch
{
static
constexpr
ck_tile
::
index_t
MRepeat
=
1
;
static_assert
(
NThread
<=
64
,
"We only support intra-wave reduction"
);
static
constexpr
ck_tile
::
index_t
WaveNum
=
NThread
/
16
;
// clang-format off
using
thread_tile
=
ck_tile
::
sequence
<
MRepeat
,
NRepeat
,
VectorAccessSize
>
;
using
warp_tile
=
ck_tile
::
sequence
<
MRepeat
*
64
/
NThread
,
NRepeat
*
NThread
*
VectorAccessSize
>
;
using
block_tile
=
ck_tile
::
sequence
<
MRepeat
*
WaveNum
*
64
/
NThread
,
NRepeat
*
NThread
*
VectorAccessSize
>
;
// clang-format on
using
Shape
=
ck_tile
::
TileLayernorm2dShape
<
thread_tile
,
warp_tile
,
block_tile
>
;
using
PipelineProblem
=
ck_tile
::
BlockLayernorm2dFwdProblem
<
typename
LayerNormTypeConfig
<
InOutDataType
>::
XDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
GammaDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
BetaDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
ComputeDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
YDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
MeanDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
InvStdDataType
,
Shape
,
kPadN
,
kTwoPass
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
static
float
Run
(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
{
using
k_
=
Kernel
;
const
dim3
grids
=
k_
::
GridSize
(
param
.
M
);
constexpr
dim3
blocks
=
k_
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
return
ck_tile
::
launch_kernel
(
stream
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
k_
{},
grids
,
blocks
,
0
,
param
.
p_x
,
param
.
p_gamma
,
param
.
p_beta
,
param
.
p_y
,
param
.
epsilon
,
param
.
M
,
param
.
N
));
};
};
template
<
typename
InOutDataType
,
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kPadN
,
bool
kTwoPass
=
false
>
float
run_layernorm
(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
{
return
layernorm_dispatch
<
InOutDataType
,
NRepeat
,
NThread
,
VectorAccessSize
,
kPadN
,
kTwoPass
>::
Run
(
param
,
stream
);
};
example/ck_tile/02_layernorm2d/perf_test.sh
0 → 100644
View file @
63214d01
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
\ No newline at end of file
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
63214d01
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
View file @
63214d01
...
...
@@ -15,8 +15,8 @@ template <typename XDataType_,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
BlockShape_
,
bool
kPad
M
_
,
bool
k
PadN
_
>
bool
kPad
N
_
,
bool
k
TwoPass
_
>
struct
BlockLayernorm2dFwdProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
...
...
@@ -27,8 +27,8 @@ struct BlockLayernorm2dFwdProblem
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp
View file @
63214d01
...
...
@@ -12,13 +12,14 @@ template <typename ThreadTile, // Sequence<...
struct
TileLayernorm2dShape
{
static
constexpr
index_t
kMPerThread
=
ThreadTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerThread
=
ThreadTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kNRepeat
=
ThreadTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kNPerThread
=
ThreadTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kMPerWarp
=
WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerWarp
=
WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMThreadPerWarp
=
kMPerWarp
/
kMPerThread
;
static
constexpr
index_t
kNThreadPerWarp
=
kNPerWarp
/
kNPerThread
;
static
constexpr
index_t
kNThreadPerWarp
=
kNPerWarp
/
kNPerThread
/
kNRepeat
;
static
constexpr
index_t
kMPerBlock
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerBlock
=
BlockTile
::
at
(
number
<
1
>
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment