Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
677a842e
Commit
677a842e
authored
Jan 14, 2025
by
AMD-dteng
Browse files
local base version
parent
5d671a5f
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1012 additions
and
0 deletions
+1012
-0
example/ck_tile/02_layernorm2d/CMakeLists.txt
example/ck_tile/02_layernorm2d/CMakeLists.txt
+22
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_api.cpp
.../ck_tile/02_layernorm2d/instances/layernorm2d_bwd_api.cpp
+25
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_bf16_n64_n128_instance.cpp
...rm2d/instances/layernorm2d_bwd_bf16_n64_n128_instance.cpp
+11
-0
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_instance_common.hpp
...layernorm2d/instances/layernorm2d_bwd_instance_common.hpp
+44
-0
example/ck_tile/02_layernorm2d/layernorm2d_bwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_bwd.cpp
+187
-0
example/ck_tile/02_layernorm2d/layernorm2d_bwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_bwd.hpp
+142
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/reference/reference_layernorm2d_bwd.hpp
include/ck_tile/host/reference/reference_layernorm2d_bwd.hpp
+86
-0
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+6
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp
.../layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp
+244
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
+79
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp
...ernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp
+132
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp
+33
-0
No files found.
example/ck_tile/02_layernorm2d/CMakeLists.txt
View file @
677a842e
...
...
@@ -42,3 +42,25 @@ target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property
(
GLOBAL PROPERTY RULE_MESSAGES OFF
)
set
(
EXAMPLE_LAYERNORM2D_BWD
"tile_example_layernorm2d_bwd"
)
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message
(
"adding example
${
EXAMPLE_LAYERNORM2D_BWD
}
"
)
file
(
GLOB INSTANCE_SRCS instances/*.cpp
)
add_executable
(
${
EXAMPLE_LAYERNORM2D_BWD
}
EXCLUDE_FROM_ALL layernorm2d_bwd.cpp
)
target_include_directories
(
${
EXAMPLE_LAYERNORM2D_BWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
EXAMPLE_LAYERNORM2D_BWD
}
PRIVATE
${
INSTANCE_SRCS
}
)
set
(
EXAMPLE_layernorm2d_bwd_COMPILE_OPTIONS
)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list
(
APPEND EXAMPLE_layernorm2d_bwd_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
)
target_compile_options
(
${
EXAMPLE_LAYERNORM2D_BWD
}
PRIVATE
${
EXAMPLE_layernorm2d_bwd_COMPILE_OPTIONS
}
)
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property
(
GLOBAL PROPERTY RULE_MESSAGES OFF
)
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_api.cpp
0 → 100644
View file @
677a842e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_bwd.hpp"
float
layernorm2d_bwd
(
layernorm2d_bwd_traits
t
,
layernorm2d_bwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
float
r
=
-
1
;
if
(
t
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
return
layernorm2d_bwd_b16_
<
ck_tile
::
fp16_t
>
{}(
t
,
a
,
s
);
}
else
if
(
t
.
data_type
.
compare
(
"bf16"
)
==
0
)
{
return
layernorm2d_bwd_b16_
<
ck_tile
::
bf16_t
>
{}(
t
,
a
,
s
);
}
if
(
r
<
0
)
throw
std
::
runtime_error
(
"Without supported instances!"
);
return
r
;
}
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_bf16_n64_n128_instance.cpp
0 → 100644
View file @
677a842e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_bwd_instance_common.hpp"
// clang-format off
// rm tm tn pd
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
64
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
64
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_instance_common.hpp
0 → 100644
View file @
677a842e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "layernorm2d_bwd.hpp"
#include <iostream>
#pragma once
using
S
=
ck_tile
::
stream_config
;
using
A
=
layernorm2d_bwd_args
;
template
<
typename
Traits_
>
float
layernorm2d_bwd_
(
const
S
&
s
,
A
a
)
{
using
DataType
=
typename
Traits_
::
DataType
;
using
PipelineProblem
=
ck_tile
::
Layernorm2dBwdGammaBetaPipelineProblem
<
typename
LayerNormTypeConfig
<
DataType
>::
XDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
GammaDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
BetaDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
ComputeDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
YDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
MeanDataType
,
typename
LayerNormTypeConfig
<
DataType
>::
InvStdDataType
,
typename
Traits_
::
Shape
,
Traits_
::
kPadN
>
;
using
Pipeline
=
ck_tile
::
Layernorm2dBwdGammaBetaPipeline
<
PipelineProblem
>
;
using
Kernel
=
ck_tile
::
Layernorm2dBwdGammaBeta
<
Pipeline
>
;
const
dim3
grids
=
Kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
auto
kargs
=
Kernel
::
MakeKargs
(
a
);
if
(
s
.
log_level_
>
0
)
std
::
cout
<<
", "
<<
Kernel
::
GetName
()
<<
std
::
flush
;
return
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
}
example/ck_tile/02_layernorm2d/layernorm2d_bwd.cpp
0 → 100644
View file @
677a842e
#include "ck_tile/host.hpp"
#include "layernorm2d_bwd.hpp"
#include <cstring>
// different threshold for different dtype
template
<
typename
DataType
>
auto
get_elimit
()
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
bf16_t
>
()
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"m"
,
"3328"
,
"m dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to n"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
ck_tile
::
index_t
m
=
arg_parser
.
get_int
(
"m"
);
ck_tile
::
index_t
n
=
arg_parser
.
get_int
(
"n"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
if
(
stride
<
0
)
stride
=
n
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
assert
(
stride
>=
n
);
using
TypeConfig
=
LayerNormTypeConfig
<
DataType
>
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
YDataType
=
typename
TypeConfig
::
YDataType
;
using
GammaDataType
=
typename
TypeConfig
::
GammaDataType
;
using
BetaDataType
=
typename
TypeConfig
::
BetaDataType
;
using
MeanDataType
=
typename
TypeConfig
::
MeanDataType
;
using
InvStdDataType
=
typename
TypeConfig
::
InvStdDataType
;
using
ComputeDataType
=
typename
TypeConfig
::
ComputeDataType
;
// host verify
ck_tile
::
HostTensor
<
XDataType
>
x_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
YDataType
>
dy_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
GammaDataType
>
gamma_host
({
n
});
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host
({
m
});
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host
({
m
});
ck_tile
::
index_t
blockM
=
layernorm2d_bwd_block_m
<
XDataType
>
();
ck_tile
::
index_t
reduce_m
=
(
m
+
blockM
-
1
)
/
blockM
;
ck_tile
::
HostTensor
<
GammaDataType
>
dgamma_host_dev
({
reduce_m
,
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
dbeta_host_dev
({
reduce_m
,
n
});
ck_tile
::
HostTensor
<
XDataType
>
dx_host_dev
({
m
,
n
});
ck_tile
::
HostTensor
<
GammaDataType
>
dgamma_host_ref
({
reduce_m
,
n
});
ck_tile
::
HostTensor
<
BetaDataType
>
dbeta_host_ref
({
reduce_m
,
n
});
ck_tile
::
HostTensor
<
XDataType
>
dx_host_ref
({
m
,
n
});
//tmp
ck_tile
::
HostTensor
<
ComputeDataType
>
ds_host_dev
({
m
});
ck_tile
::
HostTensor
<
ComputeDataType
>
db_host_dev
({
m
});
ck_tile
::
HostTensor
<
ComputeDataType
>
ds_host_ref
({
m
});
ck_tile
::
HostTensor
<
ComputeDataType
>
db_host_ref
({
m
});
// ck_tile::FillMonotonicSeq<YDataType>{}(dy_host);
ck_tile
::
FillUniformDistribution
<
YDataType
>
{
-
.5
f
,
.5
f
}(
dy_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
MeanDataType
>
{
-
.5
f
,
.5
f
}(
mean_host
);
ck_tile
::
FillUniformDistribution
<
XDataType
>
{
-
.5
f
,
.5
f
}(
x_host
);
// ck_tile::FillMonotonicSeq<MeanDataType>{}(mean_host);
ck_tile
::
FillUniformDistribution
<
InvStdDataType
>
{
-
.5
f
,
.5
f
}(
invStd_host
);
ck_tile
::
DeviceMem
x_buf
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dy_buf
(
dy_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
mean_buf
(
mean_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
invStd_buf
(
invStd_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dgamma_buf
(
dgamma_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dbeta_buf
(
dbeta_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
dx_buf
(
dx_host_dev
.
get_element_space_size_in_bytes
());
//tmp
ck_tile
::
DeviceMem
ds_buf
(
ds_host_dev
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
db_buf
(
db_host_dev
.
get_element_space_size_in_bytes
());
x_buf
.
ToDevice
(
x_host
.
data
());
dy_buf
.
ToDevice
(
dy_host
.
data
());
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
mean_buf
.
ToDevice
(
mean_host
.
data
());
invStd_buf
.
ToDevice
(
invStd_host
.
data
());
std
::
cout
<<
"["
<<
data_type
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
layernorm2d_bwd_traits
traits
{
data_type
};
layernorm2d_bwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
dy_buf
.
GetDeviceBuffer
(),
gamma_buf
.
GetDeviceBuffer
(),
mean_buf
.
GetDeviceBuffer
(),
invStd_buf
.
GetDeviceBuffer
(),
dgamma_buf
.
GetDeviceBuffer
(),
dbeta_buf
.
GetDeviceBuffer
(),
dx_buf
.
GetDeviceBuffer
(),
m
,
n
,
stride
};
float
ave_time
=
layernorm2d_bwd
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
m
*
n
+
sizeof
(
GammaDataType
)
*
n
+
sizeof
(
BetaDataType
)
*
n
+
sizeof
(
YDataType
)
*
m
*
n
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
sizeof
(
ComputeDataType
)
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
flush
;
bool
pass
=
true
;
if
(
do_validation
)
{
// reference
ck_tile
::
reference_layernorm2d_bwd_gamma_part
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
MeanDataType
,
InvStdDataType
>
(
x_host
,
dy_host
,
gamma_host
,
mean_host
,
invStd_host
,
dgamma_host_ref
,
dbeta_host_ref
,
dx_host_ref
,
ds_host_ref
,
db_host_ref
);
dgamma_buf
.
FromDevice
(
dgamma_host_dev
.
data
());
dbeta_buf
.
FromDevice
(
dbeta_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
();
pass
=
ck_tile
::
check_err
(
dgamma_host_dev
,
dgamma_host_ref
,
std
::
string
(
"GAMMA OUT Error: Incorrect results!"
),
rtol
,
atol
);
pass
&=
ck_tile
::
check_err
(
dbeta_host_dev
,
dbeta_host_ref
,
std
::
string
(
"BETA OUT Error: Incorrect results!"
),
rtol
,
atol
);
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
return
pass
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
if
(
data_type
==
"fp16"
)
{
return
run
<
ck_tile
::
half_t
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
data_type
==
"bf16"
)
{
return
run
<
ck_tile
::
bf16_t
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
}
example/ck_tile/02_layernorm2d/layernorm2d_bwd.hpp
0 → 100644
View file @
677a842e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
template
<
typename
DataType
>
struct
LayerNormTypeConfig
;
template
<
>
struct
LayerNormTypeConfig
<
ck_tile
::
half_t
>
{
using
XDataType
=
ck_tile
::
half_t
;
using
YDataType
=
ck_tile
::
half_t
;
using
GammaDataType
=
ck_tile
::
half_t
;
using
BetaDataType
=
ck_tile
::
half_t
;
using
MeanDataType
=
ck_tile
::
half_t
;
using
InvStdDataType
=
ck_tile
::
half_t
;
using
ComputeDataType
=
float
;
};
template
<
>
struct
LayerNormTypeConfig
<
ck_tile
::
bf16_t
>
{
using
XDataType
=
ck_tile
::
bf16_t
;
using
YDataType
=
ck_tile
::
bf16_t
;
using
GammaDataType
=
ck_tile
::
bf16_t
;
using
BetaDataType
=
ck_tile
::
bf16_t
;
using
MeanDataType
=
ck_tile
::
bf16_t
;
using
InvStdDataType
=
ck_tile
::
bf16_t
;
using
ComputeDataType
=
float
;
};
// runtime args
struct
layernorm2d_bwd_args
:
public
ck_tile
::
Layernorm2dBwdGammaBetaHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
bool
kPadN_
>
struct
layernorm2d_bwd_traits_
{
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
is_warp_per_row
=
ThreadPerBlock_N_
<=
warpSize
;
static_assert
((
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
%
warpSize
==
0
);
static
constexpr
ck_tile
::
index_t
total_warps
=
(
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
/
warpSize
;
// num of warps along m
static
constexpr
ck_tile
::
index_t
BlockWarps_M
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
total_warps
*
(
warpSize
/
ThreadPerBlock_N_
);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return
total_warps
/
(
ThreadPerBlock_N_
/
warpSize
);
}
}();
// num of warps along n
static
constexpr
ck_tile
::
index_t
BlockWarps_N
=
[]()
{
if
constexpr
(
is_warp_per_row
)
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
1
;
}
else
{
static_assert
(
ThreadPerBlock_N_
%
warpSize
==
0
);
return
ThreadPerBlock_N_
/
warpSize
;
}
}();
static
constexpr
ck_tile
::
index_t
Repeat_M
=
Repeat_M_
;
static
constexpr
ck_tile
::
index_t
Block_M
=
Repeat_M_
*
ThreadPerBlock_M_
;
static
constexpr
ck_tile
::
index_t
Block_N
=
ThreadPerBlock_N_
;
static
constexpr
ck_tile
::
index_t
Warp_M
=
ThreadPerBlock_M_
/
BlockWarps_M
;
static
constexpr
ck_tile
::
index_t
Warp_N
=
ThreadPerBlock_N_
/
BlockWarps_N
;
using
BlockTile
=
ck_tile
::
sequence
<
Block_M
,
Block_N
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
BlockWarps_M
,
BlockWarps_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
1
>
;
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
};
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
bool
kPadN_
>
using
trait_
=
layernorm2d_bwd_traits_
<
DataType_
,
Repeat_M_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
kPadN_
>
;
template
<
typename
Traits_
>
float
layernorm2d_bwd_
(
const
ck_tile
::
stream_config
&
s
,
layernorm2d_bwd_args
a
);
// This is the public API, will be generated by script
struct
layernorm2d_bwd_traits
{
std
::
string
data_type
;
};
template
<
typename
data_type
>
struct
layernorm2d_bwd_b16_
{
/* data */
using
Trait
=
trait_
<
data_type
,
1
,
1
,
64
,
true
>
;
float
operator
()
(
layernorm2d_bwd_traits
/*t*/
,
layernorm2d_bwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
return
layernorm2d_bwd_
<
Trait
>
(
s
,
a
);
}
};
template
<
typename
data_type
>
ck_tile
::
index_t
layernorm2d_bwd_block_m
()
{
return
layernorm2d_bwd_b16_
<
data_type
>::
Trait
::
Block_M
;
};
float
layernorm2d_bwd
(
layernorm2d_bwd_traits
,
layernorm2d_bwd_args
,
const
ck_tile
::
stream_config
&
);
include/ck_tile/host.hpp
View file @
677a842e
...
...
@@ -25,6 +25,7 @@
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_bwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
...
...
include/ck_tile/host/reference/reference_layernorm2d_bwd.hpp
0 → 100644
View file @
677a842e
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace
ck_tile
{
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
MeanDataType
,
typename
InvStdDataType
>
CK_TILE_HOST
void
reference_layernorm2d_bwd_gamma_part
(
const
HostTensor
<
XDataType
>&
x_m_n
,
const
HostTensor
<
YDataType
>&
dy_m_n
,
const
HostTensor
<
GammaDataType
>&
gamma_n
,
const
HostTensor
<
MeanDataType
>&
mean_m
,
const
HostTensor
<
InvStdDataType
>&
inv_std_m
,
HostTensor
<
GammaDataType
>&
dgamma_mpart_n
,
HostTensor
<
BetaDataType
>&
dbeta_mpart_n
,
HostTensor
<
XDataType
>&
dx_m_n
,
//tmp
HostTensor
<
ComputeDataType
>&
ds_m
,
HostTensor
<
ComputeDataType
>&
db_m
)
{
const
auto
MN
=
x_m_n
.
mDesc
.
get_lengths
();
const
int
M
=
MN
[
0
];
const
int
N
=
MN
[
1
];
const
int
PartM
=
dgamma_mpart_n
.
mDesc
.
get_lengths
()[
0
];
const
int
MLoop
=
(
M
+
PartM
-
1
)
/
PartM
;
printf
(
"
\n
dteng print---M=%d,N=%d,PartM=%d,MLoop=%d
\n
"
,
M
,
N
,
PartM
,
MLoop
);
auto
f
=
[
&
](
auto
m
)
{
const
int
m_offset
=
m
*
MLoop
;
//calculate dgamma, dbeta
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
gamma_acc
=
0
;
ComputeDataType
beta_acc
=
0
;
for
(
int
inner_m
=
0
;
inner_m
<
MLoop
&&
m_offset
+
inner_m
<
M
;
inner_m
++
)
{
const
ComputeDataType
mean
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
mean_m
(
m_offset
+
inner_m
));
const
ComputeDataType
inv_std
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
inv_std_m
(
m_offset
+
inner_m
));
const
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m_offset
+
inner_m
,
n
));
const
ComputeDataType
dy
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
dy_m_n
(
m_offset
+
inner_m
,
n
));
gamma_acc
+=
dy
*
(
x
-
mean
)
*
inv_std
;
beta_acc
+=
dy
;
}
dgamma_mpart_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
GammaDataType
>
(
gamma_acc
);
dbeta_mpart_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
BetaDataType
>
(
beta_acc
);
}
//calculate dx
for
(
int
inner_m
=
0
;
inner_m
<
MLoop
&&
m_offset
+
inner_m
<
M
;
inner_m
++
)
{
ComputeDataType
ds
=
0
;
ComputeDataType
db
=
0
;
const
ComputeDataType
mean
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
mean_m
(
m_offset
+
inner_m
));
const
ComputeDataType
inv_std
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
inv_std_m
(
m_offset
+
inner_m
));
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ComputeDataType
dy
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
dy_m_n
(
m_offset
+
inner_m
,
n
));
const
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m_offset
+
inner_m
,
n
));
const
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
ds
+=
dy
*
gamma
*
x
;
db
+=
dy
*
gamma
;
}
ComputeDataType
b
=
(
db
*
mean
-
ds
)
*
inv_std
*
inv_std
*
inv_std
/
N
;
ComputeDataType
c
=
-
b
*
mean
-
db
*
inv_std
/
N
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ComputeDataType
dy
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
dy_m_n
(
m_offset
+
inner_m
,
n
));
const
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m_offset
+
inner_m
,
n
));
const
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
dx_m_n
(
m_offset
+
inner_m
,
n
)
=
ck_tile
::
type_convert
<
XDataType
>
(
dy
*
gamma
*
inv_std
+
b
*
x
+
c
);
}
}
};
make_ParallelTensorFunctor
(
f
,
PartM
)(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d.hpp
View file @
677a842e
...
...
@@ -10,4 +10,10 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp
0 → 100644
View file @
677a842e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
// host side args
struct
Layernorm2dBwdGammaBetaHostArgs
{
const
void
*
p_x
;
const
void
*
p_dY
;
const
void
*
p_gamma
;
const
void
*
p_mean
;
const
void
*
p_invStd
;
void
*
p_dGamma
;
void
*
p_dBeta
;
void
*
p_dX
;
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
};
// TODO: Extract some type to wrapper class
template
<
typename
Pipeline_
>
struct
Layernorm2dBwdGammaBeta
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
struct
Kargs
{
const
void
*
p_x
;
const
void
*
p_dY
;
const
void
*
p_gamma
;
const
void
*
p_mean
;
const
void
*
p_invStd
;
void
*
p_dGamma
;
void
*
p_dBeta
;
void
*
p_dX
;
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
};
using
Hargs
=
Layernorm2dBwdGammaBetaHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_dY
,
hargs
.
p_gamma
,
hargs
.
p_mean
,
hargs
.
p_invStd
,
hargs
.
p_dGamma
,
hargs
.
p_dBeta
,
hargs
.
p_dX
,
hargs
.
m
,
hargs
.
n
,
hargs
.
stride
};
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
return
(
hargs
.
m
+
Block_M
-
1
)
/
Block_M
;
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
// in byte
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Pipeline
::
GetSmemSize
();
}
CK_TILE_HOST
static
std
::
string
GetName
()
{
// clang-format off
using
S_
=
typename
Problem
::
BlockShape
;
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadN
)
n
+=
"_pn"
;
return
n
;
}();
#define _SS_ std::string
#define _TS_ std::to_string
return
_SS_
(
"layernorm2d_bwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
1
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
#undef _SS_
#undef _TS_
// clang-format on
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
const
auto
block_id
=
get_block_id
();
const
auto
iM
=
block_id
*
Block_M
;
const
auto
x_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
));
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
const
auto
dy_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
YDataType
*>
(
kargs
.
p_dY
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
));
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
const
auto
gamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
MeanDataType
*>
(
kargs
.
p_gamma
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}();
const
auto
mean_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
MeanDataType
*>
(
kargs
.
p_mean
),
make_tuple
(
kargs
.
m
),
make_tuple
(
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}();
const
auto
invstd_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
MeanDataType
*>
(
kargs
.
p_invStd
),
make_tuple
(
kargs
.
m
),
make_tuple
(
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}();
auto
dgamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
GammaDataType
*>
(
kargs
.
p_dGamma
),
make_tuple
(
gridDim
.
x
,
kargs
.
n
),
make_tuple
(
kargs
.
n
,
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
1
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
1
>
{},
number
<
Block_N
>
{}),
{
block_id
,
0
});
}();
auto
dbeta_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
BetaDataType
*>
(
kargs
.
p_dBeta
),
make_tuple
(
gridDim
.
x
,
kargs
.
n
),
make_tuple
(
kargs
.
n
,
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
1
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
1
>
{},
number
<
Block_N
>
{}),
{
block_id
,
0
});
}();
auto
dx_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
XDataType
*>
(
kargs
.
p_dX
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
dy_window
,
gamma_window
,
mean_window
,
invstd_window
,
dgamma_window
,
dbeta_window
,
dx_window
,
kargs
.
n
,
smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
0 → 100644
View file @
677a842e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
struct
Layernorm2dBwdGammaBetaPipelineDefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeXBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeMeanBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
2
,
1
>>
,
sequence
<
1
>
,
sequence
<
0
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeDGammaBetaBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
2
>
,
sequence
<
0
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBetaBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
tuple
<
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
1
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp
0 → 100644
View file @
677a842e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dBwdGammaBetaPipelineDefaultPolicy
>
struct
Layernorm2dBwdGammaBetaPipeline
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
bool
kPadM
=
false
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
return
"bwd_gamma_beta"
;
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
DGammaWindow
,
typename
DBetaWindow
,
typename
DXWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XWindow
&
dy_window_
,
const
GammaWindow
&
gamma_window_
,
const
MeanWindow
&
mean_window_
,
const
InvStdWindow
&
inv_std_window_
,
DGammaWindow
&
dgamma_window_
,
DBetaWindow
&
dbeta_window_
,
DXWindow
&
dx_window_
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
(
void
)
row_size
;
(
void
)
smem
;
auto
gamma_beta_dist
=
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>();
auto
dgamma_beta_dist
=
Policy
::
template
MakeDGammaBetaBlockTileDistribution
<
Problem
>();
auto
mean_dist
=
Policy
::
template
MakeMeanBlockTileDistribution
<
Problem
>();
auto
x_dist
=
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>();
const
auto
x_window
=
make_tile_window
(
x_window_
,
x_dist
);
const
auto
dy_window
=
make_tile_window
(
dy_window_
,
x_dist
);
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
gamma_beta_dist
);
//TO CHECK
const
auto
mean_window
=
make_tile_window
(
mean_window_
,
mean_dist
);
const
auto
inv_std_window
=
make_tile_window
(
inv_std_window_
,
mean_dist
);
const
auto
x_tile
=
load_tile
(
x_window
);
const
auto
dy_tile
=
load_tile
(
dy_window
);
const
auto
gamma_tile
=
load_tile
(
gamma_window
);
const
auto
mean_tile
=
load_tile
(
mean_window
);
const
auto
inv_std_tile
=
load_tile
(
inv_std_window
);
auto
dgamma_window
=
make_tile_window
(
dgamma_window_
,
dgamma_beta_dist
);
auto
dbeta_window
=
make_tile_window
(
dbeta_window_
,
dgamma_beta_dist
);
auto
dx_window
=
make_tile_window
(
dx_window_
,
x_dist
);
auto
dgamma_tile
=
make_static_distributed_tensor
<
GammaDataType
>
(
dgamma_beta_dist
);
auto
dbeta_tile
=
make_static_distributed_tensor
<
BetaDataType
>
(
dgamma_beta_dist
);
auto
dx_tile
=
make_static_distributed_tensor
<
XDataType
>
(
x_dist
);
auto
dgamma
=
cast_tile
<
ComputeDataType
>
(
dgamma_tile
);
auto
dbeta
=
cast_tile
<
ComputeDataType
>
(
dbeta_tile
);
auto
dx
=
cast_tile
<
XDataType
>
(
dx_tile
);
(
void
)
dx_window
;
(
void
)
dx
;
(
void
)
gamma_tile
;
sweep_tile
(
x_tile
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
//constexpr auto j_idx = make_tuple(idx[number<1>{}]);
constexpr
auto
gb_idx
=
make_tuple
(
number
<
0
>
{},
idx
[
number
<
1
>
{}]);
// auto &gamma = gamma_tile(gb_idx);
// auto &beta = beta_tile(gb_idx);
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_tile
[
idx
]);
const
auto
dy
=
type_convert
<
ComputeDataType
>
(
dy_tile
[
idx
]);
const
auto
mean
=
type_convert
<
ComputeDataType
>
(
mean_tile
[
i_idx
]);
const
auto
inv_std
=
type_convert
<
ComputeDataType
>
(
inv_std_tile
[
i_idx
]);
// beta += type_convert<BetaDataType>(dy);
// gamma += type_convert<GammaDataType>(dy * (x - mean) * inv_std);
dbeta
(
gb_idx
)
+=
dy
;
dgamma
(
gb_idx
)
+=
dy
*
(
x
-
mean
)
*
inv_std
;
// index_t tid = (threadIdx.y * blockDim.x) + threadIdx.x;
// if(blockIdx.x < 3 && blockIdx.y == 0 && tid < 3) {
// printf("bid %d tid %d count %d gb %f %f\n",blockIdx.x, tid, count, type_convert<float>(g), type_convert<float>(b));
// }
});
store_tile
(
dbeta_window
,
cast_tile
<
BetaDataType
>
(
dbeta
));
store_tile
(
dgamma_window
,
cast_tile
<
GammaDataType
>
(
dgamma
));
// store_tile(gamma_window, gamma_tile);
// store_tile(beta_window, beta_tile);
// auto ds = cast_tile<ComputeDataType>(mean_tile);
// auto db = cast_tile<ComputeDataType>(mean_tile);
// //calculate dx
// sweep_tile(x_tile, [&](auto idx)) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// constexpr auto j_idx = make_tuple(idx[number<1>{}]);
// const auto x = type_convert<ComputeDataType>(x_tile[idx]);
// const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
// const auto gamma = type_convert<ComputeDataType>(gamma_tile[j_idx]);
// // const auto mean = type_convert<ComputeDataType>(mean_tile[i_idx]);
// // const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[i_idx]);
// ds[i_idx] += dy * gamma * x;
// db[i_idx] += dy * gamma;
// }
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp
0 → 100644
View file @
677a842e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
XDataType_
,
typename
GammaDataType_
,
typename
BetaDataType_
,
typename
ComputeDataType_
,
typename
YDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
BlockShape_
,
bool
kPadN_
>
struct
Layernorm2dBwdGammaBetaPipelineProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment