Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
abe875d6
Commit
abe875d6
authored
Oct 16, 2024
by
rocking
Browse files
unify layernorm api
parent
02236580
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
190 additions
and
270 deletions
+190
-270
example/ck_tile/02_layernorm2d/CMakeLists.txt
example/ck_tile/02_layernorm2d/CMakeLists.txt
+1
-1
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
+2
-15
example/ck_tile/02_layernorm2d/layernorm2d_fwd_api.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd_api.cpp
+187
-0
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp16.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp16.cpp
+0
-151
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp32.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp32.cpp
+0
-103
No files found.
example/ck_tile/02_layernorm2d/CMakeLists.txt
View file @
abe875d6
...
...
@@ -5,7 +5,7 @@ message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
file
(
GLOB INSTANCE_SRCS instances/*.cpp
)
add_executable
(
${
EXAMPLE_LAYERNORM2D_FWD
}
EXCLUDE_FROM_ALL example_layernorm2d_fwd.cpp
)
target_include_directories
(
${
EXAMPLE_LAYERNORM2D_FWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
EXAMPLE_LAYERNORM2D_FWD
}
PRIVATE layernorm2d_fwd_
fp16.cpp layernorm2d_fwd_fp32
.cpp
${
INSTANCE_SRCS
}
)
target_sources
(
${
EXAMPLE_LAYERNORM2D_FWD
}
PRIVATE layernorm2d_fwd_
api
.cpp
${
INSTANCE_SRCS
}
)
set
(
EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS
)
...
...
example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
View file @
abe875d6
...
...
@@ -2,9 +2,6 @@
#include "layernorm2d_fwd.hpp"
#include <cstring>
extern
float
layernorm2d_fwd_fp16
(
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
float
layernorm2d_fwd_fp32
(
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
...
...
@@ -95,18 +92,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
M
,
N
};
float
ave_time
=
.0
;
if
constexpr
(
std
::
is_same
<
DataType
,
ck_tile
::
fp16_t
>::
value
)
{
ave_time
=
layernorm2d_fwd_fp16
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
0
,
warmup
,
repeat
});
}
else
if
constexpr
(
std
::
is_same
<
DataType
,
float
>::
value
)
{
ave_time
=
layernorm2d_fwd_fp32
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
0
,
warmup
,
repeat
});
}
float
ave_time
=
layernorm2d_fwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
0
,
warmup
,
repeat
});
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
M
*
N
+
sizeof
(
GammaDataType
)
*
N
+
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
YDataType
)
*
M
*
N
;
...
...
example/ck_tile/02_layernorm2d/layernorm2d_fwd_api.cpp
0 → 100644
View file @
abe875d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
float
layernorm2d_fwd
(
layernorm2d_fwd_traits
t
,
layernorm2d_fwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
float
r
=
-
1
;
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
// Disable all vector 8fp16 read/write instances as it has performance issue regarding
// compiler
#if 0
if(a.N % 8 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(a, s);
}
else
{
return a.N == 2048 ? run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(a, s);
}
}
else if(a.N % 4 == 0)
#endif
if
(
a
.
N
%
4
==
0
)
{
if
(
a
.
N
<=
128
)
{
return
a
.
N
==
128
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
256
)
{
return
a
.
N
==
256
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
512
)
{
return
a
.
N
==
512
?
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
1024
)
{
return
a
.
N
==
1024
?
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
2048
)
{
return
a
.
N
==
2048
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
>
(
a
,
s
);
}
else
{
return
a
.
N
%
2048
==
0
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
,
true
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
,
true
>
(
a
,
s
);
}
}
else
if
(
a
.
N
%
2
==
0
)
{
if
(
a
.
N
<=
128
)
{
return
a
.
N
==
128
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
256
)
{
return
a
.
N
==
256
?
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
512
)
{
return
a
.
N
==
512
?
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
1024
)
{
return
a
.
N
==
1024
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
2048
)
{
return
a
.
N
==
2048
?
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
>
(
a
,
s
);
}
else
{
return
a
.
N
%
2048
==
0
?
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
,
true
>
(
a
,
s
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
,
true
>
(
a
,
s
);
}
}
}
else
if
(
t
.
data_type
.
compare
(
"fp32"
)
==
0
)
{
if
(
a
.
N
%
4
==
0
)
{
if
(
a
.
N
<=
128
)
{
return
a
.
N
==
128
?
run_layernorm
<
float
,
1
,
32
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
1
,
32
,
4
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
256
)
{
return
a
.
N
==
256
?
run_layernorm
<
float
,
1
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
1
,
64
,
4
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
512
)
{
return
a
.
N
==
512
?
run_layernorm
<
float
,
2
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
2
,
64
,
4
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
1024
)
{
return
a
.
N
==
1024
?
run_layernorm
<
float
,
4
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
4
,
64
,
4
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
2048
)
{
return
a
.
N
==
2048
?
run_layernorm
<
float
,
8
,
64
,
4
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
8
,
64
,
4
,
true
>
(
a
,
s
);
}
else
{
return
a
.
N
%
2048
==
0
?
run_layernorm
<
float
,
8
,
64
,
4
,
false
,
true
>
(
a
,
s
)
:
run_layernorm
<
float
,
8
,
64
,
4
,
true
,
true
>
(
a
,
s
);
}
}
else
if
(
a
.
N
%
2
==
0
)
{
if
(
a
.
N
<=
128
)
{
return
a
.
N
==
128
?
run_layernorm
<
float
,
1
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
1
,
64
,
2
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
256
)
{
return
a
.
N
==
256
?
run_layernorm
<
float
,
2
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
2
,
64
,
2
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
512
)
{
return
a
.
N
==
512
?
run_layernorm
<
float
,
4
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
4
,
64
,
2
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
1024
)
{
return
a
.
N
==
1024
?
run_layernorm
<
float
,
8
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
8
,
64
,
2
,
true
>
(
a
,
s
);
}
else
if
(
a
.
N
<=
2048
)
{
return
a
.
N
==
2048
?
run_layernorm
<
float
,
16
,
64
,
2
,
false
>
(
a
,
s
)
:
run_layernorm
<
float
,
16
,
64
,
2
,
true
>
(
a
,
s
);
}
else
{
return
a
.
N
%
2048
==
0
?
run_layernorm
<
float
,
16
,
64
,
2
,
false
,
true
>
(
a
,
s
)
:
run_layernorm
<
float
,
16
,
64
,
2
,
true
,
true
>
(
a
,
s
);
}
}
}
return
r
;
}
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp16.cpp
deleted
100644 → 0
View file @
02236580
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
float
layernorm2d_fwd_fp16
(
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
{
// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
#if 0
if(param.N % 8 == 0)
{
if(param.N <= 128)
{
return param.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(param, stream);
}
else if(param.N <= 256)
{
return param.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(param, stream);
}
else if(param.N <= 512)
{
return param.N == 512 ? run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(param, stream);
}
else if(param.N <= 1024)
{
return param.N == 1024 ? run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(param, stream);
}
else
{
return param.N == 2048 ? run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(param, stream)
: run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(param, stream);
}
}
else if(param.N % 4 == 0)
#endif
if
(
param
.
N
%
4
==
0
)
{
if
(
param
.
N
<=
128
)
{
return
param
.
N
==
128
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
32
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
256
)
{
return
param
.
N
==
256
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
512
)
{
return
param
.
N
==
512
?
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
1024
)
{
return
param
.
N
==
1024
?
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
2048
)
{
return
param
.
N
==
2048
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
{
return
param
.
N
%
2048
==
0
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
false
,
true
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
4
,
true
,
true
>
(
param
,
stream
);
}
}
else
if
(
param
.
N
%
2
==
0
)
{
if
(
param
.
N
<=
128
)
{
return
param
.
N
==
128
?
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
1
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
256
)
{
return
param
.
N
==
256
?
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
2
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
512
)
{
return
param
.
N
==
512
?
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
4
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
1024
)
{
return
param
.
N
==
1024
?
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
8
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
2048
)
{
return
param
.
N
==
2048
?
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
{
return
param
.
N
%
2048
==
0
?
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
false
,
true
>
(
param
,
stream
)
:
run_layernorm
<
ck_tile
::
fp16_t
,
16
,
64
,
2
,
true
,
true
>
(
param
,
stream
);
}
}
else
{
throw
std
::
runtime_error
(
"Sequence length sizes not supported!"
);
}
};
example/ck_tile/02_layernorm2d/layernorm2d_fwd_fp32.cpp
deleted
100644 → 0
View file @
02236580
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"
// clang-format off
extern
template
float
run_layernorm
<
float
,
1
,
32
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
2
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
2
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
4
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
4
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
8
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
false
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
32
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
1
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
2
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
2
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
4
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
4
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
8
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
8
,
64
,
4
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
extern
template
float
run_layernorm
<
float
,
16
,
64
,
2
,
true
>(
const
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
);
// clang-format on
float
layernorm2d_fwd_fp32
(
layernorm2d_fwd_args
&
param
,
ck_tile
::
stream_config
stream
)
{
if
(
param
.
N
%
4
==
0
)
{
if
(
param
.
N
<=
128
)
{
return
param
.
N
==
128
?
run_layernorm
<
float
,
1
,
32
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
1
,
32
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
256
)
{
return
param
.
N
==
256
?
run_layernorm
<
float
,
1
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
1
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
512
)
{
return
param
.
N
==
512
?
run_layernorm
<
float
,
2
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
2
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
1024
)
{
return
param
.
N
==
1024
?
run_layernorm
<
float
,
4
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
4
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
2048
)
{
return
param
.
N
==
2048
?
run_layernorm
<
float
,
8
,
64
,
4
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
8
,
64
,
4
,
true
>
(
param
,
stream
);
}
else
{
return
param
.
N
%
2048
==
0
?
run_layernorm
<
float
,
8
,
64
,
4
,
false
,
true
>
(
param
,
stream
)
:
run_layernorm
<
float
,
8
,
64
,
4
,
true
,
true
>
(
param
,
stream
);
}
}
else
if
(
param
.
N
%
2
==
0
)
{
if
(
param
.
N
<=
128
)
{
return
param
.
N
==
128
?
run_layernorm
<
float
,
1
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
1
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
256
)
{
return
param
.
N
==
256
?
run_layernorm
<
float
,
2
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
2
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
512
)
{
return
param
.
N
==
512
?
run_layernorm
<
float
,
4
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
4
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
1024
)
{
return
param
.
N
==
1024
?
run_layernorm
<
float
,
8
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
8
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
if
(
param
.
N
<=
2048
)
{
return
param
.
N
==
2048
?
run_layernorm
<
float
,
16
,
64
,
2
,
false
>
(
param
,
stream
)
:
run_layernorm
<
float
,
16
,
64
,
2
,
true
>
(
param
,
stream
);
}
else
{
return
param
.
N
%
2048
==
0
?
run_layernorm
<
float
,
16
,
64
,
2
,
false
,
true
>
(
param
,
stream
)
:
run_layernorm
<
float
,
16
,
64
,
2
,
true
,
true
>
(
param
,
stream
);
}
}
else
{
throw
std
::
runtime_error
(
"Sequence length sizes not supported!"
);
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment