Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
63214d01
Commit
63214d01
authored
Oct 12, 2024
by
letaoqin
Browse files
port layernorm
parent
29d384d0
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
908 additions
and
251 deletions
+908
-251
example/ck_tile/02_layernorm2d/CMakeLists.txt
example/ck_tile/02_layernorm2d/CMakeLists.txt
+18
-2
example/ck_tile/02_layernorm2d/README.md
example/ck_tile/02_layernorm2d/README.md
+2
-3
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
+145
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel.cpp
.../02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel.cpp
+28
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel_pad.cpp
...layernorm2d/instances/layernorm2d_fwd_fp16_kernel_pad.cpp
+28
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp32_kernel.cpp
.../02_layernorm2d/instances/layernorm2d_fwd_fp32_kernel.cpp
+34
-0
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
+39
-2
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp16.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp16.cpp
+151
-0
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp32.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp32.cpp
+103
-0
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
+78
-0
example/ck_tile/02_layernorm2d/perf_test.sh
example/ck_tile/02_layernorm2d/perf_test.sh
+32
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+235
-230
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
...ps/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
+12
-12
include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp
...e/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp
+3
-2
No files found.
example/ck_tile/02_layernorm2d/CMakeLists.txt
View file @
63214d01
set
(
EXAMPLE_LAYERNORM2D_FWD
"tile_example_layernorm2d_fwd"
)
# not using add_example_executable() to add this target, since we don't want this to have
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
# to be included in "make all/install/check"
add_executable
(
tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp
)
message
(
"adding example
${
EXAMPLE_LAYERNORM2D_FWD
}
"
)
target_compile_options
(
tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD
)
file
(
GLOB INSTANCE_SRCS instances/*.cpp
)
\ No newline at end of file
add_executable
(
${
EXAMPLE_LAYERNORM2D_FWD
}
EXCLUDE_FROM_ALL example_layernorm2d_fwd.cpp
)
target_include_directories
(
${
EXAMPLE_LAYERNORM2D_FWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
EXAMPLE_LAYERNORM2D_FWD
}
PRIVATE layernorm2d_fwd_fp16.cpp layernorm2d_fwd_fp32.cpp
${
INSTANCE_SRCS
}
)
set
(
EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS
)
# list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options
(
${
EXAMPLE_LAYERNORM2D_FWD
}
PRIVATE
${
EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS
}
)
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property
(
GLOBAL PROPERTY RULE_MESSAGES OFF
)
example/ck_tile/02_layernorm2d/README.md
View file @
63214d01
...
@@ -6,8 +6,7 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm
...
@@ -6,8 +6,7 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm
```
```
# in the root of ck_tile
# in the root of ck_tile
mkdir build && cd build
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_layernorm2d_fwd -j
make tile_example_layernorm2d_fwd -j
```
```
This will result in an executable
`build/bin/tile_example_layernorm2d_fwd`
This will result in an executable
`build/bin/tile_example_layernorm2d_fwd`
...
@@ -20,4 +19,4 @@ args:
...
@@ -20,4 +19,4 @@ args:
-e epsilon (default:1e-5)
-e epsilon (default:1e-5)
-v cpu validation or not (default:1)
-v cpu validation or not (default:1)
-prec precision (default:fp16)
-prec precision (default:fp16)
```
```
\ No newline at end of file
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
→
example/ck_tile/02_layernorm2d/
example_
layernorm2d_fwd.cpp
View file @
63214d01
...
@@ -2,61 +2,8 @@
...
@@ -2,61 +2,8 @@
#include "layernorm2d_fwd.hpp"
#include "layernorm2d_fwd.hpp"
#include <cstring>
#include <cstring>
// Host API implementation
extern
float
layernorm2d_fwd_fp16
(
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
t
,
extern
float
layernorm2d_fwd_fp32
(
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
layernorm2d_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
#ifdef SAVE_MEAN_INV_STD
using
MeanDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
#else
using
MeanDataType
=
ck_tile
::
null_type
;
using
InvStdDataType
=
ck_tile
::
null_type
;
#endif
using
ComputeDataType
=
float
;
using
thread_tile
=
ck_tile
::
sequence
<
4
,
4
>
;
using
warp_tile
=
ck_tile
::
sequence
<
8
,
128
>
;
using
block_tile
=
ck_tile
::
sequence
<
32
,
128
>
;
using
Shape
=
ck_tile
::
TileLayernorm2dShape
<
thread_tile
,
warp_tile
,
block_tile
>
;
using
PipelineProblem
=
ck_tile
::
BlockLayernorm2dFwdProblem
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
,
Shape
,
true
,
true
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
a
.
p_x
,
a
.
p_gamma
,
a
.
p_beta
,
a
.
p_y
,
a
.
p_mean
,
a
.
p_invStd
,
a
.
epsilon
,
a
.
M
,
a
.
N
);
const
dim3
grids
=
Kernel
::
GridSize
(
a
.
M
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
Shape
::
kMWarpPerBlock
*
Shape
::
kNWarpPerBlock
;
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
return
0
;
}
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -65,37 +12,37 @@ auto create_args(int argc, char* argv[])
...
@@ -65,37 +12,37 @@ auto create_args(int argc, char* argv[])
.
insert
(
"n"
,
"4096"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"m dimension"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"e"
,
"1e-5"
,
"epsilon"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
);
.
insert
(
"prec"
,
"fp32"
,
"precision"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
}
int
main
(
int
argc
,
char
*
argv
[])
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
float
epsilon
=
arg_parser
.
get_float
(
"e"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
M
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
N
=
arg_parser
.
get_int
(
"n"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
using
TypeConfig
=
LayerNormTypeConfig
<
DataType
>
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
YDataType
=
typename
TypeConfig
::
YDataType
;
using
GammaDataType
=
typename
TypeConfig
::
GammaDataType
;
using
BetaDataType
=
typename
TypeConfig
::
BetaDataType
;
using
XDataType
=
ck_tile
::
half_t
;
using
MeanDataType
=
ck_tile
::
null_type
;
using
YDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
#ifdef SAVE_MEAN_INV_STD
using
MeanDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
#else
using
MeanDataType
=
ck_tile
::
null_type
;
using
InvStdDataType
=
ck_tile
::
null_type
;
using
InvStdDataType
=
ck_tile
::
null_type
;
#endif
using
ComputeDataType
=
float
;
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
// host verify
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
M
,
N
});
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
M
,
N
});
...
@@ -108,25 +55,15 @@ int main(int argc, char* argv[])
...
@@ -108,25 +55,15 @@ int main(int argc, char* argv[])
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
M
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
M
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
M
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
M
});
#ifdef SAVE_MEAN_INV_STD
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_dev
({
M
});
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_dev
({
M
});
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
.5
f
,
.5
f
}(
beta_host
);
#endif
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
5.
f
,
5.
f
}(
x_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
5.
f
,
5.
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
5.
f
,
5.
f
}(
beta_host
);
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
#ifdef SAVE_MEAN_INV_STD
ck_tile
::
DeviceMem
mean_buf
(
mean_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
invStd_buf
(
invStd_host_dev
.
get_element_space_size_in_bytes
());
#endif
x_buf
.
ToDevice
(
x_host
.
data
());
x_buf
.
ToDevice
(
x_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
beta_buf
.
ToDevice
(
beta_host
.
data
());
beta_buf
.
ToDevice
(
beta_host
.
data
());
...
@@ -137,26 +74,30 @@ int main(int argc, char* argv[])
...
@@ -137,26 +74,30 @@ int main(int argc, char* argv[])
gamma_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
beta_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
y_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
mean_buf
.
GetDeviceBuffer
(),
invStd_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
#endif
epsilon
,
epsilon
,
M
,
M
,
N
};
N
};
float
ave_time
=
layernorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
});
float
ave_time
=
.0
;
if
constexpr
(
std
::
is_same
<
DataType
,
ck_tile
::
fp16_t
>::
value
)
{
ave_time
=
layernorm2d_fwd_fp16
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
0
,
warmup
,
repeat
});
}
else
if
constexpr
(
std
::
is_same
<
DataType
,
float
>::
value
)
{
ave_time
=
layernorm2d_fwd_fp32
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
0
,
warmup
,
repeat
});
}
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
M
*
N
+
sizeof
(
GammaDataType
)
*
N
+
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
M
*
N
+
sizeof
(
GammaDataType
)
*
N
+
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
YDataType
)
*
M
*
N
;
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
YDataType
)
*
M
*
N
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"["
<<
data_type
<<
"]"
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
M
<<
", n:"
<<
N
<<
", "
<<
ave_time
<<
"
m
s, "
<<
gb_per_sec
<<
" GB/s"
<<
" m:"
<<
M
<<
", n:"
<<
N
<<
", "
<<
ave_time
*
1.E6
<<
"
n
s, "
<<
gb_per_sec
<<
std
::
flush
;
<<
" GB/s"
<<
std
::
flush
;
bool
pass
=
true
;
bool
pass
=
true
;
...
@@ -176,18 +117,29 @@ int main(int argc, char* argv[])
...
@@ -176,18 +117,29 @@ int main(int argc, char* argv[])
pass
=
ck_tile
::
check_err
(
y_host_dev
,
y_host_ref
);
pass
=
ck_tile
::
check_err
(
y_host_dev
,
y_host_ref
);
#ifdef SAVE_MEAN_INV_STD
mean_buf
.
FromDevice
(
mean_host_dev
.
data
());
pass
&=
ck_tile
::
check_err
(
mean_host_dev
,
mean_host_ref
);
invStd_buf
.
FromDevice
(
invStd_host_dev
.
data
());
pass
&=
ck_tile
::
check_err
(
invStd_host_dev
,
invStd_host_ref
);
#endif
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
}
}
std
::
cout
<<
std
::
endl
<<
std
::
flush
;
std
::
cout
<<
std
::
endl
<<
std
::
flush
;
return
!
pass
;
return
pass
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
if
(
data_type
==
"fp16"
)
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
}
if
(
data_type
==
"fp32"
)
{
return
run
<
float
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
}
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel.cpp
0 → 100644
View file @
63214d01
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp16_kernel_pad.cpp
0 → 100644
View file @
63214d01
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_fwd_fp32_kernel.cpp
0 → 100644
View file @
63214d01
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
template
float
run_layernorm
<
float
,
1
,
32
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
2
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
2
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
4
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
4
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
32
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
1
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
2
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
2
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
4
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
4
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
View file @
63214d01
...
@@ -13,14 +13,51 @@ struct layernorm2d_fwd_traits
...
@@ -13,14 +13,51 @@ struct layernorm2d_fwd_traits
std
::
string
data_type
;
std
::
string
data_type
;
};
};
template
<
typename
DataType
>
struct
LayerNormTypeConfig
;
template
<
>
struct
LayerNormTypeConfig
<
ck_tile
::
half_t
>
{
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
#ifdef SAVE_MEAN_INV_STD
using
MeanDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
#else
using
MeanDataType
=
ck_tile
::
null_type
;
using
InvStdDataType
=
ck_tile
::
null_type
;
#endif
using
ComputeDataType
=
float
;
};
template
<
>
struct
LayerNormTypeConfig
<
float
>
{
using
XDataType
=
float
;
using
YDataType
=
float
;
using
GammaDataType
=
float
;
using
BetaDataType
=
float
;
#ifdef SAVE_MEAN_INV_STD
using
MeanDataType
=
float
;
using
InvStdDataType
=
float
;
#else
using
MeanDataType
=
ck_tile
::
null_type
;
using
InvStdDataType
=
ck_tile
::
null_type
;
#endif
using
ComputeDataType
=
float
;
};
struct
layernorm2d_fwd_args
struct
layernorm2d_fwd_args
{
{
const
void
*
p_x
;
const
void
*
p_x
;
const
void
*
p_gamma
;
const
void
*
p_gamma
;
const
void
*
p_beta
;
const
void
*
p_beta
;
void
*
p_y
;
void
*
p_y
;
void
*
p_mean
;
//
void* p_mean;
void
*
p_invStd
;
//
void* p_invStd;
float
epsilon
;
float
epsilon
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
N
;
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp16.cpp
0 → 100644
View file @
63214d01
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
float
layernorm2d_fwd_fp16
(
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
{
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
#if 0
if(param.N % 8 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(param, stream);
}
else
{
return param.N == 2048 ? run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(param, stream);
}
}
else if(param.N % 4 == 0)
#endif
if
(
param
.
N
%
4
==
0
)
{
if
(
param
.
N
<=
128
)
{
return
param
.
N
==
128
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
256
)
{
return
param
.
N
==
256
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
512
)
{
return
param
.
N
==
512
?
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
1024
)
{
return
param
.
N
==
1024
?
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
2048
)
{
return
param
.
N
==
2048
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
{
return
param
.
N
%
2048
==
0
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
,
true
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
,
true
>
(
param
,
stream
);
}
}
else
if
(
param
.
N
%
2
==
0
)
{
if
(
param
.
N
<=
128
)
{
return
param
.
N
==
128
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
256
)
{
return
param
.
N
==
256
?
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
512
)
{
return
param
.
N
==
512
?
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
1024
)
{
return
param
.
N
==
1024
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
2048
)
{
return
param
.
N
==
2048
?
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
{
return
param
.
N
%
2048
==
0
?
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
,
true
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
,
true
>
(
param
,
stream
);
}
}
else
{
throw
std
::
runtime_error
(
"Sequence length sizes not supported!"
);
}
};
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp32.cpp
0 → 100644
View file @
63214d01
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
extern
template
float
run_layernorm
<
float
,
1
,
32
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
2
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
2
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
4
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
4
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
8
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
32
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
2
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
2
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
4
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
4
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
8
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
float
layernorm2d_fwd_fp32
(
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
{
if
(
param
.
N
%
4
==
0
)
{
if
(
param
.
N
<=
128
)
{
return
param
.
N
==
128
?
run_layernorm
<
float
,
1
,
32
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
1
,
32
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
256
)
{
return
param
.
N
==
256
?
run_layernorm
<
float
,
1
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
1
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
512
)
{
return
param
.
N
==
512
?
run_layernorm
<
float
,
2
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
2
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
1024
)
{
return
param
.
N
==
1024
?
run_layernorm
<
float
,
4
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
4
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
2048
)
{
return
param
.
N
==
2048
?
run_layernorm
<
float
,
8
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
8
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
{
return
param
.
N
%
2048
==
0
?
run_layernorm
<
float
,
8
,
64
,
4
,
false
,
true
>
(
param
,
stream
)
:
run_layernorm
<
float
,
8
,
64
,
4
,
true
,
true
>
(
param
,
stream
);
}
}
else
if
(
param
.
N
%
2
==
0
)
{
if
(
param
.
N
<=
128
)
{
return
param
.
N
==
128
?
run_layernorm
<
float
,
1
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
1
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
256
)
{
return
param
.
N
==
256
?
run_layernorm
<
float
,
2
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
2
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
512
)
{
return
param
.
N
==
512
?
run_layernorm
<
float
,
4
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
4
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
1024
)
{
return
param
.
N
==
1024
?
run_layernorm
<
float
,
8
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
8
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
2048
)
{
return
param
.
N
==
2048
?
run_layernorm
<
float
,
16
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
16
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
{
return
param
.
N
%
2048
==
0
?
run_layernorm
<
float
,
16
,
64
,
2
,
false
,
true
>
(
param
,
stream
)
:
run_layernorm
<
float
,
16
,
64
,
2
,
true
,
true
>
(
param
,
stream
);
}
}
else
{
throw
std
::
runtime_error
(
"Sequence length sizes not supported!"
);
}
};
example/ck_tile/02_layernorm2d/layernorm_dispatch.hpp
0 → 100644
View file @
63214d01
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core/numeric/integer.hpp>
#include <ck_tile/host.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include "layernorm2d_fwd.hpp"
template
<
typename
InOutDataType
,
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kPadN
,
bool
kTwoPass
>
struct
layernorm_dispatch
{
static
constexpr
ck_tile
::
index_t
MRepeat
=
1
;
static_assert
(
NThread
<=
64
,
"We only support intra-wave reduction"
);
static
constexpr
ck_tile
::
index_t
WaveNum
=
NThread
/
16
;
// clang-format off
using
thread_tile
=
ck_tile
::
sequence
<
MRepeat
,
NRepeat
,
VectorAccessSize
>
;
using
warp_tile
=
ck_tile
::
sequence
<
MRepeat
*
64
/
NThread
,
NRepeat
*
NThread
*
VectorAccessSize
>
;
using
block_tile
=
ck_tile
::
sequence
<
MRepeat
*
WaveNum
*
64
/
NThread
,
NRepeat
*
NThread
*
VectorAccessSize
>
;
// clang-format on
using
Shape
=
ck_tile
::
TileLayernorm2dShape
<
thread_tile
,
warp_tile
,
block_tile
>
;
using
PipelineProblem
=
ck_tile
::
BlockLayernorm2dFwdProblem
<
typename
LayerNormTypeConfig
<
InOutDataType
>::
XDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
GammaDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
BetaDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
ComputeDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
YDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
MeanDataType
,
typename
LayerNormTypeConfig
<
InOutDataType
>::
InvStdDataType
,
Shape
,
kPadN
,
kTwoPass
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
static
float
Run
(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
{
using
k_
=
Kernel
;
const
dim3
grids
=
k_
::
GridSize
(
param
.
M
);
constexpr
dim3
blocks
=
k_
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
return
ck_tile
::
launch_kernel
(
stream
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
k_
{},
grids
,
blocks
,
0
,
param
.
p_x
,
param
.
p_gamma
,
param
.
p_beta
,
param
.
p_y
,
param
.
epsilon
,
param
.
M
,
param
.
N
));
};
};
template
<
typename
InOutDataType
,
ck_tile
::
index_t
NRepeat
,
ck_tile
::
index_t
NThread
,
ck_tile
::
index_t
VectorAccessSize
,
bool
kPadN
,
bool
kTwoPass
=
false
>
float
run_layernorm
(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
{
return
layernorm_dispatch
<
InOutDataType
,
NRepeat
,
NThread
,
VectorAccessSize
,
kPadN
,
kTwoPass
>::
Run
(
param
,
stream
);
};
example/ck_tile/02_layernorm2d/perf_test.sh
0 → 100644
View file @
63214d01
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec
=
fp32
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
80
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
128
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
144
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
168
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
184
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
256
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
288
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
344
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
376
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
448
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
512
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
924
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
1024
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
1078
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
1996
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
./bin/tile_example_layernorm2d_fwd
-m
=
700
-n
=
4080
-e
=
1e-12
-v
=
1
-prec
=
fp16
-repeat
=
1000
\ No newline at end of file
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
63214d01
...
@@ -31,14 +31,10 @@ struct Layernorm2dFwd
...
@@ -31,14 +31,10 @@ struct Layernorm2dFwd
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
BlockShape
::
kNPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
BlockShape
::
kNPerBlock
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
ck_tile
::
index_t
kNThreadPerWarp
=
Problem
::
BlockShape
::
kNThreadPerWarp
;
static
constexpr
ck_tile
::
index_t
kNThreadPerWarp
=
Problem
::
BlockShape
::
kNThreadPerWarp
;
static
constexpr
ck_tile
::
index_t
kNPerThread
=
Problem
::
BlockShape
::
kNPerThread
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
struct
Kargs
struct
Kargs
{
{
...
@@ -47,8 +43,8 @@ struct Layernorm2dFwd
...
@@ -47,8 +43,8 @@ struct Layernorm2dFwd
const
void
*
p_beta
;
const
void
*
p_beta
;
void
*
p_y
;
void
*
p_y
;
void
*
p_mean
;
//
void* p_mean;
void
*
p_invStd
;
//
void* p_invStd;
float
epsilon
;
float
epsilon
;
...
@@ -69,7 +65,10 @@ struct Layernorm2dFwd
...
@@ -69,7 +65,10 @@ struct Layernorm2dFwd
return
Kargs
{
p_x
,
p_gamma
,
p_beta
,
p_y
,
p_mean
,
p_invStd
,
epsilon
,
M
,
N
};
return
Kargs
{
p_x
,
p_gamma
,
p_beta
,
p_y
,
p_mean
,
p_invStd
,
epsilon
,
M
,
N
};
}
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
M
)
{
return
M
/
kMPerBlock
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
M
)
{
return
(
M
+
kMPerBlock
-
1
)
/
kMPerBlock
;
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
kBlockSize
;
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
kBlockSize
;
}
...
@@ -81,11 +80,11 @@ struct Layernorm2dFwd
...
@@ -81,11 +80,11 @@ struct Layernorm2dFwd
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
S
::
kMWarpPerBlock
,
S
::
kMThreadPerWarp
,
S
::
kMPerThread
>
,
tuple
<
sequence
<
S
::
kMWarpPerBlock
,
S
::
kMThreadPerWarp
,
S
::
kMPerThread
>
,
sequence
<
S
::
kNWarpPerBlock
,
S
::
kNThreadPerWarp
,
S
::
kNPerThread
>>
,
sequence
<
S
::
kNRepeat
,
S
::
kNWarpPerBlock
,
S
::
kNThreadPerWarp
,
S
::
kNPerThread
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
1
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
2
,
2
>>
{});
sequence
<
2
,
0
,
3
>>
{});
}
}
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBetaBlockTileDistribution
()
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBetaBlockTileDistribution
()
...
@@ -95,32 +94,26 @@ struct Layernorm2dFwd
...
@@ -95,32 +94,26 @@ struct Layernorm2dFwd
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<
S
::
kMWarpPerBlock
,
S
::
kMThreadPerWarp
>
,
sequence
<
S
::
kMWarpPerBlock
,
S
::
kMThreadPerWarp
>
,
tuple
<
sequence
<
S
::
kNWarpPerBlock
,
S
::
kNThreadPerWarp
,
S
::
kNPerThread
>>
,
tuple
<
sequence
<
S
::
kNRepeat
,
S
::
kNWarpPerBlock
,
S
::
kNThreadPerWarp
,
S
::
kNPerThread
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
1
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
>
,
sequence
<
1
,
1
>
,
sequence
<
2
>>
{});
sequence
<
0
,
3
>>
{});
}
}
CK_TILE_DEVICE
static
int
GetWelfordMaxCount
(
int
N
)
template
<
typename
Dstr
>
CK_TILE_DEVICE
static
constexpr
auto
GetNPerThread
(
Dstr
)
{
{
constexpr
ck_tile
::
index_t
kNThreadPerBlock
=
kNPerBlock
/
kNPerThread
;
constexpr
auto
nDstrSpan
=
Dstr
::
get_distributed_spans
().
template
at
<
1
>()
;
int
thread_id_n
=
get_thread_id
()
%
kNThreadPerBlock
;
using
Lengths
=
decltype
(
nDstrSpan
.
impl_
);
int
max_count
=
__builtin_amdgcn_readfirstlane
(
N
<
kNPerBlock
?
0
:
kNPerThread
*
(
N
/
kNPerBlock
));
int
n_per_block_tail_loop
=
__builtin_amdgcn_readfirstlane
(
N
-
max_count
*
kNThreadPerBlock
);
if
(
n_per_block_tail_loop
>
0
)
ck_tile
::
index_t
ret
=
1
;
{
int
thread_max_n
=
(
thread_id_n
+
1
)
*
kNPerThread
;
ck_tile
::
static_for
<
0
,
Lengths
::
size
(),
1
>
{}(
int
delta
=
thread_max_n
-
n_per_block_tail_loop
;
[
&
](
auto
idx
)
{
ret
*=
Lengths
::
template
at
(
idx
);
});
delta
=
clamp
(
thread_max_n
-
n_per_block_tail_loop
,
0
,
kNPerThread
);
max_count
+=
kNPerThread
-
delta
;
}
return
max_coun
t
;
return
re
t
;
}
}
template
<
typename
DistributedTensor
>
template
<
typename
DistributedTensor
>
...
@@ -141,127 +134,70 @@ struct Layernorm2dFwd
...
@@ -141,127 +134,70 @@ struct Layernorm2dFwd
return
out_dstr_tensor
;
return
out_dstr_tensor
;
}
}
template
<
typename
XBlockWindow
,
CK_TILE_HOST_DEVICE
static
constexpr
auto
typename
GammaBlockWindow
,
GetLastloopLayerNormIntraLaneReduceCount
(
index_t
NLength
)
typename
BetaBlockWindow
,
typename
YBlockWindow
,
typename
MeanBlockWindow
,
typename
InvStdBlockWindow
,
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
TwoPassLayernorm2dFwd
(
XBlockWindow
&
x_block_window
,
GammaBlockWindow
&
gamma_block_window
,
BetaBlockWindow
&
beta_block_window
,
YBlockWindow
&
y_block_window
,
MeanBlockWindow
&
mean_block_window
,
InvStdBlockWindow
&
inv_std_block_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
{
{
// TODO - Optimize tail loop to reduce move_tile_window()
using
S
=
typename
Problem
::
BlockShape
;
index_t
num_n_tile_iteration
=
// S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
N
,
kNPerBlock
));
auto
LastloopN
=
NLength
%
kNPerBlock
==
0
?
kNPerBlock
:
NLength
%
kNPerBlock
;
constexpr
auto
NThread
=
S
::
kNWarpPerBlock
*
S
::
kNThreadPerWarp
;
int
welford_max_count
=
GetWelfordMaxCount
(
N
);
auto
iNLane
=
get_thread_local_1d_id
()
%
NThread
;
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_count
};
auto
iN0
=
LastloopN
/
(
S
::
kNPerThread
*
S
::
kNThreadPerWarp
);
auto
iN1
=
(
LastloopN
%
(
S
::
kNPerThread
*
S
::
kNThreadPerWarp
))
/
S
::
kNPerThread
;
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
auto
N2
=
(
LastloopN
%
(
S
::
kNPerThread
*
S
::
kNThreadPerWarp
))
%
S
::
kNPerThread
;
auto
mean_compute_block_tensor
=
auto
iN3
=
iNLane
<
iN1
?
S
::
kNPerThread
:
iNLane
==
iN1
?
N2
:
0
;
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
auto
var_compute_block_tensor
=
return
iN0
*
S
::
kNPerThread
+
iN3
;
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
}
clear_tile
(
mean_compute_block_tensor
);
clear_tile
(
var_compute_block_tensor
);
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
thread_welford
(
x_block_tensor
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
move_tile_window
(
x_block_window
,
{
0
,
kNPerBlock
});
}
// TODO: support cross warp Welford
WarpMergeWelford
<
ComputeDataType
,
true
>
{}(
mean_compute_block_tensor
,
var_compute_block_tensor
,
thread_welford
.
cur_count_
);
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
if
constexpr
(
kSaveMean
)
store_tile
(
mean_block_window
,
cast_tile
<
MeanDataType
>
(
mean_compute_block_tensor
));
if
constexpr
(
kSaveInvStd
)
store_tile
(
inv_std_block_window
,
cast_tile
<
InvStdDataType
>
(
inv_std_compute_block_tensor
));
// reverse read x to reuse cache
ck_tile
::
index_t
stride_to_right_most_window
=
N
%
kNPerBlock
==
0
?
N
-
kNPerBlock
:
N
-
N
%
kNPerBlock
;
move_tile_window
(
x_block_window
,
{
0
,
-
kNPerBlock
});
template
<
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
move_tile_window
(
gamma_block_window
,
{
stride_to_right_most_window
});
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
OnePassLayernorm2dFwd
(
const
XDataType
*
p_x
,
move_tile_window
(
beta_block_window
,
{
stride_to_right_most_window
});
const
GammaDataType
*
p_gamma
,
move_tile_window
(
y_block_window
,
{
0
,
stride_to_right_most_window
});
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
,
const
ComputeDataType
epsilon
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
)
const
{
using
S
=
typename
Problem
::
BlockShape
;
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
// Normalization
const
auto
x_m_n
=
[
&
]()
{
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
const
auto
x_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
{
p_x
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
const
auto
gamma_block_tensor
=
load_tile
(
gamma_block_window
);
const
auto
beta_block_tensor
=
load_tile
(
beta_block_window
);
constexpr
auto
x_spans
=
decltype
(
x_block_tensor
)
::
get_distributed_spans
();
return
pad_tensor_view
(
x_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
false
,
kPadN
>
{});
}();
auto
y_block_tensor
=
const
auto
gamma_n
=
[
&
]()
{
make_static_distributed_tensor
<
YDataType
>
(
x_block_tensor
.
get_tile_distribution
());
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_gamma
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
sweep_tile_span
(
x_spans
[
I1
],
[
&
](
auto
idx1
)
{
return
pad_tensor_view
(
constexpr
auto
j_idx
=
make_tuple
(
idx1
);
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
const
auto
gamma
=
type_convert
<
ComputeDataType
>
(
gamma_block_tensor
[
j_idx
]);
}();
const
auto
beta
=
type_convert
<
ComputeDataType
>
(
beta_block_tensor
[
j_idx
]);
sweep_tile_span
(
x_spans
[
I0
],
[
&
](
auto
idx0
)
{
const
auto
beta_n
=
[
&
](
)
{
const
expr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
p_beta
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{}
);
const
auto
mean
=
mean_compute_block_tensor
[
i_idx
];
return
pad_tensor_view
(
const
auto
inv_std
=
inv_std_compute_block_tensor
[
i_idx
];
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
}();
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_block_tensor
[
i_j_idx
]);
const
auto
iM
=
get_block_id
()
*
kMPerBlock
;
auto
y
=
(
x
-
mean
)
*
inv_std
*
gamma
+
beta
;
y_block_tensor
(
i_j_idx
)
=
type_convert
<
YDataType
>
(
y
);
constexpr
auto
xDstr
=
MakeXBlockTileDistribution
();
});
});
store_tile
(
y_block_window
,
y_block_tensor
);
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
},
xDstr
);
move_tile_window
(
x_block_window
,
{
0
,
-
kNPerBlock
});
auto
intra_thread_count_last
=
GetLastloopLayerNormIntraLaneReduceCount
(
N
);
move_tile_window
(
gamma_block_window
,
{
-
kNPerBlock
});
move_tile_window
(
beta_block_window
,
{
-
kNPerBlock
});
move_tile_window
(
y_block_window
,
{
0
,
-
kNPerBlock
});
}
}
template
<
typename
XBlockWindow
,
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
intra_thread_count_last
};
typename
GammaBlockWindow
,
typename
BetaBlockWindow
,
typename
YBlockWindow
,
typename
MeanBlockWindow
,
typename
InvStdBlockWindow
,
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
OnePassLayernorm2dFwd
(
XBlockWindow
&
x_block_window
,
GammaBlockWindow
&
gamma_block_window
,
BetaBlockWindow
&
beta_block_window
,
YBlockWindow
&
y_block_window
,
MeanBlockWindow
&
mean_block_window
,
InvStdBlockWindow
&
inv_std_block_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
{
int
welford_max_count
=
GetWelfordMaxCount
(
N
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_count
};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
auto
mean_compute_block_tensor
=
auto
mean_compute_block_tensor
=
...
@@ -274,21 +210,37 @@ struct Layernorm2dFwd
...
@@ -274,21 +210,37 @@ struct Layernorm2dFwd
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
thread_welford
(
x_block_tensor
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
thread_welford
(
x_block_tensor
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
constexpr
auto
gammaDstr
=
MakeGammaBetaBlockTileDistribution
();
constexpr
auto
betaDstr
=
gammaDstr
;
auto
gamma_block_window
=
make_tile_window
(
gamma_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
gammaDstr
);
auto
beta_block_window
=
make_tile_window
(
beta_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
betaDstr
);
const
auto
gamma_block_tensor
=
load_tile
(
gamma_block_window
);
const
auto
beta_block_tensor
=
load_tile
(
beta_block_window
);
// TODO: support cross warp Welford
// TODO: support cross warp Welford
WarpMergeWelford
<
ComputeDataType
,
true
>
{}(
WarpMergeWelford
<
ComputeDataType
,
true
>
{}(
mean_compute_block_tensor
,
var_compute_block_tensor
,
thread_welford
.
cur_count_
);
mean_compute_block_tensor
,
var_compute_block_tensor
,
thread_welford
.
cur_count_
);
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
if
constexpr
(
kSaveMean
)
// TODO: Extract normalize pipeline
store_tile
(
mean_block_window
,
cast_tile
<
MeanDataType
>
(
mean_compute_block_tensor
));
const
auto
y_m_n
=
[
&
]()
{
if
constexpr
(
kSaveInvStd
)
const
auto
y_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
store_tile
(
inv_std_block_window
,
p_y
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
cast_tile
<
InvStdDataType
>
(
inv_std_compute_block_tensor
));
// normalize
return
pad_tensor_view
(
y_dram_naive
,
const
auto
gamma_block_tensor
=
load_tile
(
gamma_block_window
);
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
const
auto
beta_block_tensor
=
load_tile
(
beta_block_window
);
sequence
<
false
,
kPadN
>
{});
}();
auto
y_block_window
=
make_tile_window
(
y_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
});
constexpr
auto
x_spans
=
decltype
(
x_block_tensor
)
::
get_distributed_spans
();
constexpr
auto
x_spans
=
decltype
(
x_block_tensor
)
::
get_distributed_spans
();
...
@@ -317,43 +269,42 @@ struct Layernorm2dFwd
...
@@ -317,43 +269,42 @@ struct Layernorm2dFwd
store_tile
(
y_block_window
,
y_block_tensor
);
store_tile
(
y_block_window
,
y_block_tensor
);
}
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
template
<
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
TwoPassLayernorm2dFwd
(
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
,
const
ComputeDataType
epsilon
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
)
const
{
{
using
S
=
typename
Problem
::
BlockShape
;
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
const
auto
x_m_n
=
[
&
]()
{
const
auto
x_m_n
=
[
&
]()
{
const
auto
x_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
x_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
p_x
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
N
,
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
x_dram_naive
,
return
pad_tensor_view
(
x_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
sequence
<
false
,
true
>
{});
}();
}();
const
auto
gamma_n
=
[
&
]()
{
const
auto
gamma_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
p_gamma
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
make_tuple
(
kargs
.
N
),
make_tuple
(
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
true
>
{});
}();
}();
const
auto
beta_n
=
[
&
]()
{
const
auto
beta_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
BetaDataType
*>
(
kargs
.
p_beta
),
p_beta
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
make_tuple
(
kargs
.
N
),
make_tuple
(
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
true
>
{});
}();
}();
const
auto
iM
=
get_block_id
()
*
kMPerBlock
;
const
auto
iM
=
get_block_id
()
*
kMPerBlock
;
...
@@ -363,17 +314,52 @@ struct Layernorm2dFwd
...
@@ -363,17 +314,52 @@ struct Layernorm2dFwd
auto
x_block_window
=
make_tile_window
(
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
},
xDstr
);
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
},
xDstr
);
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
((
N
+
kNPerBlock
-
1
)
/
kNPerBlock
);
auto
intra_thread_count
=
S
::
kNRepeat
*
S
::
kNPerThread
*
(
num_n_tile_iteration
-
1
);
auto
intra_thread_count_last
=
GetLastloopLayerNormIntraLaneReduceCount
(
N
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
intra_thread_count
};
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford_last
{
intra_thread_count_last
};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
auto
mean_compute_block_tensor
=
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
auto
var_compute_block_tensor
=
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
clear_tile
(
mean_compute_block_tensor
);
clear_tile
(
var_compute_block_tensor
);
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
-
1
;
++
iN
)
{
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
thread_welford
(
x_block_tensor
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
move_tile_window
(
x_block_window
,
{
0
,
kNPerBlock
});
}
const
auto
x_block_tensor_
=
load_tile
(
x_block_window
);
thread_welford_last
.
cur_count_
+=
intra_thread_count
;
thread_welford_last
.
max_count_
+=
intra_thread_count
;
thread_welford_last
(
x_block_tensor_
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
thread_welford
.
cur_count_
+=
intra_thread_count_last
;
// TODO: support cross warp Welford
WarpMergeWelford
<
ComputeDataType
,
true
>
{}(
mean_compute_block_tensor
,
var_compute_block_tensor
,
thread_welford
.
cur_count_
);
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
// TODO: Extract normalize pipeline
const
auto
y_m_n
=
[
&
]()
{
const
auto
y_m_n
=
[
&
]()
{
const
auto
y_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
y_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
p_y
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
S
::
kNPerThread
>
{},
number
<
1
>
{});
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
N
,
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
y_dram_naive
,
return
pad_tensor_view
(
y_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
sequence
<
false
,
true
>
{});
}();
}();
auto
y_block_window
=
make_tile_window
(
auto
y_block_window
=
make_tile_window
(
...
@@ -385,67 +371,86 @@ struct Layernorm2dFwd
...
@@ -385,67 +371,86 @@ struct Layernorm2dFwd
auto
gamma_block_window
=
auto
gamma_block_window
=
make_tile_window
(
gamma_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
gammaDstr
);
make_tile_window
(
gamma_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
gammaDstr
);
auto
beta_block_window
=
make_tile_window
(
auto
beta_block_window
=
beta_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
0
},
betaDstr
);
make_tile_window
(
beta_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
betaDstr
);
auto
mean_block_window
=
[
&
]()
{
if
constexpr
(
kSaveMean
)
{
const
auto
mean_m
=
[
&
]()
{
const
auto
mean_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
MeanDataType
*>
(
kargs
.
p_mean
),
make_tuple
(
kargs
.
M
),
number
<
1
>
{});
return
pad_tensor_view
(
mean_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
mean_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
kMPerBlock
>
{}));
}();
auto
inv_std_block_window
=
[
&
]()
{
// reverse read x to reuse cache
if
constexpr
(
kSaveInvStd
)
ck_tile
::
index_t
stride_to_right_most_window
=
{
N
%
kNPerBlock
==
0
?
N
-
kNPerBlock
:
N
-
N
%
kNPerBlock
;
const
auto
inv_std_m
=
[
&
]()
{
const
auto
inv_std_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
InvStdDataType
*>
(
kargs
.
p_invStd
),
make_tuple
(
kargs
.
M
),
number
<
1
>
{});
return
pad_tensor_view
(
inv_std_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
inv_std_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
kMPerBlock
>
{}));
}();
if
(
kargs
.
N
<=
kNPerBlock
)
move_tile_window
(
gamma_block_window
,
{
stride_to_right_most_window
});
OnePassLayernorm2dFwd
(
x_block_window
,
move_tile_window
(
beta_block_window
,
{
stride_to_right_most_window
});
gamma_block_window
,
move_tile_window
(
y_block_window
,
{
0
,
stride_to_right_most_window
});
beta_block_window
,
y_block_window
,
// Normalization
mean_block_window
,
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
inv_std_block_window
,
{
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
kargs
.
N
);
const
auto
gamma_block_tensor
=
load_tile
(
gamma_block_window
);
const
auto
beta_block_tensor
=
load_tile
(
beta_block_window
);
constexpr
auto
x_spans
=
decltype
(
x_block_tensor
)
::
get_distributed_spans
();
auto
y_block_tensor
=
make_static_distributed_tensor
<
YDataType
>
(
x_block_tensor
.
get_tile_distribution
());
sweep_tile_span
(
x_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx1
);
const
auto
gamma
=
type_convert
<
ComputeDataType
>
(
gamma_block_tensor
[
j_idx
]);
const
auto
beta
=
type_convert
<
ComputeDataType
>
(
beta_block_tensor
[
j_idx
]);
sweep_tile_span
(
x_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
mean
=
mean_compute_block_tensor
[
i_idx
];
const
auto
inv_std
=
inv_std_compute_block_tensor
[
i_idx
];
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_block_tensor
[
i_j_idx
]);
auto
y
=
(
x
-
mean
)
*
inv_std
*
gamma
+
beta
;
y_block_tensor
(
i_j_idx
)
=
type_convert
<
YDataType
>
(
y
);
});
});
store_tile
(
y_block_window
,
y_block_tensor
);
move_tile_window
(
x_block_window
,
{
0
,
-
kNPerBlock
});
move_tile_window
(
gamma_block_window
,
{
-
kNPerBlock
});
move_tile_window
(
beta_block_window
,
{
-
kNPerBlock
});
move_tile_window
(
y_block_window
,
{
0
,
-
kNPerBlock
});
}
}
CK_TILE_DEVICE
void
operator
()(
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_y
,
const
ComputeDataType
epsilon
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
)
const
{
if
constexpr
(
kTwoPass
)
{
TwoPassLayernorm2dFwd
(
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
const
ComputeDataType
>
(
epsilon
),
M
,
N
);
}
else
else
TwoPassLayernorm2dFwd
(
x_block_window
,
{
gamma_block_window
,
beta_block_window
,
OnePassLayernorm2dFwd
(
static_cast
<
const
XDataType
*>
(
p_x
),
y_block_window
,
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
mean_block_window
,
static_cast
<
const
BetaDataType
*>
(
p_beta
),
inv_std_block_window
,
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
static_cast
<
const
ComputeDataType
>
(
epsilon
),
kargs
.
N
);
M
,
N
);
}
}
}
};
};
...
...
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
View file @
63214d01
...
@@ -15,20 +15,20 @@ template <typename XDataType_,
...
@@ -15,20 +15,20 @@ template <typename XDataType_,
typename
MeanDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
InvStdDataType_
,
typename
BlockShape_
,
typename
BlockShape_
,
bool
kPad
M
_
,
bool
kPad
N
_
,
bool
k
PadN
_
>
bool
k
TwoPass
_
>
struct
BlockLayernorm2dFwdProblem
struct
BlockLayernorm2dFwdProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kPad
M
=
kPad
M
_
;
static
constexpr
bool
kPad
N
=
kPad
N
_
;
static
constexpr
bool
k
PadN
=
kPadN
_
;
static
constexpr
bool
k
TwoPass
=
kTwoPass
_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/tile_layernorm2d_fwd_shape.hpp
View file @
63214d01
...
@@ -12,13 +12,14 @@ template <typename ThreadTile, // Sequence<...
...
@@ -12,13 +12,14 @@ template <typename ThreadTile, // Sequence<...
struct
TileLayernorm2dShape
struct
TileLayernorm2dShape
{
{
static
constexpr
index_t
kMPerThread
=
ThreadTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kMPerThread
=
ThreadTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerThread
=
ThreadTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kNRepeat
=
ThreadTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kNPerThread
=
ThreadTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
kMPerWarp
=
WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kMPerWarp
=
WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerWarp
=
WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kNPerWarp
=
WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMThreadPerWarp
=
kMPerWarp
/
kMPerThread
;
static
constexpr
index_t
kMThreadPerWarp
=
kMPerWarp
/
kMPerThread
;
static
constexpr
index_t
kNThreadPerWarp
=
kNPerWarp
/
kNPerThread
;
static
constexpr
index_t
kNThreadPerWarp
=
kNPerWarp
/
kNPerThread
/
kNRepeat
;
static
constexpr
index_t
kMPerBlock
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kMPerBlock
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerBlock
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kNPerBlock
=
BlockTile
::
at
(
number
<
1
>
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment