Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3289e656
Commit
3289e656
authored
Feb 06, 2025
by
AMD-dteng
Browse files
update dweight cal
parent
b0b399d9
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
687 additions
and
206 deletions
+687
-206
cmd
cmd
+1
-1
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_api.cpp
.../ck_tile/02_layernorm2d/instances/layernorm2d_bwd_api.cpp
+2
-2
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_bf16_n64_n128_instance.cpp
...rm2d/instances/layernorm2d_bwd_bf16_n64_n128_instance.cpp
+32
-23
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_instance_common.hpp
...layernorm2d/instances/layernorm2d_bwd_instance_common.hpp
+8
-3
example/ck_tile/02_layernorm2d/layernorm2d_bwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_bwd.cpp
+58
-26
example/ck_tile/02_layernorm2d/layernorm2d_bwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_bwd.hpp
+46
-19
include/ck_tile/host/reference/reference_layernorm2d_bwd.hpp
include/ck_tile/host/reference/reference_layernorm2d_bwd.hpp
+6
-5
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+4
-2
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_dgamma_beta_kernel.hpp
...layernorm2d/kernel/layernorm2d_bwd_dgamma_beta_kernel.hpp
+225
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_dx_kernel.hpp
...tile/ops/layernorm2d/kernel/layernorm2d_bwd_dx_kernel.hpp
+3
-3
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
+47
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_dx_one_pass.hpp
...rnorm2d/pipeline/layernorm2d_bwd_pipeline_dx_one_pass.hpp
+180
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_dx_two_pass.hpp
...rnorm2d/pipeline/layernorm2d_bwd_pipeline_dx_two_pass.hpp
+4
-3
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp
...ernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp
+71
-119
No files found.
cmd
View file @
3289e656
...
@@ -2,4 +2,4 @@ make tile_example_layernorm2d_bwd -j 200
...
@@ -2,4 +2,4 @@ make tile_example_layernorm2d_bwd -j 200
./bin/tile_example_layernorm2d_bwd -m=2048 -n=2048
./bin/tile_example_layernorm2d_bwd -m=2048 -n=2048
rocprofv2 --kernel-trace -d /home/dteng/PerfProf/out -o kernel_trace
rocprofv2 --kernel-trace -d /home/dteng/PerfProf/out -o kernel_trace
rocprofv2 -i /home/dteng/PerfProf/input.txt --plugin att auto -d /home/dteng/PerfProf/out
rocprofv2 -i /home/dteng/PerfProf/input.txt --plugin att auto -d /home/dteng/PerfProf/out
rocprofv2 -i /home/dteng/PerfProf/input.txt --plugin att auto --mode csv -d /home/dteng/PerfProf/out
rocprofv2 -i /home/dteng/PerfProf/input.txt --plugin att auto --mode csv -d /home/dteng/PerfProf/out
\ No newline at end of file
\ No newline at end of file
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_api.cpp
View file @
3289e656
...
@@ -10,11 +10,11 @@ float layernorm2d_bwd(layernorm2d_bwd_traits t,
...
@@ -10,11 +10,11 @@ float layernorm2d_bwd(layernorm2d_bwd_traits t,
{
{
float
r
=
-
1
;
float
r
=
-
1
;
if
(
t
.
d
ata
_t
ype
.
compare
(
"fp16"
)
==
0
)
if
(
t
.
D
ata
T
ype
.
compare
(
"fp16"
)
==
0
)
{
{
return
layernorm2d_bwd_b16_
<
ck_tile
::
fp16_t
>
{}(
t
,
a
,
s
);
return
layernorm2d_bwd_b16_
<
ck_tile
::
fp16_t
>
{}(
t
,
a
,
s
);
}
}
else
if
(
t
.
d
ata
_t
ype
.
compare
(
"bf16"
)
==
0
)
else
if
(
t
.
D
ata
T
ype
.
compare
(
"bf16"
)
==
0
)
{
{
return
layernorm2d_bwd_b16_
<
ck_tile
::
bf16_t
>
{}(
t
,
a
,
s
);
return
layernorm2d_bwd_b16_
<
ck_tile
::
bf16_t
>
{}(
t
,
a
,
s
);
}
}
...
...
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_bf16_n64_n128_instance.cpp
View file @
3289e656
...
@@ -5,33 +5,42 @@
...
@@ -5,33 +5,42 @@
#include "layernorm2d_bwd_instance_common.hpp"
#include "layernorm2d_bwd_instance_common.hpp"
// clang-format off
// clang-format off
// rm rn tm tn vn pd
// rm rn tm tn
vm
vn pd
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 64, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 64, 1,
1,
true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 64, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 64, 1,
1,
true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 128, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 128, 1,
1,
true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 128, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 128, 1,
1,
true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 128, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 128,
1,
8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 128, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 128,
1,
8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 1,
1,
true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 1,
1,
true>>(const S&, A);
// large m
// large m
//
template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 16,
8
, true>>(const S&, A);
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
2
,
4
,
16
,
1
,
8
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
//
template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 16,
8
, true>>(const S&, A);
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
16
,
1
,
8
,
true
,
false
,
true
>
>
(
const
S
&
,
A
);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 3, 8, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 3, 8, 8,
1,
8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 3, 8, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 3, 8, 8,
1,
8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 4, 32, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 4, 32, 8,
1,
8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 4, 32, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 4, 32, 8,
1,
8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 8, 64, 4, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 8, 64, 4,
1,
8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 8, 64, 4, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 8, 64, 4,
1,
8, true>>(const S&, A);
// large n
// large n
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 32, 4, 16, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 32, 4, 16,
1,
8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 32, 4, 16, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 32, 4, 16,
1,
8, true>>(const S&, A);
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
64
,
8
,
true
>
>
(
const
S
&
,
A
);
//
template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 64,
1,
8, true>>(const S&, A);
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
64
,
8
,
true
>
>
(
const
S
&
,
A
);
//
template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 64,
1,
8, true>>(const S&, A);
// two pass
// two pass
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
2
,
32
,
8
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
2
,
32
,
1
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
2
,
32
,
8
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
2
,
32
,
1
,
8
,
true
,
true
,
true
>
>
(
const
S
&
,
A
);
// Weight Grad
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 64, 1, 1, 1, true, false, false>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 64, 1, 1, 1, true, false, false>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 32, 32, 8, 2, true, false, false>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 32, 32, 8, 2, true, false, false>>(const S&, A);
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
4
,
32
,
1
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
32
,
1
,
1
,
true
,
false
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_instance_common.hpp
View file @
3289e656
...
@@ -27,9 +27,14 @@ float layernorm2d_bwd_(const S& s, A a)
...
@@ -27,9 +27,14 @@ float layernorm2d_bwd_(const S& s, A a)
typename
Traits_
::
Shape
,
typename
Traits_
::
Shape
,
Traits_
::
kPadN
>
;
Traits_
::
kPadN
>
;
using
Pipeline
=
ck_tile
::
Layernorm2dBwdGammaBetaPipelineTwoPass
<
PipelineProblem
>
;
using
DXOnePassPipeline
=
ck_tile
::
Layernorm2dBwdDXOnePassPipeline
<
PipelineProblem
>
;
using
DXTwoPassPipeline
=
ck_tile
::
Layernorm2dBwdDXTwoPassPipeline
<
PipelineProblem
>
;
using
Kernel
=
ck_tile
::
Layernorm2dBwdGammaBeta
<
Pipeline
>
;
using
DXPipeline
=
std
::
conditional_t
<
Traits_
::
kTwoPass
,
DXTwoPassPipeline
,
DXOnePassPipeline
>
;
using
DGammaBetaPipeline
=
ck_tile
::
Layernorm2dBwdDGammaBetaPipeline
<
PipelineProblem
>
;
using
DXKernel
=
ck_tile
::
Layernorm2dBwdDX
<
DXPipeline
>
;
using
DGammaBetaKernel
=
ck_tile
::
Layernorm2dBwdDGammaBeta
<
DGammaBetaPipeline
>
;
using
Kernel
=
std
::
conditional_t
<
Traits_
::
kCalData
,
DXKernel
,
DGammaBetaKernel
>
;
const
dim3
grids
=
Kernel
::
GridSize
(
a
);
const
dim3
grids
=
Kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
...
...
example/ck_tile/02_layernorm2d/layernorm2d_bwd.cpp
View file @
3289e656
...
@@ -25,11 +25,12 @@ auto create_args(int argc, char* argv[])
...
@@ -25,11 +25,12 @@ auto create_args(int argc, char* argv[])
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"mode"
,
"0"
,
"0: both data grad & weight grad, 1: data grad only, 2: weight grad only"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
.
insert
(
"warmup"
,
"
5
"
,
"cold iter"
)
.
insert
(
"warmup"
,
"
0
"
,
"cold iter"
)
.
insert
(
"repeat"
,
"
20
"
,
"hot iter"
);
.
insert
(
"repeat"
,
"
1
"
,
"hot iter"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
@@ -44,6 +45,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -44,6 +45,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
stride
<
0
)
if
(
stride
<
0
)
stride
=
n
;
stride
=
n
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
mode
=
arg_parser
.
get_int
(
"mode"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
...
@@ -70,13 +72,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -70,13 +72,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host
({
m
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host
({
m
});
ck_tile
::
index_t
blockM
=
layernorm2d_bwd_block_m
<
XDataType
>
();
// ck_tile::index_t blockM = layernorm2d_bwd_block_m<XDataType>();
ck_tile
::
index_t
reduce_m
=
(
m
+
blockM
-
1
)
/
blockM
;
// ck_tile::index_t reduce_m = (m + blockM - 1) / blockM;
ck_tile
::
HostTensor
<
GammaDataType
>
dgamma_host_dev
({
reduce_m
,
n
});
// ck_tile::HostTensor<GammaDataType> dgamma_host_dev({reduce_m, n});
ck_tile
::
HostTensor
<
BetaDataType
>
dbeta_host_dev
({
reduce_m
,
n
});
// ck_tile::HostTensor<BetaDataType> dbeta_host_dev({reduce_m, n});
ck_tile
::
HostTensor
<
GammaDataType
>
dgamma_host_dev
({
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
dbeta_host_dev
({
n
});
ck_tile
::
HostTensor
<
XDataType
>
dx_host_dev
({
m
,
n
});
ck_tile
::
HostTensor
<
XDataType
>
dx_host_dev
({
m
,
n
});
ck_tile
::
HostTensor
<
GammaDataType
>
dgamma_host_ref
({
reduce_m
,
n
});
// ck_tile::HostTensor<GammaDataType> dgamma_host_ref({reduce_m, n});
ck_tile
::
HostTensor
<
BetaDataType
>
dbeta_host_ref
({
reduce_m
,
n
});
// ck_tile::HostTensor<BetaDataType> dbeta_host_ref({reduce_m, n});
ck_tile
::
HostTensor
<
GammaDataType
>
dgamma_host_ref
({
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
dbeta_host_ref
({
n
});
ck_tile
::
HostTensor
<
XDataType
>
dx_host_ref
({
m
,
n
});
ck_tile
::
HostTensor
<
XDataType
>
dx_host_ref
({
m
,
n
});
//tmp
//tmp
...
@@ -117,7 +123,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -117,7 +123,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
cout
<<
"["
<<
data_type
<<
"]"
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
layernorm2d_bwd_traits
traits
{
data_type
};
layernorm2d_bwd_traits
traits_data
{
data_type
,
true
};
layernorm2d_bwd_traits
traits_weight
{
data_type
,
false
};
layernorm2d_bwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
layernorm2d_bwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
dy_buf
.
GetDeviceBuffer
(),
dy_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
...
@@ -126,8 +133,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -126,8 +133,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
dgamma_buf
.
GetDeviceBuffer
(),
dgamma_buf
.
GetDeviceBuffer
(),
dbeta_buf
.
GetDeviceBuffer
(),
dbeta_buf
.
GetDeviceBuffer
(),
dx_buf
.
GetDeviceBuffer
(),
dx_buf
.
GetDeviceBuffer
(),
//tmp
//
tmp
ds_buf
.
GetDeviceBuffer
(),
ds_buf
.
GetDeviceBuffer
(),
db_buf
.
GetDeviceBuffer
(),
db_buf
.
GetDeviceBuffer
(),
...
@@ -135,8 +142,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -135,8 +142,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
n
,
n
,
stride
};
stride
};
float
ave_time
=
layernorm2d_bwd
(
float
ave_time
=
0
;
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
mode
!=
2
)
{
ave_time
=
layernorm2d_bwd
(
traits_data
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
}
if
(
mode
!=
1
)
{
ave_time
+=
layernorm2d_bwd
(
traits_weight
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
}
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
sizeof
(
MeanDataType
)
*
m
+
sizeof
(
InvStdDataType
)
*
m
+
sizeof
(
YDataType
)
*
m
*
n
+
sizeof
(
XDataType
);
sizeof
(
MeanDataType
)
*
m
+
sizeof
(
InvStdDataType
)
*
m
+
sizeof
(
YDataType
)
*
m
*
n
+
sizeof
(
XDataType
);
...
@@ -167,22 +184,37 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -167,22 +184,37 @@ bool run(const ck_tile::ArgParser& arg_parser)
db_buf
.
FromDevice
(
db_host_dev
.
data
());
db_buf
.
FromDevice
(
db_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
();
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
();
// pass = ck_tile::check_err(
if
(
mode
!=
2
)
// dgamma_host_dev, dgamma_host_ref, std::string("GAMMA OUT Error: Incorrect results!"), rtol, atol);
{
// pass &= ck_tile::check_err(
pass
=
ck_tile
::
check_err
(
// dbeta_host_dev, dbeta_host_ref, std::string("BETA OUT Error: Incorrect results!"), rtol, atol);
dx_host_dev
,
dx_host_ref
,
std
::
string
(
"DX OUT Error: Incorrect results!"
),
rtol
,
pass
&=
ck_tile
::
check_err
(
atol
);
dx_host_dev
,
dx_host_ref
,
std
::
string
(
"DX OUT Error: Incorrect results!"
),
rtol
,
atol
);
// tmp
//tmp
// pass &= ck_tile::check_err(
// pass &= ck_tile::check_err(
// ds_host_dev, ds_host_ref, std::string("DS OUT Error: Incorrect results!"), rtol,
// ds_host_dev, ds_host_ref, std::string("DS OUT Error: Incorrect results!"), rtol, atol);
// atol);
// pass &= ck_tile::check_err(
// pass &= ck_tile::check_err(
// db_host_dev, db_host_ref, std::string("DB OUT Error: Incorrect results!"), rtol, atol);
// db_host_dev, db_host_ref, std::string("DB OUT Error: Incorrect results!"), rtol,
// atol);
}
if
(
mode
!=
1
)
{
pass
&=
ck_tile
::
check_err
(
dgamma_host_dev
,
dgamma_host_ref
,
std
::
string
(
"GAMMA OUT Error: Incorrect results!"
),
rtol
,
atol
);
pass
&=
ck_tile
::
check_err
(
dbeta_host_dev
,
dbeta_host_ref
,
std
::
string
(
"BETA OUT Error: Incorrect results!"
),
rtol
,
atol
);
}
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
}
return
pass
;
return
1
;
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
...
...
example/ck_tile/02_layernorm2d/layernorm2d_bwd.hpp
View file @
3289e656
...
@@ -36,7 +36,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t>
...
@@ -36,7 +36,7 @@ struct LayerNormTypeConfig<ck_tile::bf16_t>
};
};
// runtime args
// runtime args
struct
layernorm2d_bwd_args
:
public
ck_tile
::
Layernorm2dBwd
GammaBeta
HostArgs
struct
layernorm2d_bwd_args
:
public
ck_tile
::
Layernorm2dBwdHostArgs
{
{
};
};
...
@@ -46,8 +46,11 @@ template <typename DataType_,
...
@@ -46,8 +46,11 @@ template <typename DataType_,
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_M_
,
// vector size along M
ck_tile
::
index_t
Vector_N_
,
// vector size along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
>
bool
kPadN_
,
bool
kTwoPass_
,
bool
kCalData_
>
struct
layernorm2d_bwd_traits_
struct
layernorm2d_bwd_traits_
{
{
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
...
@@ -89,20 +92,22 @@ struct layernorm2d_bwd_traits_
...
@@ -89,20 +92,22 @@ struct layernorm2d_bwd_traits_
static
constexpr
ck_tile
::
index_t
Repeat_M
=
Repeat_M_
;
static
constexpr
ck_tile
::
index_t
Repeat_M
=
Repeat_M_
;
static
constexpr
ck_tile
::
index_t
Repeat_N
=
Repeat_N_
;
static
constexpr
ck_tile
::
index_t
Repeat_N
=
Repeat_N_
;
static
constexpr
ck_tile
::
index_t
Block_M
=
Repeat_M_
*
ThreadPerBlock_M_
;
static
constexpr
ck_tile
::
index_t
Block_M
=
Repeat_M_
*
ThreadPerBlock_M_
*
Vector_M_
;
static
constexpr
ck_tile
::
index_t
Block_N
=
Repeat_N_
*
ThreadPerBlock_N_
*
Vector_N_
;
static
constexpr
ck_tile
::
index_t
Block_N
=
Repeat_N_
*
ThreadPerBlock_N_
*
Vector_N_
;
static
constexpr
ck_tile
::
index_t
Warp_M
=
ThreadPerBlock_M_
/
BlockWarps_M
;
static
constexpr
ck_tile
::
index_t
Warp_M
=
ThreadPerBlock_M_
/
BlockWarps_M
*
Vector_M_
;
static
constexpr
ck_tile
::
index_t
Warp_N
=
ThreadPerBlock_N_
/
BlockWarps_N
*
Vector_N_
;
static
constexpr
ck_tile
::
index_t
Warp_N
=
ThreadPerBlock_N_
/
BlockWarps_N
*
Vector_N_
;
using
BlockTile
=
ck_tile
::
sequence
<
Block_M
,
Block_N
>
;
using
BlockTile
=
ck_tile
::
sequence
<
Block_M
,
Block_N
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
BlockWarps_M
,
BlockWarps_N
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
BlockWarps_M
,
BlockWarps_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Vector
=
ck_tile
::
sequence
<
Vector_M_
,
Vector_N_
>
;
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
bool
kCalData
=
kCalData_
;
};
};
template
<
typename
DataType_
,
template
<
typename
DataType_
,
...
@@ -110,15 +115,21 @@ template <typename DataType_,
...
@@ -110,15 +115,21 @@ template <typename DataType_,
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_M_
,
// vector size along M
ck_tile
::
index_t
Vector_N_
,
// vector size along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
>
bool
kPadN_
,
bool
kTwoPass_
,
bool
kCalData_
>
using
trait_
=
layernorm2d_bwd_traits_
<
DataType_
,
using
trait_
=
layernorm2d_bwd_traits_
<
DataType_
,
Repeat_M_
,
Repeat_M_
,
Repeat_N_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
ThreadPerBlock_N_
,
Vector_M_
,
Vector_N_
,
Vector_N_
,
kPadN_
>
;
kPadN_
,
kTwoPass_
,
kCalData_
>
;
template
<
typename
Traits_
>
template
<
typename
Traits_
>
float
layernorm2d_bwd_
(
const
ck_tile
::
stream_config
&
s
,
layernorm2d_bwd_args
a
);
float
layernorm2d_bwd_
(
const
ck_tile
::
stream_config
&
s
,
layernorm2d_bwd_args
a
);
...
@@ -126,27 +137,43 @@ float layernorm2d_bwd_(const ck_tile::stream_config& s, layernorm2d_bwd_args a);
...
@@ -126,27 +137,43 @@ float layernorm2d_bwd_(const ck_tile::stream_config& s, layernorm2d_bwd_args a);
// This is the public API, will be generated by script
// This is the public API, will be generated by script
struct
layernorm2d_bwd_traits
struct
layernorm2d_bwd_traits
{
{
std
::
string
d
ata
_t
ype
;
std
::
string
D
ata
T
ype
;
bool
CalData
;
// 0: weight grad, 1: data grad
};
};
template
<
typename
d
ata
_t
ype
>
template
<
typename
D
ata
T
ype
>
struct
layernorm2d_bwd_b16_
struct
layernorm2d_bwd_b16_
{
{
/* data */
/* data */
//using Trait = trait_<data_type, 1, 1, 1, 256, 1, true>;
//using Trait = trait_<DataType, 1, 1, 1, 256, 1, 1, true>;
//using Trait = trait_<data_type, 1, 8, 64, 4, 8, true>;
//using Trait = trait_<DataType, 1, 8, 64, 4, 1, 8, true>;
using
Trait
=
trait_
<
data_type
,
1
,
4
,
1
,
64
,
8
,
true
>
;
//using Trait = trait_<DataType, 1, 4, 1, 64, 1, 8, true>;
float
operator
()
(
layernorm2d_bwd_traits
/*t*/
,
//using Trait = trait_<DataType, 1, 2, 4, 16, 1, 8, true, false, true>;
//using Trait = trait_<DataType, 1, 1, 64, 1, 1, 1, true, false, false>;
float
operator
()
(
layernorm2d_bwd_traits
t
,
layernorm2d_bwd_args
a
,
layernorm2d_bwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
const
ck_tile
::
stream_config
&
s
)
{
return
layernorm2d_bwd_
<
Trait
>
(
s
,
a
);
if
(
t
.
CalData
)
{
if
(
a
.
n
<=
256
)
return
layernorm2d_bwd_
<
trait_
<
DataType
,
1
,
2
,
4
,
16
,
1
,
8
,
true
,
false
,
true
>>
(
s
,
a
);
else
return
layernorm2d_bwd_
<
trait_
<
DataType
,
1
,
4
,
2
,
32
,
1
,
8
,
true
,
true
,
true
>>
(
s
,
a
);
}
else
{
// if (a.n <= 64)
// return layernorm2d_bwd_<trait_<DataType, 1, 1, 64, 1, 1, 1, true, false, false>>(s, a);
// else
// return layernorm2d_bwd_<trait_<DataType, 1, 1, 32, 32, 8, 2, true, false, false>>(s, a);
return
layernorm2d_bwd_
<
trait_
<
DataType
,
1
,
1
,
4
,
32
,
1
,
1
,
true
,
false
,
false
>>
(
s
,
a
);
}
}
}
};
};
template
<
typename
data_type
>
//
template <typename data_type>
ck_tile
::
index_t
layernorm2d_bwd_block_m
()
{
//
ck_tile::index_t layernorm2d_bwd_block_m() {
return
layernorm2d_bwd_b16_
<
data_type
>::
Trait
::
Block_M
;
//
return layernorm2d_bwd_b16_<data_type>::Trait::Block_M;
};
//
};
float
layernorm2d_bwd
(
layernorm2d_bwd_traits
,
layernorm2d_bwd_args
,
const
ck_tile
::
stream_config
&
);
float
layernorm2d_bwd
(
layernorm2d_bwd_traits
,
layernorm2d_bwd_args
,
const
ck_tile
::
stream_config
&
);
include/ck_tile/host/reference/reference_layernorm2d_bwd.hpp
View file @
3289e656
...
@@ -18,8 +18,8 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
...
@@ -18,8 +18,8 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
const
HostTensor
<
GammaDataType
>&
gamma_n
,
const
HostTensor
<
GammaDataType
>&
gamma_n
,
const
HostTensor
<
MeanDataType
>&
mean_m
,
const
HostTensor
<
MeanDataType
>&
mean_m
,
const
HostTensor
<
InvStdDataType
>&
inv_std_m
,
const
HostTensor
<
InvStdDataType
>&
inv_std_m
,
HostTensor
<
GammaDataType
>&
dgamma_
mpart_
n
,
HostTensor
<
GammaDataType
>&
dgamma_n
,
HostTensor
<
BetaDataType
>&
dbeta_
mpart_
n
,
HostTensor
<
BetaDataType
>&
dbeta_n
,
HostTensor
<
XDataType
>&
dx_m_n
,
HostTensor
<
XDataType
>&
dx_m_n
,
//tmp
//tmp
...
@@ -30,7 +30,8 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
...
@@ -30,7 +30,8 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
const
auto
MN
=
x_m_n
.
mDesc
.
get_lengths
();
const
auto
MN
=
x_m_n
.
mDesc
.
get_lengths
();
const
int
M
=
MN
[
0
];
const
int
M
=
MN
[
0
];
const
int
N
=
MN
[
1
];
const
int
N
=
MN
[
1
];
const
int
PartM
=
dgamma_mpart_n
.
mDesc
.
get_lengths
()[
0
];
// const int PartM = dgamma_n.mDesc.get_lengths()[0];
const
int
PartM
=
1
;
const
int
MLoop
=
(
M
+
PartM
-
1
)
/
PartM
;
const
int
MLoop
=
(
M
+
PartM
-
1
)
/
PartM
;
printf
(
"
\n
dteng print---M=%d,N=%d,PartM=%d,MLoop=%d
\n
"
,
M
,
N
,
PartM
,
MLoop
);
printf
(
"
\n
dteng print---M=%d,N=%d,PartM=%d,MLoop=%d
\n
"
,
M
,
N
,
PartM
,
MLoop
);
auto
f
=
[
&
](
auto
m
)
{
auto
f
=
[
&
](
auto
m
)
{
...
@@ -51,8 +52,8 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
...
@@ -51,8 +52,8 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
//printf("\ndteng print---dy[%d][%d]=%f\n",m_offset + inner_m,n,dy);
//printf("\ndteng print---dy[%d][%d]=%f\n",m_offset + inner_m,n,dy);
}
}
dgamma_
mpart_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
GammaDataType
>
(
gamma_acc
);
dgamma_
n
(
n
)
=
ck_tile
::
type_convert
<
GammaDataType
>
(
gamma_acc
);
dbeta_
mpart_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
BetaDataType
>
(
beta_acc
);
dbeta_
n
(
n
)
=
ck_tile
::
type_convert
<
BetaDataType
>
(
beta_acc
);
}
}
//calculate dx
//calculate dx
...
...
include/ck_tile/ops/layernorm2d.hpp
View file @
3289e656
...
@@ -11,10 +11,12 @@
...
@@ -11,10 +11,12 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_dx_kernel.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_dgamma_beta_kernel.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_two_pass_dx.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_dx_one_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_dx_two_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_dgamma_beta_kernel.hpp
0 → 100644
View file @
3289e656
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_dx_kernel.hpp"
namespace
ck_tile
{
// TODO: Extract some type to wrapper class
template
<
typename
Pipeline_
>
struct
Layernorm2dBwdDGammaBeta
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
struct
Kargs
{
const
void
*
p_x
;
const
void
*
p_dY
;
const
void
*
p_gamma
;
const
void
*
p_mean
;
const
void
*
p_invStd
;
void
*
p_dGamma
;
void
*
p_dBeta
;
void
*
p_dX
;
//tmp
void
*
p_dS
;
void
*
p_dB
;
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
};
using
Hargs
=
Layernorm2dBwdHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_dY
,
hargs
.
p_gamma
,
hargs
.
p_mean
,
hargs
.
p_invStd
,
hargs
.
p_dGamma
,
hargs
.
p_dBeta
,
hargs
.
p_dX
,
//tmp
hargs
.
p_dS
,
hargs
.
p_dB
,
hargs
.
m
,
hargs
.
n
,
hargs
.
stride
};
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
return
(
hargs
.
n
+
Block_N
-
1
)
/
Block_N
;
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
// in byte
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Pipeline
::
GetSmemSize
();
}
CK_TILE_HOST
static
std
::
string
GetName
()
{
// clang-format off
using
S_
=
typename
Problem
::
BlockShape
;
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadN
)
n
+=
"_pn"
;
return
n
;
}();
#define _SS_ std::string
#define _TS_ std::to_string
return
_SS_
(
"layernorm2d_bwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
1
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
#undef _SS_
#undef _TS_
// clang-format on
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
const
auto
block_id
=
get_block_id
();
const
auto
iN
=
block_id
*
Block_N
;
// if(threadIdx.x == 0 && blockIdx.x == 0){
// printf("dteng block shape---WarpPerBlock_M=%d, WarpPerBlock_N=%d, ThreadPerWarp_M=%d, ThreadPerWarp_N=%d, Vector_N=%d\n", static_cast<int>(Problem::BlockShape::WarpPerBlock_M), static_cast<int>(Problem::BlockShape::WarpPerBlock_N), static_cast<int>(Problem::BlockShape::ThreadPerWarp_M), static_cast<int>(Problem::BlockShape::ThreadPerWarp_N), static_cast<int>(Problem::BlockShape::Vector_N));
// }
const
auto
x_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
0
,
iN
});
}();
const
auto
dy_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
YDataType
*>
(
kargs
.
p_dY
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
0
,
iN
});
}();
const
auto
mean_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
MeanDataType
*>
(
kargs
.
p_mean
),
make_tuple
(
kargs
.
m
),
make_tuple
(
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
0
});
}();
const
auto
invstd_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
MeanDataType
*>
(
kargs
.
p_invStd
),
make_tuple
(
kargs
.
m
),
make_tuple
(
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
0
});
}();
auto
dgamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
GammaDataType
*>
(
kargs
.
p_dGamma
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
iN
});
}();
auto
dbeta_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
BetaDataType
*>
(
kargs
.
p_dBeta
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
iN
});
}();
__shared__
char
smem
[
GetSmemSize
()];
// __shared__ char smem[0];
Pipeline
{}(
x_window
,
dy_window
,
mean_window
,
invstd_window
,
dgamma_window
,
dbeta_window
,
kargs
.
m
,
smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_
gamma_beta
_kernel.hpp
→
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_
dx
_kernel.hpp
View file @
3289e656
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
namespace
ck_tile
{
namespace
ck_tile
{
// host side args
// host side args
struct
Layernorm2dBwd
GammaBeta
HostArgs
struct
Layernorm2dBwdHostArgs
{
{
const
void
*
p_x
;
const
void
*
p_x
;
const
void
*
p_dY
;
const
void
*
p_dY
;
...
@@ -32,7 +32,7 @@ struct Layernorm2dBwdGammaBetaHostArgs
...
@@ -32,7 +32,7 @@ struct Layernorm2dBwdGammaBetaHostArgs
// TODO: Extract some type to wrapper class
// TODO: Extract some type to wrapper class
template
<
typename
Pipeline_
>
template
<
typename
Pipeline_
>
struct
Layernorm2dBwd
GammaBeta
struct
Layernorm2dBwd
DX
{
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
Problem
=
typename
Pipeline
::
Problem
;
...
@@ -76,7 +76,7 @@ struct Layernorm2dBwdGammaBeta
...
@@ -76,7 +76,7 @@ struct Layernorm2dBwdGammaBeta
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
stride
;
// row_stride
};
};
using
Hargs
=
Layernorm2dBwd
GammaBeta
HostArgs
;
using
Hargs
=
Layernorm2dBwdHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
{
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
View file @
3289e656
...
@@ -69,6 +69,53 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
...
@@ -69,6 +69,53 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
sequence
<
1
,
1
>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
3
>>
{});
sequence
<
0
,
3
>>
{});
}
}
// template <typename Problem>
// CK_TILE_DEVICE static constexpr auto MakeXBlockTileColDistribution()
// {
// using S = typename Problem::BlockShape;
// return make_static_tile_distribution(
// tile_distribution_encoding<
// sequence<>,
// // We want to walk along M direction first. In dweight distruction, *_M represent *_N, *_N represent *_M
// tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>,
// sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>>,
// tuple<sequence<2, 1>, sequence<2, 1>>,
// tuple<sequence<1, 1>, sequence<2, 2>>,
// sequence<2, 2, 1, 1>,
// sequence<0, 3, 0, 3>>{});
// }
// template <typename Problem>
// CK_TILE_DEVICE static constexpr auto MakeMeanBlockTileColDistribution()
// {
// using S = typename Problem::BlockShape;
// return make_static_tile_distribution(
// tile_distribution_encoding<
// // We want to walk along M direction first. In dweight distruction, *_M represent *_N, *_N represent *_M
// sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
// tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
// tuple<sequence<0, 1>, sequence<0, 1>>,
// tuple<sequence<0, 1>, sequence<1, 2>>,
// sequence<1, 1>,
// sequence<0, 3>>{});
// }
// template <typename Problem>
// CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileColDistribution()
// {
// using S = typename Problem::BlockShape;
// return make_static_tile_distribution(
// tile_distribution_encoding<
// // We want to walk along M direction first. In dweight distruction, *_M represent *_N, *_N represent *_M
// sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N>,
// tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>>,
// tuple<sequence<0, 0>, sequence<0, 0>>,
// tuple<sequence<1, 0>, sequence<2, 1>>,
// sequence<1, 1>,
// sequence<0, 3>>{});
// }
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_dx_one_pass.hpp
0 → 100644
View file @
3289e656
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dBwdGammaBetaPipelineDefaultPolicy
>
struct
Layernorm2dBwdDXOnePassPipeline
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
ReducePolicy
=
ck_tile
::
remove_cvref_t
<
BlockReduce2dDefaultPolicy
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
bool
kPadM
=
false
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
return
"bwd_dx_onepass"
;
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
ReducePolicy
::
template
GetSmemSize
<
Problem
>();
//GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
}
template
<
typename
XWindow
,
typename
YWindow
,
typename
GammaWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
DGammaWindow
,
typename
DBetaWindow
,
typename
DXWindow
,
// tmp
typename
DSWindow
,
typename
DBWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
YWindow
&
dy_window_
,
const
GammaWindow
&
gamma_window_
,
const
MeanWindow
&
mean_window_
,
const
InvStdWindow
&
inv_std_window_
,
DGammaWindow
&
dgamma_window_
,
DBetaWindow
&
dbeta_window_
,
DXWindow
&
dx_window_
,
// tmp
DSWindow
&
ds_window_
,
DBWindow
&
db_window_
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
auto
gamma_beta_dist
=
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>();
auto
dgamma_beta_dist
=
Policy
::
template
MakeDGammaBetaBlockTileDistribution
<
Problem
>();
auto
mean_dist
=
Policy
::
template
MakeMeanBlockTileDistribution
<
Problem
>();
auto
x_dist
=
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>();
const
auto
x_window
=
make_tile_window
(
x_window_
,
x_dist
);
const
auto
dy_window
=
make_tile_window
(
dy_window_
,
x_dist
);
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
gamma_beta_dist
);
// TO CHECK
const
auto
mean_window
=
make_tile_window
(
mean_window_
,
mean_dist
);
const
auto
inv_std_window
=
make_tile_window
(
inv_std_window_
,
mean_dist
);
auto
dgamma_window
=
make_tile_window
(
dgamma_window_
,
dgamma_beta_dist
);
auto
dbeta_window
=
make_tile_window
(
dbeta_window_
,
dgamma_beta_dist
);
auto
dx_window
=
make_tile_window
(
dx_window_
,
x_dist
);
const
auto
x_tile
=
load_tile
(
x_window
);
const
auto
dy_tile
=
load_tile
(
dy_window
);
const
auto
gamma_tile
=
load_tile
(
gamma_window
);
const
auto
mean_tile
=
load_tile
(
mean_window
);
const
auto
inv_std_tile
=
load_tile
(
inv_std_window
);
// tmp
auto
ds_window
=
make_tile_window
(
ds_window_
,
mean_dist
);
auto
db_window
=
make_tile_window
(
db_window_
,
mean_dist
);
auto
ds_tile
=
make_static_distributed_tensor
<
ComputeDataType
>
(
mean_dist
);
auto
db_tile
=
make_static_distributed_tensor
<
ComputeDataType
>
(
mean_dist
);
clear_tile
(
ds_tile
);
clear_tile
(
db_tile
);
// (void)ds_window;
// (void)db_window;
// auto dgamma_tile = make_static_distributed_tensor<GammaDataType>(dgamma_beta_dist);
// auto dbeta_tile = make_static_distributed_tensor<BetaDataType>(dgamma_beta_dist);
auto
dx_tile
=
make_static_distributed_tensor
<
XDataType
>
(
x_dist
);
// auto dgamma = cast_tile<ComputeDataType>(dgamma_tile);
// auto dbeta = cast_tile<ComputeDataType>(dbeta_tile);
auto
dx
=
cast_tile
<
ComputeDataType
>
(
dx_tile
);
// auto gen_ones = [](ck_tile::index_t size) -> uint64_t {
// if (size <= 0) return 0;
// if (size >= 64) return 0xFFFFFFFFFFFFFFFF;
// return (1ULL << size) - 1;
// };
// uint64_t lane_en = gen_ones(row_size);
// printf("lane en is %lu", lane_en);
// //uint64_t lane_en = (1ULL << row_size) - 1;
// asm volatile("s_mov_b64 exec, %[s_lane_en]"
// :
// : [s_lane_en]"s"(lane_en)
// : );
sweep_tile
(
x_tile
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_tile
[
idx
]);
const
auto
dy
=
type_convert
<
ComputeDataType
>
(
dy_tile
[
idx
]);
const
auto
gamma
=
type_convert
<
ComputeDataType
>
(
gamma_tile
[
j_idx
]);
ds_tile
(
i_idx
)
+=
dy
*
gamma
*
x
;
db_tile
(
i_idx
)
+=
dy
*
gamma
;
// printf("db_tile pre: threadidx=%d, blockidx=%d, db_tile=%f\n",threadIdx.x, blockIdx.x, db_tile[i_idx]);
// printf("dy_tile: threadidx=%d, blockidx=%d, dy_tile=%f\n",threadIdx.x, blockIdx.x, dy);
// printf("x: threadidx=%d, blockidx=%d, x_tile=%f\n",threadIdx.x, blockIdx.x, x);
// printf("gamma: threadidx=%d, blockidx=%d, gamma_tile=%f\n",threadIdx.x, blockIdx.x, gamma);
});
auto
block_reduce2d_sync
=
ReducePolicy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
ReducePolicy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
block_reduce2d_sync
(
ds_tile
,
ck_tile
::
ReduceOp
::
Add
{});
block_reduce2d_sync
(
db_tile
,
ck_tile
::
ReduceOp
::
Add
{});
// block_reduce2d_cross_warp_sync(ds_tile, smem, ck_tile::ReduceOp::Add{});
// block_reduce2d_cross_warp_sync(db_tile, smem, ck_tile::ReduceOp::Add{});
// sweep_tile(x_tile, [&](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// printf("db_tile post: threadidx=%d, blockidx=%d, db_tile=%f\n",threadIdx.x, blockIdx.x,
// db_tile[i_idx]);
// });
// store_tile(ds_window, ds_tile);
// store_tile(db_window, db_tile);
using
XDistributedTensor
=
decltype
(
load_tile
(
x_window
));
constexpr
auto
spans
=
XDistributedTensor
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
i_idx
)
{
constexpr
auto
idx0
=
make_tuple
(
i_idx
);
const
auto
mean
=
type_convert
<
ComputeDataType
>
(
mean_tile
[
idx0
]);
const
auto
inv_std
=
type_convert
<
ComputeDataType
>
(
inv_std_tile
[
idx0
]);
auto
b
=
(
db_tile
[
idx0
]
*
mean
-
ds_tile
[
idx0
])
*
inv_std
*
inv_std
*
inv_std
/
row_size
;
auto
c
=
-
b
*
mean
-
db_tile
[
idx0
]
*
inv_std
/
row_size
;
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
j_idx
)
{
constexpr
auto
idx1
=
make_tuple
(
j_idx
);
constexpr
auto
idx
=
make_tuple
(
i_idx
,
j_idx
);
//constexpr auto gb_idx = make_tuple(number<0>{}, j_idx);
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_tile
[
idx
]);
const
auto
dy
=
type_convert
<
ComputeDataType
>
(
dy_tile
[
idx
]);
const
auto
gamma
=
type_convert
<
ComputeDataType
>
(
gamma_tile
[
idx1
]);
// dbeta(gb_idx) += dy;
// dgamma(gb_idx) += dy * (x - mean) * inv_std;
dx
(
idx
)
=
dy
*
gamma
*
inv_std
+
b
*
x
+
c
;
//printf("dx: threadidx=%d, blockidx=%d, dx_tile=%f\n",threadIdx.x, blockIdx.x, dx(idx));
});
});
// store_tile(dbeta_window, cast_tile<BetaDataType>(dbeta));
// store_tile(dgamma_window, cast_tile<GammaDataType>(dgamma));
store_tile
(
dx_window
,
cast_tile
<
XDataType
>
(
dx
));
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_two_pass
_dx
.hpp
→
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_
dx_
two_pass.hpp
View file @
3289e656
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dBwdGammaBetaPipelineDefaultPolicy
>
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dBwdGammaBetaPipelineDefaultPolicy
>
struct
Layernorm2dBwd
GammaBeta
Pipeline
TwoPass
struct
Layernorm2dBwd
DXTwoPass
Pipeline
{
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
...
@@ -29,7 +29,7 @@ struct Layernorm2dBwdGammaBetaPipelineTwoPass
...
@@ -29,7 +29,7 @@ struct Layernorm2dBwdGammaBetaPipelineTwoPass
static
constexpr
bool
kPadM
=
false
;
static
constexpr
bool
kPadM
=
false
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
return
"bwd_
gamma_beta
"
;
}();
static
constexpr
const
char
*
name
=
[]()
{
return
"bwd_
dx_twopass
"
;
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
...
@@ -37,6 +37,7 @@ struct Layernorm2dBwdGammaBetaPipelineTwoPass
...
@@ -37,6 +37,7 @@ struct Layernorm2dBwdGammaBetaPipelineTwoPass
//GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
//GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
}
}
template
<
typename
XWindow
,
template
<
typename
XWindow
,
typename
YWindow
,
typename
GammaWindow
,
typename
GammaWindow
,
typename
MeanWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
InvStdWindow
,
...
@@ -48,7 +49,7 @@ struct Layernorm2dBwdGammaBetaPipelineTwoPass
...
@@ -48,7 +49,7 @@ struct Layernorm2dBwdGammaBetaPipelineTwoPass
typename
DSWindow
,
typename
DSWindow
,
typename
DBWindow
>
typename
DBWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
X
Window
&
dy_window_
,
const
Y
Window
&
dy_window_
,
const
GammaWindow
&
gamma_window_
,
const
GammaWindow
&
gamma_window_
,
const
MeanWindow
&
mean_window_
,
const
MeanWindow
&
mean_window_
,
const
InvStdWindow
&
inv_std_window_
,
const
InvStdWindow
&
inv_std_window_
,
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp
View file @
3289e656
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dBwdGammaBetaPipelineDefaultPolicy
>
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dBwdGammaBetaPipelineDefaultPolicy
>
struct
Layernorm2dBwdGammaBetaPipeline
struct
Layernorm2dBwd
D
GammaBetaPipeline
{
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
...
@@ -33,147 +33,99 @@ struct Layernorm2dBwdGammaBetaPipeline
...
@@ -33,147 +33,99 @@ struct Layernorm2dBwdGammaBetaPipeline
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
return
ReducePolicy
::
template
GetSmemSize
<
Problem
>();
// return ReducePolicy::template GetSmemSize<Problem>();
//GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
using
y_block_tile
=
decltype
(
make_static_distributed_tensor
<
GammaDataType
>
(
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>()));
return
ReducePolicy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>().
template
GetSmemSize
<
y_block_tile
>();
}
}
template
<
typename
XWindow
,
template
<
typename
XWindow
,
typename
Gamma
Window
,
typename
Y
Window
,
typename
MeanWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
InvStdWindow
,
typename
DGammaWindow
,
typename
DGammaWindow
,
typename
DBetaWindow
,
typename
DBetaWindow
>
typename
DXWindow
,
// tmp
typename
DSWindow
,
typename
DBWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XWindow
&
dy_window_
,
const
YWindow
&
dy_window_
,
const
GammaWindow
&
gamma_window_
,
const
MeanWindow
&
mean_window_
,
const
MeanWindow
&
mean_window_
,
const
InvStdWindow
&
inv_std_window_
,
const
InvStdWindow
&
inv_std_window_
,
DGammaWindow
&
dgamma_window_
,
DGammaWindow
&
dgamma_window_
,
DBetaWindow
&
dbeta_window_
,
DBetaWindow
&
dbeta_window_
,
DXWindow
&
dx_window_
,
ck_tile
::
index_t
column_size
,
// tmp
DSWindow
&
ds_window_
,
DBWindow
&
db_window_
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
)
const
{
{
auto
gamma_beta_dist
=
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>();
auto
dgamma_beta_dist
=
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>();
auto
dgamma_beta_dist
=
Policy
::
template
MakeDGammaBetaBlockTileDistribution
<
Problem
>();
auto
mean_dist
=
Policy
::
template
MakeMeanBlockTileDistribution
<
Problem
>();
auto
mean_dist
=
Policy
::
template
MakeMeanBlockTileDistribution
<
Problem
>();
auto
x_dist
=
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>();
auto
x_dist
=
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>();
const
auto
x_window
=
make_tile_window
(
x_window_
,
x_dist
);
// const auto x_window = make_tile_window(x_window_, x_dist);
const
auto
dy_window
=
make_tile_window
(
dy_window_
,
x_dist
);
// const auto dy_window = make_tile_window(dy_window_, x_dist);
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
gamma_beta_dist
);
// TO CHECK
// const auto mean_window = make_tile_window(mean_window_, mean_dist);
const
auto
mean_window
=
make_tile_window
(
mean_window_
,
mean_dist
);
// const auto inv_std_window = make_tile_window(inv_std_window_, mean_dist);
const
auto
inv_std_window
=
make_tile_window
(
inv_std_window_
,
mean_dist
);
auto
x_window
=
make_tile_window
(
x_window_
,
x_dist
);
auto
dy_window
=
make_tile_window
(
dy_window_
,
x_dist
);
auto
mean_window
=
make_tile_window
(
mean_window_
,
mean_dist
);
auto
inv_std_window
=
make_tile_window
(
inv_std_window_
,
mean_dist
);
auto
dgamma_window
=
make_tile_window
(
dgamma_window_
,
dgamma_beta_dist
);
auto
dgamma_window
=
make_tile_window
(
dgamma_window_
,
dgamma_beta_dist
);
auto
dbeta_window
=
make_tile_window
(
dbeta_window_
,
dgamma_beta_dist
);
auto
dbeta_window
=
make_tile_window
(
dbeta_window_
,
dgamma_beta_dist
);
auto
dx_window
=
make_tile_window
(
dx_window_
,
x_dist
);
// const auto x_tile = load_tile(x_window);
const
auto
x_tile
=
load_tile
(
x_window
);
// const auto dy_tile = load_tile(dy_window);
const
auto
dy_tile
=
load_tile
(
dy_window
);
// const auto mean_tile = load_tile(mean_window);
const
auto
gamma_tile
=
load_tile
(
gamma_window
);
// const auto inv_std_tile = load_tile(inv_std_window);
const
auto
mean_tile
=
load_tile
(
mean_window
);
const
auto
inv_std_tile
=
load_tile
(
inv_std_window
);
auto
dgamma_tile
=
make_static_distributed_tensor
<
GammaDataType
>
(
dgamma_beta_dist
);
auto
dbeta_tile
=
make_static_distributed_tensor
<
BetaDataType
>
(
dgamma_beta_dist
);
// tmp
auto
dgamma
=
cast_tile
<
ComputeDataType
>
(
dgamma_tile
);
auto
ds_window
=
make_tile_window
(
ds_window_
,
mean_dist
);
auto
dbeta
=
cast_tile
<
ComputeDataType
>
(
dbeta_tile
);
auto
db_window
=
make_tile_window
(
db_window_
,
mean_dist
);
auto
ds_tile
=
make_static_distributed_tensor
<
ComputeDataType
>
(
mean_dist
);
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
auto
db_tile
=
make_static_distributed_tensor
<
ComputeDataType
>
(
mean_dist
);
index_t
num_m_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
column_size
,
Block_M
));
clear_tile
(
ds_tile
);
clear_tile
(
db_tile
);
for
(
int
iM
=
__builtin_amdgcn_readfirstlane
(
0
);
iM
<
num_m_tile_iteration
;
++
iM
)
// (void)ds_window;
{
// (void)db_window;
const
auto
x_tile
=
load_tile
(
x_window
);
const
auto
dy_tile
=
load_tile
(
dy_window
);
// auto dgamma_tile = make_static_distributed_tensor<GammaDataType>(dgamma_beta_dist);
const
auto
mean_tile
=
load_tile
(
mean_window
);
// auto dbeta_tile = make_static_distributed_tensor<BetaDataType>(dgamma_beta_dist);
const
auto
inv_std_tile
=
load_tile
(
inv_std_window
);
auto
dx_tile
=
make_static_distributed_tensor
<
XDataType
>
(
x_dist
);
// auto dgamma = cast_tile<ComputeDataType>(dgamma_tile);
move_tile_window
(
x_window
,
{
Block_M
,
0
});
// auto dbeta = cast_tile<ComputeDataType>(dbeta_tile);
move_tile_window
(
dy_window
,
{
Block_M
,
0
});
auto
dx
=
cast_tile
<
ComputeDataType
>
(
dx_tile
);
move_tile_window
(
mean_window
,
{
Block_M
});
move_tile_window
(
inv_std_window
,
{
Block_M
});
// auto gen_ones = [](ck_tile::index_t size) -> uint64_t {
sweep_tile
(
x_tile
,
[
&
](
auto
idx
)
{
// if (size <= 0) return 0;
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
// if (size >= 64) return 0xFFFFFFFFFFFFFFFF;
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
// return (1ULL << size) - 1;
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_tile
[
idx
]);
// };
const
auto
dy
=
type_convert
<
ComputeDataType
>
(
dy_tile
[
idx
]);
const
auto
mean
=
type_convert
<
ComputeDataType
>
(
mean_tile
[
i_idx
]);
// uint64_t lane_en = gen_ones(row_size);
const
auto
inv_std
=
type_convert
<
ComputeDataType
>
(
inv_std_tile
[
i_idx
]);
// printf("lane en is %lu", lane_en);
dbeta
(
j_idx
)
+=
dy
;
// //uint64_t lane_en = (1ULL << row_size) - 1;
dgamma
(
j_idx
)
+=
dy
*
(
x
-
mean
)
*
inv_std
;
printf
(
"dy: threadidx=%d, blockidx=%d, x_tile=%f
\n
"
,
threadIdx
.
x
,
blockIdx
.
x
,
dy
);
// asm volatile("s_mov_b64 exec, %[s_lane_en]"
});
// :
}
// : [s_lane_en]"s"(lane_en)
// : );
sweep_tile
(
x_tile
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_tile
[
idx
]);
const
auto
dy
=
type_convert
<
ComputeDataType
>
(
dy_tile
[
idx
]);
const
auto
gamma
=
type_convert
<
ComputeDataType
>
(
gamma_tile
[
j_idx
]);
ds_tile
(
i_idx
)
+=
dy
*
gamma
*
x
;
db_tile
(
i_idx
)
+=
dy
*
gamma
;
// printf("db_tile pre: threadidx=%d, blockidx=%d, db_tile=%f\n",threadIdx.x, blockIdx.x, db_tile[i_idx]);
// printf("dy_tile: threadidx=%d, blockidx=%d, dy_tile=%f\n",threadIdx.x, blockIdx.x, dy);
// printf("x: threadidx=%d, blockidx=%d, x_tile=%f\n",threadIdx.x, blockIdx.x, x);
// printf("gamma: threadidx=%d, blockidx=%d, gamma_tile=%f\n",threadIdx.x, blockIdx.x, gamma);
});
auto
block_reduce2d_sync
=
ReducePolicy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_sync
=
ReducePolicy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
ReducePolicy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
ReducePolicy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
block_reduce2d_sync
(
ds_tile
,
ck_tile
::
ReduceOp
::
Add
{});
block_reduce2d_sync
(
dbeta
,
ck_tile
::
ReduceOp
::
Add
{});
block_reduce2d_sync
(
db_tile
,
ck_tile
::
ReduceOp
::
Add
{});
block_reduce2d_sync
(
dgamma
,
ck_tile
::
ReduceOp
::
Add
{});
// block_reduce2d_cross_warp_sync(ds_tile, smem, ck_tile::ReduceOp::Add{});
sweep_tile
(
dbeta
,
[
&
](
auto
idx
)
{
// block_reduce2d_cross_warp_sync(db_tile, smem, ck_tile::ReduceOp::Add{});
printf
(
"dbeta pre: threadidx=%d, blockidx=%d, dbeta=%f
\n
"
,
threadIdx
.
x
,
blockIdx
.
x
,
dbeta
[
idx
]);
// sweep_tile(x_tile, [&](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// printf("db_tile post: threadidx=%d, blockidx=%d, db_tile=%f\n",threadIdx.x, blockIdx.x,
// db_tile[i_idx]);
// });
// store_tile(ds_window, ds_tile);
// store_tile(db_window, db_tile);
using
XDistributedTensor
=
decltype
(
load_tile
(
x_window
));
constexpr
auto
spans
=
XDistributedTensor
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
i_idx
)
{
constexpr
auto
idx0
=
make_tuple
(
i_idx
);
const
auto
mean
=
type_convert
<
ComputeDataType
>
(
mean_tile
[
idx0
]);
const
auto
inv_std
=
type_convert
<
ComputeDataType
>
(
inv_std_tile
[
idx0
]);
auto
b
=
(
db_tile
[
idx0
]
*
mean
-
ds_tile
[
idx0
])
*
inv_std
*
inv_std
*
inv_std
/
row_size
;
auto
c
=
-
b
*
mean
-
db_tile
[
idx0
]
*
inv_std
/
row_size
;
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
j_idx
)
{
constexpr
auto
idx1
=
make_tuple
(
j_idx
);
constexpr
auto
idx
=
make_tuple
(
i_idx
,
j_idx
);
//constexpr auto gb_idx = make_tuple(number<0>{}, j_idx);
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_tile
[
idx
]);
const
auto
dy
=
type_convert
<
ComputeDataType
>
(
dy_tile
[
idx
]);
const
auto
gamma
=
type_convert
<
ComputeDataType
>
(
gamma_tile
[
idx1
]);
// dbeta(gb_idx) += dy;
// dgamma(gb_idx) += dy * (x - mean) * inv_std;
dx
(
idx
)
=
dy
*
gamma
*
inv_std
+
b
*
x
+
c
;
//printf("dx: threadidx=%d, blockidx=%d, dx_tile=%f\n",threadIdx.x, blockIdx.x, dx(idx));
});
});
});
// store_tile(dbeta_window, cast_tile<BetaDataType>(dbeta));
block_reduce2d_cross_warp_sync
(
dbeta
,
smem
,
ck_tile
::
ReduceOp
::
Add
{});
// store_tile(dgamma_window, cast_tile<GammaDataType>(dgamma));
block_reduce2d_cross_warp_sync
(
dgamma
,
smem
,
ck_tile
::
ReduceOp
::
Add
{});
store_tile
(
dx_window
,
cast_tile
<
XDataType
>
(
dx
));
sweep_tile
(
dbeta
,
[
&
](
auto
idx
)
{
printf
(
"dbeta post: threadidx=%d, blockidx=%d, dbeta=%f
\n
"
,
threadIdx
.
x
,
blockIdx
.
x
,
dbeta
[
idx
]);
});
store_tile
(
dbeta_window
,
cast_tile
<
BetaDataType
>
(
dbeta
));
store_tile
(
dgamma_window
,
cast_tile
<
GammaDataType
>
(
dgamma
));
}
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment