Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
30e15644
Commit
30e15644
authored
Jan 22, 2025
by
AMD-dteng
Browse files
temp commit
parent
677a842e
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
466 additions
and
106 deletions
+466
-106
CMakeLists.txt
CMakeLists.txt
+1
-1
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-1
cmd
cmd
+5
-0
example/ck_tile/02_layernorm2d/generate.py
example/ck_tile/02_layernorm2d/generate.py
+3
-2
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_bf16_n64_n128_instance.cpp
...rm2d/instances/layernorm2d_bwd_bf16_n64_n128_instance.cpp
+25
-3
example/ck_tile/02_layernorm2d/layernorm2d_bwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_bwd.cpp
+21
-3
example/ck_tile/02_layernorm2d/layernorm2d_bwd.hpp
example/ck_tile/02_layernorm2d/layernorm2d_bwd.hpp
+15
-5
include/ck_tile/host/reference/reference_layernorm2d_bwd.hpp
include/ck_tile/host/reference/reference_layernorm2d_bwd.hpp
+5
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp
.../layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp
+71
-10
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+3
-1
include/ck_tile/ops/layernorm2d/pipeline/2passtmp
include/ck_tile/ops/layernorm2d/pipeline/2passtmp
+183
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
+10
-10
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp
...ernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp
+116
-69
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp
+2
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+4
-0
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+1
-1
No files found.
CMakeLists.txt
View file @
30e15644
...
@@ -521,7 +521,7 @@ include_directories(BEFORE
...
@@ -521,7 +521,7 @@ include_directories(BEFORE
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
if
(
BUILD_DEV
)
if
(
BUILD_DEV
)
add_compile_options
(
-Werror
)
#
add_compile_options(-Werror)
add_compile_options
(
-Weverything
)
add_compile_options
(
-Weverything
)
endif
()
endif
()
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
...
...
cmake/EnableCompilerWarnings.cmake
View file @
30e15644
...
@@ -66,7 +66,7 @@ else()
...
@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunreachable-code
-Wunused
-Wunused
-Wno-reserved-identifier
-Wno-reserved-identifier
-Werror
#
-Werror
-Wno-option-ignored
-Wno-option-ignored
-Wsign-compare
-Wsign-compare
-Wno-extra-semi-stmt
-Wno-extra-semi-stmt
...
...
cmd
0 → 100644
View file @
30e15644
make tile_example_layernorm2d_bwd -j 200
./bin/tile_example_layernorm2d_bwd -m=2048 -n=2048
rocprofv2 --kernel-trace -d /home/dteng/PerfProf/out -o kernel_trace
rocprofv2 -i /home/dteng/PerfProf/input.txt --plugin att auto -d /home/dteng/PerfProf/out
rocprofv2 -i /home/dteng/PerfProf/input.txt --plugin att auto --mode csv -d /home/dteng/PerfProf/out
\ No newline at end of file
example/ck_tile/02_layernorm2d/generate.py
View file @
30e15644
...
@@ -84,7 +84,8 @@ struct layernorm2d_fwd_traits_
...
@@ -84,7 +84,8 @@ struct layernorm2d_fwd_traits_
if constexpr(is_warp_per_row)
if constexpr(is_warp_per_row)
{
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
//return total_warps * (warpSize / ThreadPerBlock_N_);
return total_warps;
}
}
else
else
{
{
...
@@ -483,7 +484,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
...
@@ -483,7 +484,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
_sweep_cond
=
't.fused_quant == {f_fused_sweep} && (t.prec_sy ==
\"
{f_sy_type}
\"
)'
.
format
(
_sweep_cond
=
't.fused_quant == {f_fused_sweep} && (t.prec_sy ==
\"
{f_sy_type}
\"
)'
.
format
(
f_fused_sweep
=
ins
.
F_kFusedQuant
,
f_sy_type
=
ins
.
F_YScaleDataType
)
f_fused_sweep
=
ins
.
F_kFusedQuant
,
f_sy_type
=
ins
.
F_YScaleDataType
)
_cond
=
'((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'
.
format
(
_cond
=
'((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'
.
format
(
f_vec_n
=
ins
.
F_Vector_N
,
f_xbias
=
ins
.
F_kXbias
,
f_fused_add
=
ins
.
F_kFusedAdd
,
f_vec_n
=
1
,
f_xbias
=
ins
.
F_kXbias
,
f_fused_add
=
ins
.
F_kFusedAdd
,
f_sweep_cond
=
_sweep_cond
)
f_sweep_cond
=
_sweep_cond
)
inner_str
+=
self
.
API_INNER_CASE
.
format
(
F_if
=
get_if_str
(
idx_in_n
,
len_in_n
,
False
),
inner_str
+=
self
.
API_INNER_CASE
.
format
(
F_if
=
get_if_str
(
idx_in_n
,
len_in_n
,
False
),
F_VEC_COND
=
_cond
,
F_instance_func
=
ins
.
call_name
)
F_VEC_COND
=
_cond
,
F_instance_func
=
ins
.
call_name
)
...
...
example/ck_tile/02_layernorm2d/instances/layernorm2d_bwd_bf16_n64_n128_instance.cpp
View file @
30e15644
...
@@ -5,7 +5,29 @@
...
@@ -5,7 +5,29 @@
#include "layernorm2d_bwd_instance_common.hpp"
#include "layernorm2d_bwd_instance_common.hpp"
// clang-format off
// clang-format off
// rm tm tn pd
// rm rn tm tn vn pd
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
1
,
64
,
true
>
>
(
const
S
&
,
A
);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 64, 1, true>>(const S&, A);
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
64
,
true
>
>
(
const
S
&
,
A
);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 64, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 128, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 128, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 128, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 128, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 1, true>>(const S&, A);
// large m
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 16, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 16, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 3, 8, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 3, 8, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 4, 32, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 4, 32, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 8, 64, 4, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 8, 64, 4, 8, true>>(const S&, A);
// large n
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 32, 4, 16, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 32, 4, 16, 8, true>>(const S&, A);
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
bf16_t
,
1
,
4
,
1
,
128
,
8
,
true
>
>
(
const
S
&
,
A
);
template
float
layernorm2d_bwd_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
128
,
8
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/02_layernorm2d/layernorm2d_bwd.cpp
View file @
30e15644
...
@@ -126,6 +126,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -126,6 +126,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
dgamma_buf
.
GetDeviceBuffer
(),
dgamma_buf
.
GetDeviceBuffer
(),
dbeta_buf
.
GetDeviceBuffer
(),
dbeta_buf
.
GetDeviceBuffer
(),
dx_buf
.
GetDeviceBuffer
(),
dx_buf
.
GetDeviceBuffer
(),
//tmp
ds_buf
.
GetDeviceBuffer
(),
db_buf
.
GetDeviceBuffer
(),
m
,
m
,
n
,
n
,
stride
};
stride
};
...
@@ -155,12 +160,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -155,12 +160,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
dgamma_buf
.
FromDevice
(
dgamma_host_dev
.
data
());
dgamma_buf
.
FromDevice
(
dgamma_host_dev
.
data
());
dbeta_buf
.
FromDevice
(
dbeta_host_dev
.
data
());
dbeta_buf
.
FromDevice
(
dbeta_host_dev
.
data
());
dx_buf
.
FromDevice
(
dx_host_dev
.
data
());
//tmp
ds_buf
.
FromDevice
(
ds_host_dev
.
data
());
db_buf
.
FromDevice
(
db_host_dev
.
data
());
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
();
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
();
pass
=
ck_tile
::
check_err
(
// pass = ck_tile::check_err(
dgamma_host_dev
,
dgamma_host_ref
,
std
::
string
(
"GAMMA OUT Error: Incorrect results!"
),
rtol
,
atol
);
// dgamma_host_dev, dgamma_host_ref, std::string("GAMMA OUT Error: Incorrect results!"), rtol, atol);
// pass &= ck_tile::check_err(
// dbeta_host_dev, dbeta_host_ref, std::string("BETA OUT Error: Incorrect results!"), rtol, atol);
pass
&=
ck_tile
::
check_err
(
pass
&=
ck_tile
::
check_err
(
dbeta_host_dev
,
dbeta_host_ref
,
std
::
string
(
"BETA OUT Error: Incorrect results!"
),
rtol
,
atol
);
dx_host_dev
,
dx_host_ref
,
std
::
string
(
"DX OUT Error: Incorrect results!"
),
rtol
,
atol
);
//tmp
// pass &= ck_tile::check_err(
// ds_host_dev, ds_host_ref, std::string("DS OUT Error: Incorrect results!"), rtol, atol);
// pass &= ck_tile::check_err(
// db_host_dev, db_host_ref, std::string("DB OUT Error: Incorrect results!"), rtol, atol);
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
<<
std
::
endl
;
}
}
...
...
example/ck_tile/02_layernorm2d/layernorm2d_bwd.hpp
View file @
30e15644
...
@@ -43,8 +43,10 @@ struct layernorm2d_bwd_args : public ck_tile::Layernorm2dBwdGammaBetaHostArgs
...
@@ -43,8 +43,10 @@ struct layernorm2d_bwd_args : public ck_tile::Layernorm2dBwdGammaBetaHostArgs
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
DataType_
,
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
>
bool
kPadN_
>
struct
layernorm2d_bwd_traits_
struct
layernorm2d_bwd_traits_
{
{
...
@@ -60,7 +62,8 @@ struct layernorm2d_bwd_traits_
...
@@ -60,7 +62,8 @@ struct layernorm2d_bwd_traits_
if
constexpr
(
is_warp_per_row
)
if
constexpr
(
is_warp_per_row
)
{
{
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
static_assert
(
warpSize
%
ThreadPerBlock_N_
==
0
);
return
total_warps
*
(
warpSize
/
ThreadPerBlock_N_
);
// return total_warps * (warpSize / ThreadPerBlock_N_);
return
total_warps
;
}
}
else
else
{
{
...
@@ -84,17 +87,18 @@ struct layernorm2d_bwd_traits_
...
@@ -84,17 +87,18 @@ struct layernorm2d_bwd_traits_
}();
}();
static
constexpr
ck_tile
::
index_t
Repeat_M
=
Repeat_M_
;
static
constexpr
ck_tile
::
index_t
Repeat_M
=
Repeat_M_
;
static
constexpr
ck_tile
::
index_t
Repeat_N
=
Repeat_N_
;
static
constexpr
ck_tile
::
index_t
Block_M
=
Repeat_M_
*
ThreadPerBlock_M_
;
static
constexpr
ck_tile
::
index_t
Block_M
=
Repeat_M_
*
ThreadPerBlock_M_
;
static
constexpr
ck_tile
::
index_t
Block_N
=
ThreadPerBlock_N_
;
static
constexpr
ck_tile
::
index_t
Block_N
=
Repeat_N_
*
ThreadPerBlock_N_
*
Vector_N_
;
static
constexpr
ck_tile
::
index_t
Warp_M
=
ThreadPerBlock_M_
/
BlockWarps_M
;
static
constexpr
ck_tile
::
index_t
Warp_M
=
ThreadPerBlock_M_
/
BlockWarps_M
;
static
constexpr
ck_tile
::
index_t
Warp_N
=
ThreadPerBlock_N_
/
BlockWarps_N
;
static
constexpr
ck_tile
::
index_t
Warp_N
=
ThreadPerBlock_N_
/
BlockWarps_N
*
Vector_N_
;
using
BlockTile
=
ck_tile
::
sequence
<
Block_M
,
Block_N
>
;
using
BlockTile
=
ck_tile
::
sequence
<
Block_M
,
Block_N
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
BlockWarps_M
,
BlockWarps_N
>
;
using
BlockWarps
=
ck_tile
::
sequence
<
BlockWarps_M
,
BlockWarps_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
WarpTile
=
ck_tile
::
sequence
<
Warp_M
,
Warp_N
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
1
>
;
using
Vector
=
ck_tile
::
sequence
<
1
,
Vector_N_
>
;
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
using
Shape
=
ck_tile
::
Generic2dBlockShape
<
BlockTile
,
BlockWarps
,
WarpTile
,
Vector
>
;
...
@@ -103,13 +107,17 @@ struct layernorm2d_bwd_traits_
...
@@ -103,13 +107,17 @@ struct layernorm2d_bwd_traits_
template
<
typename
DataType_
,
template
<
typename
DataType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
ThreadPerBlock_N_
,
// num threads along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
>
bool
kPadN_
>
using
trait_
=
layernorm2d_bwd_traits_
<
DataType_
,
using
trait_
=
layernorm2d_bwd_traits_
<
DataType_
,
Repeat_M_
,
Repeat_M_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_M_
,
ThreadPerBlock_N_
,
ThreadPerBlock_N_
,
Vector_N_
,
kPadN_
>
;
kPadN_
>
;
template
<
typename
Traits_
>
template
<
typename
Traits_
>
...
@@ -126,7 +134,9 @@ template <typename data_type>
...
@@ -126,7 +134,9 @@ template <typename data_type>
struct
layernorm2d_bwd_b16_
struct
layernorm2d_bwd_b16_
{
{
/* data */
/* data */
using
Trait
=
trait_
<
data_type
,
1
,
1
,
64
,
true
>
;
//using Trait = trait_<data_type, 1, 1, 1, 256, 1, true>;
//using Trait = trait_<data_type, 1, 8, 64, 4, 8, true>;
using
Trait
=
trait_
<
data_type
,
1
,
4
,
1
,
128
,
8
,
true
>
;
float
operator
()
(
layernorm2d_bwd_traits
/*t*/
,
float
operator
()
(
layernorm2d_bwd_traits
/*t*/
,
layernorm2d_bwd_args
a
,
layernorm2d_bwd_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
const
ck_tile
::
stream_config
&
s
)
{
...
...
include/ck_tile/host/reference/reference_layernorm2d_bwd.hpp
View file @
30e15644
...
@@ -48,6 +48,7 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
...
@@ -48,6 +48,7 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
const
ComputeDataType
dy
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
dy_m_n
(
m_offset
+
inner_m
,
n
));
const
ComputeDataType
dy
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
dy_m_n
(
m_offset
+
inner_m
,
n
));
gamma_acc
+=
dy
*
(
x
-
mean
)
*
inv_std
;
gamma_acc
+=
dy
*
(
x
-
mean
)
*
inv_std
;
beta_acc
+=
dy
;
beta_acc
+=
dy
;
//printf("\ndteng print---dy[%d][%d]=%f\n",m_offset + inner_m,n,dy);
}
}
dgamma_mpart_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
GammaDataType
>
(
gamma_acc
);
dgamma_mpart_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
GammaDataType
>
(
gamma_acc
);
...
@@ -69,14 +70,18 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
...
@@ -69,14 +70,18 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
ds
+=
dy
*
gamma
*
x
;
ds
+=
dy
*
gamma
*
x
;
db
+=
dy
*
gamma
;
db
+=
dy
*
gamma
;
}
}
ds_m
(
m_offset
+
inner_m
)
=
ds
;
db_m
(
m_offset
+
inner_m
)
=
db
;
ComputeDataType
b
=
(
db
*
mean
-
ds
)
*
inv_std
*
inv_std
*
inv_std
/
N
;
ComputeDataType
b
=
(
db
*
mean
-
ds
)
*
inv_std
*
inv_std
*
inv_std
/
N
;
ComputeDataType
c
=
-
b
*
mean
-
db
*
inv_std
/
N
;
ComputeDataType
c
=
-
b
*
mean
-
db
*
inv_std
/
N
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
const
ComputeDataType
dy
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
dy_m_n
(
m_offset
+
inner_m
,
n
));
const
ComputeDataType
dy
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
dy_m_n
(
m_offset
+
inner_m
,
n
));
const
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m_offset
+
inner_m
,
n
));
const
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m_offset
+
inner_m
,
n
));
const
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
const
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
dx_m_n
(
m_offset
+
inner_m
,
n
)
=
ck_tile
::
type_convert
<
XDataType
>
(
dy
*
gamma
*
inv_std
+
b
*
x
+
c
);
dx_m_n
(
m_offset
+
inner_m
,
n
)
=
ck_tile
::
type_convert
<
XDataType
>
(
dy
*
gamma
*
inv_std
+
b
*
x
+
c
);
//printf("\ndteng print---dx[%d][%d]=%f\n",m_offset + inner_m,n,ck_tile::type_convert<ComputeDataType>(dx_m_n(m_offset + inner_m, n)));
}
}
}
}
};
};
...
...
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp
View file @
30e15644
...
@@ -21,6 +21,10 @@ struct Layernorm2dBwdGammaBetaHostArgs
...
@@ -21,6 +21,10 @@ struct Layernorm2dBwdGammaBetaHostArgs
void
*
p_dBeta
;
void
*
p_dBeta
;
void
*
p_dX
;
void
*
p_dX
;
//tmp
void
*
p_dS
;
void
*
p_dB
;
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
stride
;
// row_stride
...
@@ -43,6 +47,7 @@ struct Layernorm2dBwdGammaBeta
...
@@ -43,6 +47,7 @@ struct Layernorm2dBwdGammaBeta
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
...
@@ -63,6 +68,10 @@ struct Layernorm2dBwdGammaBeta
...
@@ -63,6 +68,10 @@ struct Layernorm2dBwdGammaBeta
void
*
p_dBeta
;
void
*
p_dBeta
;
void
*
p_dX
;
void
*
p_dX
;
//tmp
void
*
p_dS
;
void
*
p_dB
;
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
stride
;
// row_stride
...
@@ -79,6 +88,11 @@ struct Layernorm2dBwdGammaBeta
...
@@ -79,6 +88,11 @@ struct Layernorm2dBwdGammaBeta
hargs
.
p_dGamma
,
hargs
.
p_dGamma
,
hargs
.
p_dBeta
,
hargs
.
p_dBeta
,
hargs
.
p_dX
,
hargs
.
p_dX
,
//tmp
hargs
.
p_dS
,
hargs
.
p_dB
,
hargs
.
m
,
hargs
.
m
,
hargs
.
n
,
hargs
.
n
,
hargs
.
stride
};
hargs
.
stride
};
...
@@ -128,11 +142,17 @@ struct Layernorm2dBwdGammaBeta
...
@@ -128,11 +142,17 @@ struct Layernorm2dBwdGammaBeta
const
auto
block_id
=
get_block_id
();
const
auto
block_id
=
get_block_id
();
const
auto
iM
=
block_id
*
Block_M
;
const
auto
iM
=
block_id
*
Block_M
;
// if(threadIdx.x == 0 && blockIdx.x == 0){
// printf("dteng block shape---WarpPerBlock_M=%d, WarpPerBlock_N=%d, ThreadPerWarp_M=%d, ThreadPerWarp_N=%d, Vector_N=%d\n", static_cast<int>(Problem::BlockShape::WarpPerBlock_M), static_cast<int>(Problem::BlockShape::WarpPerBlock_N), static_cast<int>(Problem::BlockShape::ThreadPerWarp_M), static_cast<int>(Problem::BlockShape::ThreadPerWarp_N), static_cast<int>(Problem::BlockShape::Vector_N));
// }
const
auto
x_window
=
[
&
]()
{
const
auto
x_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
));
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
// check the max count dynamically
...
@@ -146,7 +166,9 @@ struct Layernorm2dBwdGammaBeta
...
@@ -146,7 +166,9 @@ struct Layernorm2dBwdGammaBeta
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
YDataType
*>
(
kargs
.
p_dY
),
static_cast
<
const
YDataType
*>
(
kargs
.
p_dY
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
));
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
// check the max count dynamically
...
@@ -160,7 +182,9 @@ struct Layernorm2dBwdGammaBeta
...
@@ -160,7 +182,9 @@ struct Layernorm2dBwdGammaBeta
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
MeanDataType
*>
(
kargs
.
p_gamma
),
static_cast
<
const
MeanDataType
*>
(
kargs
.
p_gamma
),
make_tuple
(
kargs
.
n
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
));
make_tuple
(
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
...
@@ -175,7 +199,7 @@ struct Layernorm2dBwdGammaBeta
...
@@ -175,7 +199,7 @@ struct Layernorm2dBwdGammaBeta
make_tuple
(
1
));
make_tuple
(
1
));
const
auto
tmp2_
=
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
false
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}();
}();
...
@@ -187,7 +211,7 @@ struct Layernorm2dBwdGammaBeta
...
@@ -187,7 +211,7 @@ struct Layernorm2dBwdGammaBeta
make_tuple
(
1
));
make_tuple
(
1
));
const
auto
tmp2_
=
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
false
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}();
}();
...
@@ -196,7 +220,9 @@ struct Layernorm2dBwdGammaBeta
...
@@ -196,7 +220,9 @@ struct Layernorm2dBwdGammaBeta
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
GammaDataType
*>
(
kargs
.
p_dGamma
),
static_cast
<
GammaDataType
*>
(
kargs
.
p_dGamma
),
make_tuple
(
gridDim
.
x
,
kargs
.
n
),
make_tuple
(
gridDim
.
x
,
kargs
.
n
),
make_tuple
(
kargs
.
n
,
1
));
make_tuple
(
kargs
.
n
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
1
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
kPadN
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
1
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
kPadN
>
{});
...
@@ -208,7 +234,9 @@ struct Layernorm2dBwdGammaBeta
...
@@ -208,7 +234,9 @@ struct Layernorm2dBwdGammaBeta
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
BetaDataType
*>
(
kargs
.
p_dBeta
),
static_cast
<
BetaDataType
*>
(
kargs
.
p_dBeta
),
make_tuple
(
gridDim
.
x
,
kargs
.
n
),
make_tuple
(
gridDim
.
x
,
kargs
.
n
),
make_tuple
(
kargs
.
n
,
1
));
make_tuple
(
kargs
.
n
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
1
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
kPadN
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
1
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
kPadN
>
{});
...
@@ -219,14 +247,42 @@ struct Layernorm2dBwdGammaBeta
...
@@ -219,14 +247,42 @@ struct Layernorm2dBwdGammaBeta
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
XDataType
*>
(
kargs
.
p_dX
),
static_cast
<
XDataType
*>
(
kargs
.
p_dX
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
));
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
false
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
}();
__shared__
char
smem
[
GetSmemSize
()];
//tmp
const
auto
ds_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
ComputeDataType
*>
(
kargs
.
p_dS
),
make_tuple
(
kargs
.
m
),
make_tuple
(
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}();
const
auto
db_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
ComputeDataType
*>
(
kargs
.
p_dB
),
make_tuple
(
kargs
.
m
),
make_tuple
(
1
));
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}();
// __shared__ char smem[GetSmemSize()];
__shared__
char
smem
[
0
];
Pipeline
{}(
x_window
,
Pipeline
{}(
x_window
,
dy_window
,
dy_window
,
...
@@ -236,6 +292,11 @@ struct Layernorm2dBwdGammaBeta
...
@@ -236,6 +292,11 @@ struct Layernorm2dBwdGammaBeta
dgamma_window
,
dgamma_window
,
dbeta_window
,
dbeta_window
,
dx_window
,
dx_window
,
//tmp
ds_window
,
db_window
,
kargs
.
n
,
kargs
.
n
,
smem
);
smem
);
}
}
...
...
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
30e15644
...
@@ -192,7 +192,9 @@ struct Layernorm2dFwd
...
@@ -192,7 +192,9 @@ struct Layernorm2dFwd
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
{
const
auto
iM
=
get_block_id
()
*
Block_M
;
const
auto
iM
=
get_block_id
()
*
Block_M
;
// if(threadIdx.x == 0 && blockIdx.x == 0){
// printf("dteng block shape---WarpPerBlock_M=%d, WarpPerBlock_N=%d, ThreadPerWarp_M=%d, ThreadPerWarp_N=%d, Vector_N=%d\n", static_cast<int>(Problem::BlockShape::WarpPerBlock_M), static_cast<int>(Problem::BlockShape::WarpPerBlock_N), static_cast<int>(Problem::BlockShape::ThreadPerWarp_M), static_cast<int>(Problem::BlockShape::ThreadPerWarp_N), static_cast<int>(Problem::BlockShape::Vector_N));
// }
const
auto
x_window
=
[
&
]()
{
const
auto
x_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
...
...
include/ck_tile/ops/layernorm2d/pipeline/2passtmp
0 → 100644
View file @
30e15644
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dBwdGammaBetaPipelineDefaultPolicy>
struct Layernorm2dBwdGammaBetaPipeline
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using ReducePolicy = ck_tile::remove_cvref_t<BlockReduce2dDefaultPolicy>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
static constexpr bool kPadM = false;
static constexpr bool kPadN = Problem::kPadN;
static constexpr const char* name = []() { return "bwd_gamma_beta"; }();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow,
typename GammaWindow,
typename MeanWindow,
typename InvStdWindow,
typename DGammaWindow,
typename DBetaWindow,
typename DXWindow,
// tmp
typename DSWindow,
typename DBWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XWindow& dy_window_,
const GammaWindow& gamma_window_,
const MeanWindow& mean_window_,
const InvStdWindow& inv_std_window_,
DGammaWindow& dgamma_window_,
DBetaWindow& dbeta_window_,
DXWindow& dx_window_,
// tmp
DSWindow& ds_window_,
DBWindow& db_window_,
ck_tile::index_t row_size,
void* smem) const
{
(void)smem;
auto gamma_beta_dist = Policy::template MakeGammaBetaBlockTileDistribution<Problem>();
auto dgamma_beta_dist = Policy::template MakeDGammaBetaBlockTileDistribution<Problem>();
auto mean_dist = Policy::template MakeMeanBlockTileDistribution<Problem>();
auto x_dist = Policy::template MakeXBlockTileDistribution<Problem>();
const auto x_window = make_tile_window(x_window_, x_dist);
const auto dy_window = make_tile_window(dy_window_, x_dist);
const auto gamma_window = make_tile_window(gamma_window_, gamma_beta_dist); // TO CHECK
const auto mean_window = make_tile_window(mean_window_, mean_dist);
const auto inv_std_window = make_tile_window(inv_std_window_, mean_dist);
auto dgamma_window = make_tile_window(dgamma_window_, dgamma_beta_dist);
auto dbeta_window = make_tile_window(dbeta_window_, dgamma_beta_dist);
auto dx_window = make_tile_window(dx_window_, x_dist);
const auto mean_tile = load_tile(mean_window);
const auto inv_std_tile = load_tile(inv_std_window);
// tmp
(void)ds_window_;
(void)db_window_;
//auto ds_window = make_tile_window(ds_window_, mean_dist);
//auto db_window = make_tile_window(db_window_, mean_dist);
auto ds_tile = make_static_distributed_tensor<ComputeDataType>(mean_dist);
auto db_tile = make_static_distributed_tensor<ComputeDataType>(mean_dist);
clear_tile(ds_tile);
clear_tile(db_tile);
auto dgamma_tile = make_static_distributed_tensor<GammaDataType>(dgamma_beta_dist);
auto dbeta_tile = make_static_distributed_tensor<BetaDataType>(dgamma_beta_dist);
auto dx_tile = make_static_distributed_tensor<XDataType>(x_dist);
auto dgamma = cast_tile<ComputeDataType>(dgamma_tile);
auto dbeta = cast_tile<ComputeDataType>(dbeta_tile);
auto dx = cast_tile<ComputeDataType>(dx_tile);
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x_tile = load_tile(x_window);
const auto dy_tile = load_tile(dy_window);
const auto gamma_tile = load_tile(gamma_window);
move_tile_window(x_window, {0, Block_N});
move_tile_window(dy_window, {0, Block_N});
move_tile_window(gamma_window, {Block_N});
sweep_tile(x_tile, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto x = type_convert<ComputeDataType>(x_tile[idx]);
const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
const auto gamma = type_convert<ComputeDataType>(gamma_tile[j_idx]);
ds_tile(i_idx) += dy * gamma * x;
db_tile(i_idx) += dy * gamma;
// printf("threadidx=%d, blockidx=%d, ds_tile=%f\n",threadIdx.x, blockIdx.x, ds_tile[i_idx]);
});
}
auto block_reduce2d_sync = ReducePolicy::template GetBlockReduce2dSync<Problem>();
block_reduce2d_sync(ds_tile, ck_tile::ReduceOp::Add{});
block_reduce2d_sync(db_tile, ck_tile::ReduceOp::Add{});
// sweep_tile(x_tile, [&](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// printf("post::threadidx=%d, blockidx=%d, ds_tile=%f\n",threadIdx.x, blockIdx.x,
// ds_tile[i_idx]);
// });
//store_tile(ds_window, ds_tile);
//store_tile(db_window, db_tile);
ck_tile::index_t stride_to_right_most_window = row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N});
move_tile_window(dy_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(dx_window, {0, stride_to_right_most_window});
move_tile_window(dbeta_window, {0, stride_to_right_most_window});
move_tile_window(dgamma_window, {0, stride_to_right_most_window});
using XDistributedTensor = decltype(load_tile(x_window));
constexpr auto spans = XDistributedTensor::get_distributed_spans();
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
sweep_tile_span(spans[number<0>{}], [&](auto i_idx) {
constexpr auto idx0 = make_tuple(i_idx);
const auto mean = type_convert<ComputeDataType>(mean_tile[idx0]);
const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[idx0]);
auto b = (db_tile[idx0] * mean - ds_tile[idx0]) * inv_std * inv_std * inv_std / row_size;
auto c = -b * mean - db_tile[idx0] * inv_std / row_size;
sweep_tile_span(spans[number<1>{}], [&](auto j_idx) {
constexpr auto idx = make_tuple(i_idx, j_idx);
constexpr auto gb_idx = make_tuple(number<0>{}, j_idx);
const auto x = type_convert<ComputeDataType>(x_tile[idx]);
const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
const auto gamma = type_convert<ComputeDataType>(gamma_tile[idx]);
dbeta(gb_idx) += dy;
dgamma(gb_idx) += dy * (x - mean) * inv_std;
dx(idx) = dy * gamma * inv_std + b * x + c;
});
});
store_tile(dbeta_window, cast_tile<BetaDataType>(dbeta));
store_tile(dgamma_window, cast_tile<GammaDataType>(dgamma));
store_tile(dx_window, cast_tile<XDataType>(dx));
move_tile_window(x_window, {0, -Block_N});
move_tile_window(dy_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(dx_window, {0, -Block_N});
move_tile_window(dbeta_window, {0, -Block_N});
move_tile_window(dgamma_window, {0, -Block_N});
}
}
};
} // namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp
View file @
30e15644
...
@@ -17,12 +17,12 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
...
@@ -17,12 +17,12 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
,
S
::
Vector_M
>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
>>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
0
,
0
>>
{});
sequence
<
0
,
3
,
0
,
3
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeMeanBlockTileDistribution
()
CK_TILE_DEVICE
static
constexpr
auto
MakeMeanBlockTileDistribution
()
...
@@ -32,11 +32,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
...
@@ -32,11 +32,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
>
,
sequence
<
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
,
S
::
Vector_M
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
2
,
1
>>
,
sequence
<
1
>
,
sequence
<
1
,
1
>
,
sequence
<
0
>>
{});
sequence
<
0
,
3
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -48,11 +48,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
...
@@ -48,11 +48,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<>
,
sequence
<>
,
tuple
<
sequence
<
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
tuple
<
sequence
<
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
>>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
2
>
,
sequence
<
2
,
2
>
,
sequence
<
0
>>
{});
sequence
<
0
,
3
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp
View file @
30e15644
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include <string>
#include <string>
#include <type_traits>
#include <type_traits>
...
@@ -13,8 +14,9 @@ namespace ck_tile {
...
@@ -13,8 +14,9 @@ namespace ck_tile {
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dBwdGammaBetaPipelineDefaultPolicy
>
template
<
typename
Problem_
,
typename
Policy_
=
Layernorm2dBwdGammaBetaPipelineDefaultPolicy
>
struct
Layernorm2dBwdGammaBetaPipeline
struct
Layernorm2dBwdGammaBetaPipeline
{
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
ReducePolicy
=
ck_tile
::
remove_cvref_t
<
BlockReduce2dDefaultPolicy
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
...
@@ -24,16 +26,15 @@ struct Layernorm2dBwdGammaBetaPipeline
...
@@ -24,16 +26,15 @@ struct Layernorm2dBwdGammaBetaPipeline
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
static
constexpr
bool
kPadM
=
false
;
static
constexpr
bool
kPadM
=
false
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
static
constexpr
const
char
*
name
=
[]()
{
return
"bwd_gamma_beta"
;
}();
return
"bwd_gamma_beta"
;
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
ReducePolicy
::
template
GetSmemSize
<
Problem
>();
//GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
}
}
template
<
typename
XWindow
,
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
GammaWindow
,
...
@@ -41,7 +42,11 @@ struct Layernorm2dBwdGammaBetaPipeline
...
@@ -41,7 +42,11 @@ struct Layernorm2dBwdGammaBetaPipeline
typename
InvStdWindow
,
typename
InvStdWindow
,
typename
DGammaWindow
,
typename
DGammaWindow
,
typename
DBetaWindow
,
typename
DBetaWindow
,
typename
DXWindow
>
typename
DXWindow
,
// tmp
typename
DSWindow
,
typename
DBWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XWindow
&
dy_window_
,
const
XWindow
&
dy_window_
,
const
GammaWindow
&
gamma_window_
,
const
GammaWindow
&
gamma_window_
,
...
@@ -50,83 +55,125 @@ struct Layernorm2dBwdGammaBetaPipeline
...
@@ -50,83 +55,125 @@ struct Layernorm2dBwdGammaBetaPipeline
DGammaWindow
&
dgamma_window_
,
DGammaWindow
&
dgamma_window_
,
DBetaWindow
&
dbeta_window_
,
DBetaWindow
&
dbeta_window_
,
DXWindow
&
dx_window_
,
DXWindow
&
dx_window_
,
// tmp
DSWindow
&
ds_window_
,
DBWindow
&
db_window_
,
ck_tile
::
index_t
row_size
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
)
const
{
{
(
void
)
row_size
;
(
void
)
smem
;
auto
gamma_beta_dist
=
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>();
auto
gamma_beta_dist
=
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>();
auto
dgamma_beta_dist
=
Policy
::
template
MakeDGammaBetaBlockTileDistribution
<
Problem
>();
auto
dgamma_beta_dist
=
Policy
::
template
MakeDGammaBetaBlockTileDistribution
<
Problem
>();
auto
mean_dist
=
Policy
::
template
MakeMeanBlockTileDistribution
<
Problem
>();
auto
mean_dist
=
Policy
::
template
MakeMeanBlockTileDistribution
<
Problem
>();
auto
x_dist
=
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>();
auto
x_dist
=
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>();
const
auto
x_window
=
make_tile_window
(
x_window_
,
x_dist
);
const
auto
x_window
=
make_tile_window
(
x_window_
,
x_dist
);
const
auto
dy_window
=
make_tile_window
(
dy_window_
,
x_dist
);
const
auto
dy_window
=
make_tile_window
(
dy_window_
,
x_dist
);
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
gamma_beta_dist
);
//TO CHECK
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
gamma_beta_dist
);
//
TO CHECK
const
auto
mean_window
=
make_tile_window
(
mean_window_
,
mean_dist
);
const
auto
mean_window
=
make_tile_window
(
mean_window_
,
mean_dist
);
const
auto
inv_std_window
=
make_tile_window
(
inv_std_window_
,
mean_dist
);
const
auto
inv_std_window
=
make_tile_window
(
inv_std_window_
,
mean_dist
);
const
auto
x_tile
=
load_tile
(
x_window
);
const
auto
dy_tile
=
load_tile
(
dy_window
);
const
auto
gamma_tile
=
load_tile
(
gamma_window
);
const
auto
mean_tile
=
load_tile
(
mean_window
);
const
auto
inv_std_tile
=
load_tile
(
inv_std_window
);
auto
dgamma_window
=
make_tile_window
(
dgamma_window_
,
dgamma_beta_dist
);
auto
dgamma_window
=
make_tile_window
(
dgamma_window_
,
dgamma_beta_dist
);
auto
dbeta_window
=
make_tile_window
(
dbeta_window_
,
dgamma_beta_dist
);
auto
dbeta_window
=
make_tile_window
(
dbeta_window_
,
dgamma_beta_dist
);
auto
dx_window
=
make_tile_window
(
dx_window_
,
x_dist
);
auto
dx_window
=
make_tile_window
(
dx_window_
,
x_dist
);
auto
dgamma_tile
=
make_static_distributed_tensor
<
GammaDataType
>
(
dgamma_beta_dist
);
auto
dbeta_tile
=
make_static_distributed_tensor
<
BetaDataType
>
(
dgamma_beta_dist
);
const
auto
x_tile
=
load_tile
(
x_window
);
auto
dx_tile
=
make_static_distributed_tensor
<
XDataType
>
(
x_dist
);
const
auto
dy_tile
=
load_tile
(
dy_window
);
auto
dgamma
=
cast_tile
<
ComputeDataType
>
(
dgamma_tile
);
const
auto
gamma_tile
=
load_tile
(
gamma_window
);
auto
dbeta
=
cast_tile
<
ComputeDataType
>
(
dbeta_tile
);
const
auto
mean_tile
=
load_tile
(
mean_window
);
auto
dx
=
cast_tile
<
XDataType
>
(
dx_tile
);
const
auto
inv_std_tile
=
load_tile
(
inv_std_window
);
// tmp
auto
ds_window
=
make_tile_window
(
ds_window_
,
mean_dist
);
auto
db_window
=
make_tile_window
(
db_window_
,
mean_dist
);
auto
ds_tile
=
make_static_distributed_tensor
<
ComputeDataType
>
(
mean_dist
);
auto
db_tile
=
make_static_distributed_tensor
<
ComputeDataType
>
(
mean_dist
);
clear_tile
(
ds_tile
);
clear_tile
(
db_tile
);
// (void)ds_window;
// (void)db_window;
// auto dgamma_tile = make_static_distributed_tensor<GammaDataType>(dgamma_beta_dist);
// auto dbeta_tile = make_static_distributed_tensor<BetaDataType>(dgamma_beta_dist);
auto
dx_tile
=
make_static_distributed_tensor
<
XDataType
>
(
x_dist
);
// auto dgamma = cast_tile<ComputeDataType>(dgamma_tile);
// auto dbeta = cast_tile<ComputeDataType>(dbeta_tile);
auto
dx
=
cast_tile
<
ComputeDataType
>
(
dx_tile
);
(
void
)
dx_window
;
// auto gen_ones = [](ck_tile::index_t size) -> uint64_t {
(
void
)
dx
;
// if (size <= 0) return 0;
(
void
)
gamma_tile
;
// if (size >= 64) return 0xFFFFFFFFFFFFFFFF;
// return (1ULL << size) - 1;
// };
// uint64_t lane_en = gen_ones(row_size);
// printf("lane en is %lu", lane_en);
// //uint64_t lane_en = (1ULL << row_size) - 1;
// asm volatile("s_mov_b64 exec, %[s_lane_en]"
// :
// : [s_lane_en]"s"(lane_en)
// : );
sweep_tile
(
x_tile
,
[
&
](
auto
idx
)
{
sweep_tile
(
x_tile
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
//constexpr auto j_idx = make_tuple(idx[number<1>{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
constexpr
auto
gb_idx
=
make_tuple
(
number
<
0
>
{},
idx
[
number
<
1
>
{}]);
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_tile
[
idx
]);
// auto &gamma = gamma_tile(gb_idx);
const
auto
dy
=
type_convert
<
ComputeDataType
>
(
dy_tile
[
idx
]);
// auto &beta = beta_tile(gb_idx);
const
auto
gamma
=
type_convert
<
ComputeDataType
>
(
gamma_tile
[
j_idx
]);
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_tile
[
idx
]);
ds_tile
(
i_idx
)
+=
dy
*
gamma
*
x
;
const
auto
dy
=
type_convert
<
ComputeDataType
>
(
dy_tile
[
idx
]);
db_tile
(
i_idx
)
+=
dy
*
gamma
;
const
auto
mean
=
type_convert
<
ComputeDataType
>
(
mean_tile
[
i_idx
]);
// printf("db_tile pre: threadidx=%d, blockidx=%d, db_tile=%f\n",threadIdx.x, blockIdx.x, db_tile[i_idx]);
const
auto
inv_std
=
type_convert
<
ComputeDataType
>
(
inv_std_tile
[
i_idx
]);
// printf("dy_tile: threadidx=%d, blockidx=%d, dy_tile=%f\n",threadIdx.x, blockIdx.x, dy);
// beta += type_convert<BetaDataType>(dy);
// printf("x: threadidx=%d, blockidx=%d, x_tile=%f\n",threadIdx.x, blockIdx.x, x);
// gamma += type_convert<GammaDataType>(dy * (x - mean) * inv_std);
// printf("gamma: threadidx=%d, blockidx=%d, gamma_tile=%f\n",threadIdx.x, blockIdx.x, gamma);
dbeta
(
gb_idx
)
+=
dy
;
dgamma
(
gb_idx
)
+=
dy
*
(
x
-
mean
)
*
inv_std
;
// index_t tid = (threadIdx.y * blockDim.x) + threadIdx.x;
// if(blockIdx.x < 3 && blockIdx.y == 0 && tid < 3) {
// printf("bid %d tid %d count %d gb %f %f\n",blockIdx.x, tid, count, type_convert<float>(g), type_convert<float>(b));
// }
});
});
store_tile
(
dbeta_window
,
cast_tile
<
BetaDataType
>
(
dbeta
));
store_tile
(
dgamma_window
,
cast_tile
<
GammaDataType
>
(
dgamma
));
// store_tile(gamma_window, gamma_tile);
// store_tile(beta_window, beta_tile);
// auto ds = cast_tile<ComputeDataType>(mean_tile);
auto
block_reduce2d_sync
=
ReducePolicy
::
template
GetBlockReduce2dSync
<
Problem
>();
// auto db = cast_tile<ComputeDataType>(mean_tile);
auto
block_reduce2d_cross_warp_sync
=
ReducePolicy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
// //calculate dx
block_reduce2d_sync
(
ds_tile
,
ck_tile
::
ReduceOp
::
Add
{});
// sweep_tile(x_tile, [&](auto idx)) {
block_reduce2d_sync
(
db_tile
,
ck_tile
::
ReduceOp
::
Add
{});
// block_reduce2d_cross_warp_sync(ds_tile, smem, ck_tile::ReduceOp::Add{});
// block_reduce2d_cross_warp_sync(db_tile, smem, ck_tile::ReduceOp::Add{});
// sweep_tile(x_tile, [&](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// constexpr auto j_idx = make_tuple(idx[number<1>{}]);
// printf("db_tile post: threadidx=%d, blockidx=%d, db_tile=%f\n",threadIdx.x, blockIdx.x,
// db_tile[i_idx]);
// });
// store_tile(ds_window, ds_tile);
// store_tile(db_window, db_tile);
// const auto x = type_convert<ComputeDataType>(x_tile[idx]);
using
XDistributedTensor
=
decltype
(
load_tile
(
x_window
));
// const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
constexpr
auto
spans
=
XDistributedTensor
::
get_distributed_spans
();
// const auto gamma = type_convert<ComputeDataType>(gamma_tile[j_idx]);
// // const auto mean = type_convert<ComputeDataType>(mean_tile[i_idx]);
// // const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[i_idx]);
// ds[i_idx] += dy * gamma * x;
// db[i_idx] += dy * gamma;
// }
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
i_idx
)
{
constexpr
auto
idx0
=
make_tuple
(
i_idx
);
const
auto
mean
=
type_convert
<
ComputeDataType
>
(
mean_tile
[
idx0
]);
const
auto
inv_std
=
type_convert
<
ComputeDataType
>
(
inv_std_tile
[
idx0
]);
auto
b
=
(
db_tile
[
idx0
]
*
mean
-
ds_tile
[
idx0
])
*
inv_std
*
inv_std
*
inv_std
/
row_size
;
auto
c
=
-
b
*
mean
-
db_tile
[
idx0
]
*
inv_std
/
row_size
;
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
j_idx
)
{
constexpr
auto
idx1
=
make_tuple
(
j_idx
);
constexpr
auto
idx
=
make_tuple
(
i_idx
,
j_idx
);
//constexpr auto gb_idx = make_tuple(number<0>{}, j_idx);
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_tile
[
idx
]);
const
auto
dy
=
type_convert
<
ComputeDataType
>
(
dy_tile
[
idx
]);
const
auto
gamma
=
type_convert
<
ComputeDataType
>
(
gamma_tile
[
idx1
]);
// dbeta(gb_idx) += dy;
// dgamma(gb_idx) += dy * (x - mean) * inv_std;
dx
(
idx
)
=
dy
*
gamma
*
inv_std
+
b
*
x
+
c
;
//printf("dx: threadidx=%d, blockidx=%d, dx_tile=%f\n",threadIdx.x, blockIdx.x, dx(idx));
});
});
// store_tile(dbeta_window, cast_tile<BetaDataType>(dbeta));
// store_tile(dgamma_window, cast_tile<GammaDataType>(dgamma));
store_tile
(
dx_window
,
cast_tile
<
XDataType
>
(
dx
));
}
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp
View file @
30e15644
...
@@ -28,6 +28,8 @@ struct Layernorm2dBwdGammaBetaPipelineProblem
...
@@ -28,6 +28,8 @@ struct Layernorm2dBwdGammaBetaPipelineProblem
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
30e15644
...
@@ -133,7 +133,10 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -133,7 +133,10 @@ struct Layernorm2dFwdPipelineOnePass
{
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
// compute x = x_resi + x
//printf("x: threadidx=%d, blockidx=%d, acc=%f\n",threadIdx.x, blockIdx.x, x(idx));
// printf("acc pre: threadidx=%d, blockidx=%d, acc=%f\n",threadIdx.x, blockIdx.x, acc(idx));
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
// printf("acc post: threadidx=%d, blockidx=%d, acc=%f\n",threadIdx.x, blockIdx.x, acc(idx));
});
});
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
...
@@ -184,6 +187,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -184,6 +187,7 @@ struct Layernorm2dFwdPipelineOnePass
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
auto
ln_
=
(
acc
[
idx
]
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
auto
ln_
=
(
acc
[
idx
]
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
// printf("ln: threadidx=%d, blockidx=%d, acc=%f\n",threadIdx.x, blockIdx.x, ln_);
ln
(
idx
)
=
ln_
;
ln
(
idx
)
=
ln_
;
});
});
...
...
script/cmake-ck-dev.sh
View file @
30e15644
...
@@ -17,7 +17,7 @@ fi
...
@@ -17,7 +17,7 @@ fi
cmake
\
cmake
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm/
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm/
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker"
\
-D
CMAKE_CXX_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker
--save-temps
"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
$GPU_TARGETS
\
-D
GPU_TARGETS
=
$GPU_TARGETS
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment