Commit 30e15644 authored by AMD-dteng's avatar AMD-dteng
Browse files

temp commit

parent 677a842e
...@@ -521,7 +521,7 @@ include_directories(BEFORE ...@@ -521,7 +521,7 @@ include_directories(BEFORE
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV) if(BUILD_DEV)
add_compile_options(-Werror) #add_compile_options(-Werror)
add_compile_options(-Weverything) add_compile_options(-Weverything)
endif() endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
......
...@@ -66,7 +66,7 @@ else() ...@@ -66,7 +66,7 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-Werror #-Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
......
make tile_example_layernorm2d_bwd -j 200
./bin/tile_example_layernorm2d_bwd -m=2048 -n=2048
rocprofv2 --kernel-trace -d /home/dteng/PerfProf/out -o kernel_trace
rocprofv2 -i /home/dteng/PerfProf/input.txt --plugin att auto -d /home/dteng/PerfProf/out
rocprofv2 -i /home/dteng/PerfProf/input.txt --plugin att auto --mode csv -d /home/dteng/PerfProf/out
\ No newline at end of file
...@@ -84,7 +84,8 @@ struct layernorm2d_fwd_traits_ ...@@ -84,7 +84,8 @@ struct layernorm2d_fwd_traits_
if constexpr(is_warp_per_row) if constexpr(is_warp_per_row)
{ {
static_assert(warpSize % ThreadPerBlock_N_ == 0); static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_); //return total_warps * (warpSize / ThreadPerBlock_N_);
return total_warps;
} }
else else
{ {
...@@ -483,7 +484,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -483,7 +484,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
_cond = '((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( _cond = '((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
f_vec_n = ins.F_Vector_N, f_xbias = ins.F_kXbias, f_fused_add = ins.F_kFusedAdd, f_vec_n = 1, f_xbias = ins.F_kXbias, f_fused_add = ins.F_kFusedAdd,
f_sweep_cond = _sweep_cond) f_sweep_cond = _sweep_cond)
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
F_VEC_COND = _cond, F_instance_func=ins.call_name) F_VEC_COND = _cond, F_instance_func=ins.call_name)
......
...@@ -5,7 +5,29 @@ ...@@ -5,7 +5,29 @@
#include "layernorm2d_bwd_instance_common.hpp" #include "layernorm2d_bwd_instance_common.hpp"
// clang-format off // clang-format off
// rm tm tn pd // rm rn tm tn vn pd
template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 64, true>>(const S&, A); // template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 64, 1, true>>(const S&, A);
template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 64, true>>(const S&, A); // template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 64, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 128, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 128, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 128, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 128, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 1, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 1, true>>(const S&, A);
// large m
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 16, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 16, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 3, 8, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 3, 8, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 4, 32, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 4, 32, 8, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 8, 64, 4, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 8, 64, 4, 8, true>>(const S&, A);
// large n
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 32, 4, 16, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 32, 4, 16, 8, true>>(const S&, A);
template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 128, 8, true>>(const S&, A);
template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 128, 8, true>>(const S&, A);
// clang-format on // clang-format on
...@@ -126,6 +126,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -126,6 +126,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
dgamma_buf.GetDeviceBuffer(), dgamma_buf.GetDeviceBuffer(),
dbeta_buf.GetDeviceBuffer(), dbeta_buf.GetDeviceBuffer(),
dx_buf.GetDeviceBuffer(), dx_buf.GetDeviceBuffer(),
//tmp
ds_buf.GetDeviceBuffer(),
db_buf.GetDeviceBuffer(),
m, m,
n, n,
stride}; stride};
...@@ -155,12 +160,25 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -155,12 +160,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
dgamma_buf.FromDevice(dgamma_host_dev.data()); dgamma_buf.FromDevice(dgamma_host_dev.data());
dbeta_buf.FromDevice(dbeta_host_dev.data()); dbeta_buf.FromDevice(dbeta_host_dev.data());
dx_buf.FromDevice(dx_host_dev.data());
//tmp
ds_buf.FromDevice(ds_host_dev.data());
db_buf.FromDevice(db_host_dev.data());
auto [rtol, atol] = get_elimit<DataType>(); auto [rtol, atol] = get_elimit<DataType>();
pass = ck_tile::check_err( // pass = ck_tile::check_err(
dgamma_host_dev, dgamma_host_ref, std::string("GAMMA OUT Error: Incorrect results!"), rtol, atol); // dgamma_host_dev, dgamma_host_ref, std::string("GAMMA OUT Error: Incorrect results!"), rtol, atol);
// pass &= ck_tile::check_err(
// dbeta_host_dev, dbeta_host_ref, std::string("BETA OUT Error: Incorrect results!"), rtol, atol);
pass &= ck_tile::check_err( pass &= ck_tile::check_err(
dbeta_host_dev, dbeta_host_ref, std::string("BETA OUT Error: Incorrect results!"), rtol, atol); dx_host_dev, dx_host_ref, std::string("DX OUT Error: Incorrect results!"), rtol, atol);
//tmp
// pass &= ck_tile::check_err(
// ds_host_dev, ds_host_ref, std::string("DS OUT Error: Incorrect results!"), rtol, atol);
// pass &= ck_tile::check_err(
// db_host_dev, db_host_ref, std::string("DB OUT Error: Incorrect results!"), rtol, atol);
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
} }
......
...@@ -43,8 +43,10 @@ struct layernorm2d_bwd_args : public ck_tile::Layernorm2dBwdGammaBetaHostArgs ...@@ -43,8 +43,10 @@ struct layernorm2d_bwd_args : public ck_tile::Layernorm2dBwdGammaBetaHostArgs
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_, template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_> bool kPadN_>
struct layernorm2d_bwd_traits_ struct layernorm2d_bwd_traits_
{ {
...@@ -60,7 +62,8 @@ struct layernorm2d_bwd_traits_ ...@@ -60,7 +62,8 @@ struct layernorm2d_bwd_traits_
if constexpr(is_warp_per_row) if constexpr(is_warp_per_row)
{ {
static_assert(warpSize % ThreadPerBlock_N_ == 0); static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_); // return total_warps * (warpSize / ThreadPerBlock_N_);
return total_warps;
} }
else else
{ {
...@@ -84,17 +87,18 @@ struct layernorm2d_bwd_traits_ ...@@ -84,17 +87,18 @@ struct layernorm2d_bwd_traits_
}(); }();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = ThreadPerBlock_N_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N; static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>; using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>; using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>; using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, 1>; using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>; using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
...@@ -103,13 +107,17 @@ struct layernorm2d_bwd_traits_ ...@@ -103,13 +107,17 @@ struct layernorm2d_bwd_traits_
template <typename DataType_, template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_> bool kPadN_>
using trait_ = layernorm2d_bwd_traits_<DataType_, using trait_ = layernorm2d_bwd_traits_<DataType_,
Repeat_M_, Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_, ThreadPerBlock_M_,
ThreadPerBlock_N_, ThreadPerBlock_N_,
Vector_N_,
kPadN_>; kPadN_>;
template <typename Traits_> template <typename Traits_>
...@@ -126,7 +134,9 @@ template <typename data_type> ...@@ -126,7 +134,9 @@ template <typename data_type>
struct layernorm2d_bwd_b16_ struct layernorm2d_bwd_b16_
{ {
/* data */ /* data */
using Trait = trait_<data_type, 1, 1, 64, true>; //using Trait = trait_<data_type, 1, 1, 1, 256, 1, true>;
//using Trait = trait_<data_type, 1, 8, 64, 4, 8, true>;
using Trait = trait_<data_type, 1, 4, 1, 128, 8, true>;
float operator() (layernorm2d_bwd_traits /*t*/, float operator() (layernorm2d_bwd_traits /*t*/,
layernorm2d_bwd_args a, layernorm2d_bwd_args a,
const ck_tile::stream_config& s) { const ck_tile::stream_config& s) {
......
...@@ -48,6 +48,7 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp ...@@ -48,6 +48,7 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
const ComputeDataType dy = ck_tile::type_convert<ComputeDataType>(dy_m_n(m_offset + inner_m, n)); const ComputeDataType dy = ck_tile::type_convert<ComputeDataType>(dy_m_n(m_offset + inner_m, n));
gamma_acc += dy * (x - mean) * inv_std; gamma_acc += dy * (x - mean) * inv_std;
beta_acc += dy; beta_acc += dy;
//printf("\ndteng print---dy[%d][%d]=%f\n",m_offset + inner_m,n,dy);
} }
dgamma_mpart_n(m, n) = ck_tile::type_convert<GammaDataType>(gamma_acc); dgamma_mpart_n(m, n) = ck_tile::type_convert<GammaDataType>(gamma_acc);
...@@ -69,14 +70,18 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp ...@@ -69,14 +70,18 @@ CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataTyp
ds += dy * gamma * x; ds += dy * gamma * x;
db += dy * gamma; db += dy * gamma;
} }
ds_m(m_offset + inner_m) = ds;
db_m(m_offset + inner_m) = db;
ComputeDataType b = (db * mean - ds) * inv_std * inv_std * inv_std / N; ComputeDataType b = (db * mean - ds) * inv_std * inv_std * inv_std / N;
ComputeDataType c = -b * mean - db * inv_std / N; ComputeDataType c = -b * mean - db * inv_std / N;
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
const ComputeDataType dy = ck_tile::type_convert<ComputeDataType>(dy_m_n(m_offset + inner_m, n)); const ComputeDataType dy = ck_tile::type_convert<ComputeDataType>(dy_m_n(m_offset + inner_m, n));
const ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m_offset + inner_m, n)); const ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m_offset + inner_m, n));
const ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n)); const ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
dx_m_n(m_offset + inner_m, n) = ck_tile::type_convert<XDataType>(dy * gamma * inv_std + b * x + c); dx_m_n(m_offset + inner_m, n) = ck_tile::type_convert<XDataType>(dy * gamma * inv_std + b * x + c);
//printf("\ndteng print---dx[%d][%d]=%f\n",m_offset + inner_m,n,ck_tile::type_convert<ComputeDataType>(dx_m_n(m_offset + inner_m, n)));
} }
} }
}; };
......
...@@ -21,6 +21,10 @@ struct Layernorm2dBwdGammaBetaHostArgs ...@@ -21,6 +21,10 @@ struct Layernorm2dBwdGammaBetaHostArgs
void* p_dBeta; void* p_dBeta;
void* p_dX; void* p_dX;
//tmp
void* p_dS;
void* p_dB;
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t stride; // row_stride
...@@ -43,6 +47,7 @@ struct Layernorm2dBwdGammaBeta ...@@ -43,6 +47,7 @@ struct Layernorm2dBwdGammaBeta
static constexpr index_t Block_M = Problem::BlockShape::Block_M; static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr bool kPadM = false; // always no need to pad along M static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
...@@ -63,6 +68,10 @@ struct Layernorm2dBwdGammaBeta ...@@ -63,6 +68,10 @@ struct Layernorm2dBwdGammaBeta
void* p_dBeta; void* p_dBeta;
void* p_dX; void* p_dX;
//tmp
void* p_dS;
void* p_dB;
index_t m; index_t m;
index_t n; index_t n;
index_t stride; // row_stride index_t stride; // row_stride
...@@ -79,6 +88,11 @@ struct Layernorm2dBwdGammaBeta ...@@ -79,6 +88,11 @@ struct Layernorm2dBwdGammaBeta
hargs.p_dGamma, hargs.p_dGamma,
hargs.p_dBeta, hargs.p_dBeta,
hargs.p_dX, hargs.p_dX,
//tmp
hargs.p_dS,
hargs.p_dB,
hargs.m, hargs.m,
hargs.n, hargs.n,
hargs.stride}; hargs.stride};
...@@ -128,11 +142,17 @@ struct Layernorm2dBwdGammaBeta ...@@ -128,11 +142,17 @@ struct Layernorm2dBwdGammaBeta
const auto block_id = get_block_id(); const auto block_id = get_block_id();
const auto iM = block_id * Block_M; const auto iM = block_id * Block_M;
// if(threadIdx.x == 0 && blockIdx.x == 0){
// printf("dteng block shape---WarpPerBlock_M=%d, WarpPerBlock_N=%d, ThreadPerWarp_M=%d, ThreadPerWarp_N=%d, Vector_N=%d\n", static_cast<int>(Problem::BlockShape::WarpPerBlock_M), static_cast<int>(Problem::BlockShape::WarpPerBlock_N), static_cast<int>(Problem::BlockShape::ThreadPerWarp_M), static_cast<int>(Problem::BlockShape::ThreadPerWarp_N), static_cast<int>(Problem::BlockShape::Vector_N));
// }
const auto x_window = [&]() { const auto x_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x), static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1)); make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will // NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically // check the max count dynamically
...@@ -146,7 +166,9 @@ struct Layernorm2dBwdGammaBeta ...@@ -146,7 +166,9 @@ struct Layernorm2dBwdGammaBeta
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const YDataType*>(kargs.p_dY), static_cast<const YDataType*>(kargs.p_dY),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1)); make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will // NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically // check the max count dynamically
...@@ -160,7 +182,9 @@ struct Layernorm2dBwdGammaBeta ...@@ -160,7 +182,9 @@ struct Layernorm2dBwdGammaBeta
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const MeanDataType*>(kargs.p_gamma), static_cast<const MeanDataType*>(kargs.p_gamma),
make_tuple(kargs.n), make_tuple(kargs.n),
make_tuple(1)); make_tuple(1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ = const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{}); pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
...@@ -175,7 +199,7 @@ struct Layernorm2dBwdGammaBeta ...@@ -175,7 +199,7 @@ struct Layernorm2dBwdGammaBeta
make_tuple(1)); make_tuple(1));
const auto tmp2_ = const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{}); pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM}); return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}(); }();
...@@ -187,7 +211,7 @@ struct Layernorm2dBwdGammaBeta ...@@ -187,7 +211,7 @@ struct Layernorm2dBwdGammaBeta
make_tuple(1)); make_tuple(1));
const auto tmp2_ = const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{}); pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM}); return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}(); }();
...@@ -196,7 +220,9 @@ struct Layernorm2dBwdGammaBeta ...@@ -196,7 +220,9 @@ struct Layernorm2dBwdGammaBeta
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<GammaDataType*>(kargs.p_dGamma), static_cast<GammaDataType*>(kargs.p_dGamma),
make_tuple(gridDim.x, kargs.n), make_tuple(gridDim.x, kargs.n),
make_tuple(kargs.n, 1)); make_tuple(kargs.n, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ = const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<1>{}, number<Block_N>{}), sequence<false, kPadN>{}); pad_tensor_view(tmp_, make_tuple(number<1>{}, number<Block_N>{}), sequence<false, kPadN>{});
...@@ -208,7 +234,9 @@ struct Layernorm2dBwdGammaBeta ...@@ -208,7 +234,9 @@ struct Layernorm2dBwdGammaBeta
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<BetaDataType*>(kargs.p_dBeta), static_cast<BetaDataType*>(kargs.p_dBeta),
make_tuple(gridDim.x, kargs.n), make_tuple(gridDim.x, kargs.n),
make_tuple(kargs.n, 1)); make_tuple(kargs.n, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ = const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<1>{}, number<Block_N>{}), sequence<false, kPadN>{}); pad_tensor_view(tmp_, make_tuple(number<1>{}, number<Block_N>{}), sequence<false, kPadN>{});
...@@ -219,14 +247,42 @@ struct Layernorm2dBwdGammaBeta ...@@ -219,14 +247,42 @@ struct Layernorm2dBwdGammaBeta
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<XDataType*>(kargs.p_dX), static_cast<XDataType*>(kargs.p_dX),
make_tuple(kargs.m, kargs.n), make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1)); make_tuple(kargs.stride, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ = const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{}); pad_tensor_view(tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0}); return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}(); }();
__shared__ char smem[GetSmemSize()]; //tmp
const auto ds_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<ComputeDataType*>(kargs.p_dS),
make_tuple(kargs.m),
make_tuple(1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}();
const auto db_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<ComputeDataType*>(kargs.p_dB),
make_tuple(kargs.m),
make_tuple(1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}();
// __shared__ char smem[GetSmemSize()];
__shared__ char smem[0];
Pipeline{}(x_window, Pipeline{}(x_window,
dy_window, dy_window,
...@@ -236,6 +292,11 @@ struct Layernorm2dBwdGammaBeta ...@@ -236,6 +292,11 @@ struct Layernorm2dBwdGammaBeta
dgamma_window, dgamma_window,
dbeta_window, dbeta_window,
dx_window, dx_window,
//tmp
ds_window,
db_window,
kargs.n, kargs.n,
smem); smem);
} }
......
...@@ -192,7 +192,9 @@ struct Layernorm2dFwd ...@@ -192,7 +192,9 @@ struct Layernorm2dFwd
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
const auto iM = get_block_id() * Block_M; const auto iM = get_block_id() * Block_M;
// if(threadIdx.x == 0 && blockIdx.x == 0){
// printf("dteng block shape---WarpPerBlock_M=%d, WarpPerBlock_N=%d, ThreadPerWarp_M=%d, ThreadPerWarp_N=%d, Vector_N=%d\n", static_cast<int>(Problem::BlockShape::WarpPerBlock_M), static_cast<int>(Problem::BlockShape::WarpPerBlock_N), static_cast<int>(Problem::BlockShape::ThreadPerWarp_M), static_cast<int>(Problem::BlockShape::ThreadPerWarp_N), static_cast<int>(Problem::BlockShape::Vector_N));
// }
const auto x_window = [&]() { const auto x_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x), static_cast<const XDataType*>(kargs.p_x),
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dBwdGammaBetaPipelineDefaultPolicy>
struct Layernorm2dBwdGammaBetaPipeline
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using ReducePolicy = ck_tile::remove_cvref_t<BlockReduce2dDefaultPolicy>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
static constexpr bool kPadM = false;
static constexpr bool kPadN = Problem::kPadN;
static constexpr const char* name = []() { return "bwd_gamma_beta"; }();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow,
typename GammaWindow,
typename MeanWindow,
typename InvStdWindow,
typename DGammaWindow,
typename DBetaWindow,
typename DXWindow,
// tmp
typename DSWindow,
typename DBWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XWindow& dy_window_,
const GammaWindow& gamma_window_,
const MeanWindow& mean_window_,
const InvStdWindow& inv_std_window_,
DGammaWindow& dgamma_window_,
DBetaWindow& dbeta_window_,
DXWindow& dx_window_,
// tmp
DSWindow& ds_window_,
DBWindow& db_window_,
ck_tile::index_t row_size,
void* smem) const
{
(void)smem;
auto gamma_beta_dist = Policy::template MakeGammaBetaBlockTileDistribution<Problem>();
auto dgamma_beta_dist = Policy::template MakeDGammaBetaBlockTileDistribution<Problem>();
auto mean_dist = Policy::template MakeMeanBlockTileDistribution<Problem>();
auto x_dist = Policy::template MakeXBlockTileDistribution<Problem>();
const auto x_window = make_tile_window(x_window_, x_dist);
const auto dy_window = make_tile_window(dy_window_, x_dist);
const auto gamma_window = make_tile_window(gamma_window_, gamma_beta_dist); // TO CHECK
const auto mean_window = make_tile_window(mean_window_, mean_dist);
const auto inv_std_window = make_tile_window(inv_std_window_, mean_dist);
auto dgamma_window = make_tile_window(dgamma_window_, dgamma_beta_dist);
auto dbeta_window = make_tile_window(dbeta_window_, dgamma_beta_dist);
auto dx_window = make_tile_window(dx_window_, x_dist);
const auto mean_tile = load_tile(mean_window);
const auto inv_std_tile = load_tile(inv_std_window);
// tmp
(void)ds_window_;
(void)db_window_;
//auto ds_window = make_tile_window(ds_window_, mean_dist);
//auto db_window = make_tile_window(db_window_, mean_dist);
auto ds_tile = make_static_distributed_tensor<ComputeDataType>(mean_dist);
auto db_tile = make_static_distributed_tensor<ComputeDataType>(mean_dist);
clear_tile(ds_tile);
clear_tile(db_tile);
auto dgamma_tile = make_static_distributed_tensor<GammaDataType>(dgamma_beta_dist);
auto dbeta_tile = make_static_distributed_tensor<BetaDataType>(dgamma_beta_dist);
auto dx_tile = make_static_distributed_tensor<XDataType>(x_dist);
auto dgamma = cast_tile<ComputeDataType>(dgamma_tile);
auto dbeta = cast_tile<ComputeDataType>(dbeta_tile);
auto dx = cast_tile<ComputeDataType>(dx_tile);
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x_tile = load_tile(x_window);
const auto dy_tile = load_tile(dy_window);
const auto gamma_tile = load_tile(gamma_window);
move_tile_window(x_window, {0, Block_N});
move_tile_window(dy_window, {0, Block_N});
move_tile_window(gamma_window, {Block_N});
sweep_tile(x_tile, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto x = type_convert<ComputeDataType>(x_tile[idx]);
const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
const auto gamma = type_convert<ComputeDataType>(gamma_tile[j_idx]);
ds_tile(i_idx) += dy * gamma * x;
db_tile(i_idx) += dy * gamma;
// printf("threadidx=%d, blockidx=%d, ds_tile=%f\n",threadIdx.x, blockIdx.x, ds_tile[i_idx]);
});
}
auto block_reduce2d_sync = ReducePolicy::template GetBlockReduce2dSync<Problem>();
block_reduce2d_sync(ds_tile, ck_tile::ReduceOp::Add{});
block_reduce2d_sync(db_tile, ck_tile::ReduceOp::Add{});
// sweep_tile(x_tile, [&](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// printf("post::threadidx=%d, blockidx=%d, ds_tile=%f\n",threadIdx.x, blockIdx.x,
// ds_tile[i_idx]);
// });
//store_tile(ds_window, ds_tile);
//store_tile(db_window, db_tile);
ck_tile::index_t stride_to_right_most_window = row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N});
move_tile_window(dy_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(dx_window, {0, stride_to_right_most_window});
move_tile_window(dbeta_window, {0, stride_to_right_most_window});
move_tile_window(dgamma_window, {0, stride_to_right_most_window});
using XDistributedTensor = decltype(load_tile(x_window));
constexpr auto spans = XDistributedTensor::get_distributed_spans();
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
sweep_tile_span(spans[number<0>{}], [&](auto i_idx) {
constexpr auto idx0 = make_tuple(i_idx);
const auto mean = type_convert<ComputeDataType>(mean_tile[idx0]);
const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[idx0]);
auto b = (db_tile[idx0] * mean - ds_tile[idx0]) * inv_std * inv_std * inv_std / row_size;
auto c = -b * mean - db_tile[idx0] * inv_std / row_size;
sweep_tile_span(spans[number<1>{}], [&](auto j_idx) {
constexpr auto idx = make_tuple(i_idx, j_idx);
constexpr auto gb_idx = make_tuple(number<0>{}, j_idx);
const auto x = type_convert<ComputeDataType>(x_tile[idx]);
const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
const auto gamma = type_convert<ComputeDataType>(gamma_tile[idx]);
dbeta(gb_idx) += dy;
dgamma(gb_idx) += dy * (x - mean) * inv_std;
dx(idx) = dy * gamma * inv_std + b * x + c;
});
});
store_tile(dbeta_window, cast_tile<BetaDataType>(dbeta));
store_tile(dgamma_window, cast_tile<GammaDataType>(dgamma));
store_tile(dx_window, cast_tile<XDataType>(dx));
move_tile_window(x_window, {0, -Block_N});
move_tile_window(dy_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N});
move_tile_window(dx_window, {0, -Block_N});
move_tile_window(dbeta_window, {0, -Block_N});
move_tile_window(dgamma_window, {0, -Block_N});
}
}
};
} // namespace ck_tile
...@@ -17,12 +17,12 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy ...@@ -17,12 +17,12 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>, tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N>>, sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>, tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<2, 2>>, tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 2>, sequence<1, 1, 2, 2>,
sequence<0, 0>>{}); sequence<0, 3, 0, 3>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMeanBlockTileDistribution() CK_TILE_DEVICE static constexpr auto MakeMeanBlockTileDistribution()
...@@ -32,11 +32,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy ...@@ -32,11 +32,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N>, sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>>, tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::Vector_M>>,
tuple<sequence<1, 0>, sequence<1, 0>>, tuple<sequence<1, 0>, sequence<1, 0>>,
tuple<sequence<1, 0>, sequence<2, 1>>, tuple<sequence<1, 0>, sequence<2, 1>>,
sequence<1>, sequence<1, 1>,
sequence<0>>{}); sequence<0, 3>>{});
} }
template <typename Problem> template <typename Problem>
...@@ -48,11 +48,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy ...@@ -48,11 +48,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
tile_distribution_encoding< tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>, tuple<sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N>>, sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>, tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<1, 2>>, tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<2>, sequence<2, 2>,
sequence<0>>{}); sequence<0, 3>>{});
} }
template <typename Problem> template <typename Problem>
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
...@@ -13,8 +14,9 @@ namespace ck_tile { ...@@ -13,8 +14,9 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dBwdGammaBetaPipelineDefaultPolicy> template <typename Problem_, typename Policy_ = Layernorm2dBwdGammaBetaPipelineDefaultPolicy>
struct Layernorm2dBwdGammaBetaPipeline struct Layernorm2dBwdGammaBetaPipeline
{ {
using Problem = ck_tile::remove_cvref_t<Problem_>; using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
using ReducePolicy = ck_tile::remove_cvref_t<BlockReduce2dDefaultPolicy>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>; using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>; using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
...@@ -24,16 +26,15 @@ struct Layernorm2dBwdGammaBetaPipeline ...@@ -24,16 +26,15 @@ struct Layernorm2dBwdGammaBetaPipeline
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>; using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>; using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
static constexpr bool kPadM = false; static constexpr bool kPadM = false;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr const char* name = []() { static constexpr const char* name = []() { return "bwd_gamma_beta"; }();
return "bwd_gamma_beta";
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return ReducePolicy::template GetSmemSize<Problem>();
//GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
} }
template <typename XWindow, template <typename XWindow,
typename GammaWindow, typename GammaWindow,
...@@ -41,7 +42,11 @@ struct Layernorm2dBwdGammaBetaPipeline ...@@ -41,7 +42,11 @@ struct Layernorm2dBwdGammaBetaPipeline
typename InvStdWindow, typename InvStdWindow,
typename DGammaWindow, typename DGammaWindow,
typename DBetaWindow, typename DBetaWindow,
typename DXWindow> typename DXWindow,
// tmp
typename DSWindow,
typename DBWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XWindow& dy_window_, const XWindow& dy_window_,
const GammaWindow& gamma_window_, const GammaWindow& gamma_window_,
...@@ -50,83 +55,125 @@ struct Layernorm2dBwdGammaBetaPipeline ...@@ -50,83 +55,125 @@ struct Layernorm2dBwdGammaBetaPipeline
DGammaWindow& dgamma_window_, DGammaWindow& dgamma_window_,
DBetaWindow& dbeta_window_, DBetaWindow& dbeta_window_,
DXWindow& dx_window_, DXWindow& dx_window_,
// tmp
DSWindow& ds_window_,
DBWindow& db_window_,
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem) const void* smem) const
{ {
(void)row_size;
(void)smem;
auto gamma_beta_dist = Policy::template MakeGammaBetaBlockTileDistribution<Problem>(); auto gamma_beta_dist = Policy::template MakeGammaBetaBlockTileDistribution<Problem>();
auto dgamma_beta_dist = Policy::template MakeDGammaBetaBlockTileDistribution<Problem>(); auto dgamma_beta_dist = Policy::template MakeDGammaBetaBlockTileDistribution<Problem>();
auto mean_dist = Policy::template MakeMeanBlockTileDistribution<Problem>(); auto mean_dist = Policy::template MakeMeanBlockTileDistribution<Problem>();
auto x_dist = Policy::template MakeXBlockTileDistribution<Problem>(); auto x_dist = Policy::template MakeXBlockTileDistribution<Problem>();
const auto x_window = make_tile_window(x_window_, x_dist); const auto x_window = make_tile_window(x_window_, x_dist);
const auto dy_window = make_tile_window(dy_window_, x_dist); const auto dy_window = make_tile_window(dy_window_, x_dist);
const auto gamma_window = make_tile_window(gamma_window_, gamma_beta_dist); //TO CHECK const auto gamma_window = make_tile_window(gamma_window_, gamma_beta_dist); // TO CHECK
const auto mean_window = make_tile_window(mean_window_, mean_dist); const auto mean_window = make_tile_window(mean_window_, mean_dist);
const auto inv_std_window = make_tile_window(inv_std_window_, mean_dist); const auto inv_std_window = make_tile_window(inv_std_window_, mean_dist);
const auto x_tile = load_tile(x_window);
const auto dy_tile = load_tile(dy_window);
const auto gamma_tile = load_tile(gamma_window);
const auto mean_tile = load_tile(mean_window);
const auto inv_std_tile = load_tile(inv_std_window);
auto dgamma_window = make_tile_window(dgamma_window_, dgamma_beta_dist); auto dgamma_window = make_tile_window(dgamma_window_, dgamma_beta_dist);
auto dbeta_window = make_tile_window(dbeta_window_, dgamma_beta_dist); auto dbeta_window = make_tile_window(dbeta_window_, dgamma_beta_dist);
auto dx_window = make_tile_window(dx_window_, x_dist); auto dx_window = make_tile_window(dx_window_, x_dist);
auto dgamma_tile = make_static_distributed_tensor<GammaDataType>(dgamma_beta_dist);
auto dbeta_tile = make_static_distributed_tensor<BetaDataType>(dgamma_beta_dist); const auto x_tile = load_tile(x_window);
auto dx_tile = make_static_distributed_tensor<XDataType>(x_dist); const auto dy_tile = load_tile(dy_window);
auto dgamma = cast_tile<ComputeDataType>(dgamma_tile); const auto gamma_tile = load_tile(gamma_window);
auto dbeta = cast_tile<ComputeDataType>(dbeta_tile); const auto mean_tile = load_tile(mean_window);
auto dx = cast_tile<XDataType>(dx_tile); const auto inv_std_tile = load_tile(inv_std_window);
// tmp
auto ds_window = make_tile_window(ds_window_, mean_dist);
auto db_window = make_tile_window(db_window_, mean_dist);
auto ds_tile = make_static_distributed_tensor<ComputeDataType>(mean_dist);
auto db_tile = make_static_distributed_tensor<ComputeDataType>(mean_dist);
clear_tile(ds_tile);
clear_tile(db_tile);
// (void)ds_window;
// (void)db_window;
// auto dgamma_tile = make_static_distributed_tensor<GammaDataType>(dgamma_beta_dist);
// auto dbeta_tile = make_static_distributed_tensor<BetaDataType>(dgamma_beta_dist);
auto dx_tile = make_static_distributed_tensor<XDataType>(x_dist);
// auto dgamma = cast_tile<ComputeDataType>(dgamma_tile);
// auto dbeta = cast_tile<ComputeDataType>(dbeta_tile);
auto dx = cast_tile<ComputeDataType>(dx_tile);
(void)dx_window; // auto gen_ones = [](ck_tile::index_t size) -> uint64_t {
(void)dx; // if (size <= 0) return 0;
(void)gamma_tile; // if (size >= 64) return 0xFFFFFFFFFFFFFFFF;
// return (1ULL << size) - 1;
// };
// uint64_t lane_en = gen_ones(row_size);
// printf("lane en is %lu", lane_en);
// //uint64_t lane_en = (1ULL << row_size) - 1;
// asm volatile("s_mov_b64 exec, %[s_lane_en]"
// :
// : [s_lane_en]"s"(lane_en)
// : );
sweep_tile(x_tile, [&](auto idx) { sweep_tile(x_tile, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
//constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
constexpr auto gb_idx = make_tuple(number<0>{}, idx[number<1>{}]); const auto x = type_convert<ComputeDataType>(x_tile[idx]);
// auto &gamma = gamma_tile(gb_idx); const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
// auto &beta = beta_tile(gb_idx); const auto gamma = type_convert<ComputeDataType>(gamma_tile[j_idx]);
const auto x = type_convert<ComputeDataType>(x_tile[idx]); ds_tile(i_idx) += dy * gamma * x;
const auto dy = type_convert<ComputeDataType>(dy_tile[idx]); db_tile(i_idx) += dy * gamma;
const auto mean = type_convert<ComputeDataType>(mean_tile[i_idx]); // printf("db_tile pre: threadidx=%d, blockidx=%d, db_tile=%f\n",threadIdx.x, blockIdx.x, db_tile[i_idx]);
const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[i_idx]); // printf("dy_tile: threadidx=%d, blockidx=%d, dy_tile=%f\n",threadIdx.x, blockIdx.x, dy);
// beta += type_convert<BetaDataType>(dy); // printf("x: threadidx=%d, blockidx=%d, x_tile=%f\n",threadIdx.x, blockIdx.x, x);
// gamma += type_convert<GammaDataType>(dy * (x - mean) * inv_std); // printf("gamma: threadidx=%d, blockidx=%d, gamma_tile=%f\n",threadIdx.x, blockIdx.x, gamma);
dbeta(gb_idx) += dy;
dgamma(gb_idx) += dy * (x - mean) * inv_std;
// index_t tid = (threadIdx.y * blockDim.x) + threadIdx.x;
// if(blockIdx.x < 3 && blockIdx.y == 0 && tid < 3) {
// printf("bid %d tid %d count %d gb %f %f\n",blockIdx.x, tid, count, type_convert<float>(g), type_convert<float>(b));
// }
}); });
store_tile(dbeta_window, cast_tile<BetaDataType>(dbeta));
store_tile(dgamma_window, cast_tile<GammaDataType>(dgamma));
// store_tile(gamma_window, gamma_tile);
// store_tile(beta_window, beta_tile);
// auto ds = cast_tile<ComputeDataType>(mean_tile); auto block_reduce2d_sync = ReducePolicy::template GetBlockReduce2dSync<Problem>();
// auto db = cast_tile<ComputeDataType>(mean_tile); auto block_reduce2d_cross_warp_sync = ReducePolicy::template GetBlockReduce2dCrossWarpSync<Problem>();
// //calculate dx block_reduce2d_sync(ds_tile, ck_tile::ReduceOp::Add{});
// sweep_tile(x_tile, [&](auto idx)) { block_reduce2d_sync(db_tile, ck_tile::ReduceOp::Add{});
// block_reduce2d_cross_warp_sync(ds_tile, smem, ck_tile::ReduceOp::Add{});
// block_reduce2d_cross_warp_sync(db_tile, smem, ck_tile::ReduceOp::Add{});
// sweep_tile(x_tile, [&](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]); // constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// constexpr auto j_idx = make_tuple(idx[number<1>{}]); // printf("db_tile post: threadidx=%d, blockidx=%d, db_tile=%f\n",threadIdx.x, blockIdx.x,
// db_tile[i_idx]);
// });
// store_tile(ds_window, ds_tile);
// store_tile(db_window, db_tile);
// const auto x = type_convert<ComputeDataType>(x_tile[idx]); using XDistributedTensor = decltype(load_tile(x_window));
// const auto dy = type_convert<ComputeDataType>(dy_tile[idx]); constexpr auto spans = XDistributedTensor::get_distributed_spans();
// const auto gamma = type_convert<ComputeDataType>(gamma_tile[j_idx]);
// // const auto mean = type_convert<ComputeDataType>(mean_tile[i_idx]);
// // const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[i_idx]);
// ds[i_idx] += dy * gamma * x;
// db[i_idx] += dy * gamma;
// }
sweep_tile_span(spans[number<0>{}], [&](auto i_idx) {
constexpr auto idx0 = make_tuple(i_idx);
const auto mean = type_convert<ComputeDataType>(mean_tile[idx0]);
const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[idx0]);
auto b = (db_tile[idx0] * mean - ds_tile[idx0]) * inv_std * inv_std * inv_std / row_size;
auto c = -b * mean - db_tile[idx0] * inv_std / row_size;
sweep_tile_span(spans[number<1>{}], [&](auto j_idx) {
constexpr auto idx1 = make_tuple(j_idx);
constexpr auto idx = make_tuple(i_idx, j_idx);
//constexpr auto gb_idx = make_tuple(number<0>{}, j_idx);
const auto x = type_convert<ComputeDataType>(x_tile[idx]);
const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
const auto gamma = type_convert<ComputeDataType>(gamma_tile[idx1]);
// dbeta(gb_idx) += dy;
// dgamma(gb_idx) += dy * (x - mean) * inv_std;
dx(idx) = dy * gamma * inv_std + b * x + c;
//printf("dx: threadidx=%d, blockidx=%d, dx_tile=%f\n",threadIdx.x, blockIdx.x, dx(idx));
});
});
// store_tile(dbeta_window, cast_tile<BetaDataType>(dbeta));
// store_tile(dgamma_window, cast_tile<GammaDataType>(dgamma));
store_tile(dx_window, cast_tile<XDataType>(dx));
} }
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -28,6 +28,8 @@ struct Layernorm2dBwdGammaBetaPipelineProblem ...@@ -28,6 +28,8 @@ struct Layernorm2dBwdGammaBetaPipelineProblem
using BlockShape = remove_cvref_t<BlockShape_>; using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -133,7 +133,10 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -133,7 +133,10 @@ struct Layernorm2dFwdPipelineOnePass
{ {
sweep_tile(x_resi, [&](auto idx) { sweep_tile(x_resi, [&](auto idx) {
// compute x = x_resi + x // compute x = x_resi + x
//printf("x: threadidx=%d, blockidx=%d, acc=%f\n",threadIdx.x, blockIdx.x, x(idx));
// printf("acc pre: threadidx=%d, blockidx=%d, acc=%f\n",threadIdx.x, blockIdx.x, acc(idx));
acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx); acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
// printf("acc post: threadidx=%d, blockidx=%d, acc=%f\n",threadIdx.x, blockIdx.x, acc(idx));
}); });
if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE)
store_tile(y_residual_window, cast_tile<YResidualDataType>(acc)); store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
...@@ -184,6 +187,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -184,6 +187,7 @@ struct Layernorm2dFwdPipelineOnePass
const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]); const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
// printf("ln: threadidx=%d, blockidx=%d, acc=%f\n",threadIdx.x, blockIdx.x, ln_);
ln(idx) = ln_; ln(idx) = ln_;
}); });
......
...@@ -17,7 +17,7 @@ fi ...@@ -17,7 +17,7 @@ fi
cmake \ cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm/ \ -D CMAKE_PREFIX_PATH=/opt/rocm/ \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker --save-temps" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \ -D GPU_TARGETS=$GPU_TARGETS \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment