Commit b0b399d9 authored by AMD-dteng's avatar AMD-dteng
Browse files

optimize for dgrad

parent 30e15644
...@@ -28,6 +28,10 @@ ...@@ -28,6 +28,10 @@
// large n // large n
// template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 32, 4, 16, 8, true>>(const S&, A); // template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 32, 4, 16, 8, true>>(const S&, A);
// template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 32, 4, 16, 8, true>>(const S&, A); // template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 32, 4, 16, 8, true>>(const S&, A);
template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 128, 8, true>>(const S&, A); template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 64, 8, true>>(const S&, A);
template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 128, 8, true>>(const S&, A); template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 64, 8, true>>(const S&, A);
// two pass
template float layernorm2d_bwd_<trait_<ck_tile::bf16_t, 1, 4, 2, 32, 8, true>>(const S&, A);
template float layernorm2d_bwd_<trait_<ck_tile::fp16_t, 1, 4, 2, 32, 8, true>>(const S&, A);
// clang-format on // clang-format on
...@@ -27,7 +27,7 @@ float layernorm2d_bwd_(const S& s, A a) ...@@ -27,7 +27,7 @@ float layernorm2d_bwd_(const S& s, A a)
typename Traits_::Shape, typename Traits_::Shape,
Traits_::kPadN>; Traits_::kPadN>;
using Pipeline = ck_tile::Layernorm2dBwdGammaBetaPipeline<PipelineProblem>; using Pipeline = ck_tile::Layernorm2dBwdGammaBetaPipelineTwoPass<PipelineProblem>;
using Kernel = ck_tile::Layernorm2dBwdGammaBeta<Pipeline>; using Kernel = ck_tile::Layernorm2dBwdGammaBeta<Pipeline>;
......
...@@ -139,7 +139,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -139,7 +139,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(GammaDataType) * n +
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n; sizeof(MeanDataType) * m + sizeof(InvStdDataType) * m + sizeof(YDataType) * m * n + sizeof(XDataType);
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << sizeof(ComputeDataType) << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; std::cout << sizeof(ComputeDataType) << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
......
...@@ -136,7 +136,7 @@ struct layernorm2d_bwd_b16_ ...@@ -136,7 +136,7 @@ struct layernorm2d_bwd_b16_
/* data */ /* data */
//using Trait = trait_<data_type, 1, 1, 1, 256, 1, true>; //using Trait = trait_<data_type, 1, 1, 1, 256, 1, true>;
//using Trait = trait_<data_type, 1, 8, 64, 4, 8, true>; //using Trait = trait_<data_type, 1, 8, 64, 4, 8, true>;
using Trait = trait_<data_type, 1, 4, 1, 128, 8, true>; using Trait = trait_<data_type, 1, 4, 1, 64, 8, true>;
float operator() (layernorm2d_bwd_traits /*t*/, float operator() (layernorm2d_bwd_traits /*t*/,
layernorm2d_bwd_args a, layernorm2d_bwd_args a,
const ck_tile::stream_config& s) { const ck_tile::stream_config& s) {
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp" #include "ck_tile/ops/layernorm2d/kernel/layernorm2d_bwd_gamma_beta_kernel.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_gamma_beta.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_two_pass_dx.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
namespace ck_tile { namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dBwdGammaBetaPipelineDefaultPolicy> template <typename Problem_, typename Policy_ = Layernorm2dBwdGammaBetaPipelineDefaultPolicy>
struct Layernorm2dBwdGammaBetaPipeline struct Layernorm2dBwdGammaBetaPipelineTwoPass
{ {
using Problem = ck_tile::remove_cvref_t<Problem_>; using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>; using Policy = ck_tile::remove_cvref_t<Policy_>;
...@@ -33,7 +33,8 @@ struct Layernorm2dBwdGammaBetaPipeline ...@@ -33,7 +33,8 @@ struct Layernorm2dBwdGammaBetaPipeline
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return ReducePolicy::template GetSmemSize<Problem>();
//GetBlockReduce2dCrossWarpSync<Problem>().template GetSmemSize<y_block_tile>();
} }
template <typename XWindow, template <typename XWindow,
typename GammaWindow, typename GammaWindow,
...@@ -62,18 +63,17 @@ struct Layernorm2dBwdGammaBetaPipeline ...@@ -62,18 +63,17 @@ struct Layernorm2dBwdGammaBetaPipeline
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem) const void* smem) const
{ {
(void)smem;
auto gamma_beta_dist = Policy::template MakeGammaBetaBlockTileDistribution<Problem>(); auto gamma_beta_dist = Policy::template MakeGammaBetaBlockTileDistribution<Problem>();
auto dgamma_beta_dist = Policy::template MakeDGammaBetaBlockTileDistribution<Problem>(); auto dgamma_beta_dist = Policy::template MakeDGammaBetaBlockTileDistribution<Problem>();
auto mean_dist = Policy::template MakeMeanBlockTileDistribution<Problem>(); auto mean_dist = Policy::template MakeMeanBlockTileDistribution<Problem>();
auto x_dist = Policy::template MakeXBlockTileDistribution<Problem>(); auto x_dist = Policy::template MakeXBlockTileDistribution<Problem>();
const auto x_window = make_tile_window(x_window_, x_dist); auto x_window = make_tile_window(x_window_, x_dist);
const auto dy_window = make_tile_window(dy_window_, x_dist); auto dy_window = make_tile_window(dy_window_, x_dist);
const auto gamma_window = make_tile_window(gamma_window_, gamma_beta_dist); // TO CHECK auto gamma_window = make_tile_window(gamma_window_, gamma_beta_dist); // TO CHECK
const auto mean_window = make_tile_window(mean_window_, mean_dist); auto mean_window = make_tile_window(mean_window_, mean_dist);
const auto inv_std_window = make_tile_window(inv_std_window_, mean_dist); auto inv_std_window = make_tile_window(inv_std_window_, mean_dist);
auto dgamma_window = make_tile_window(dgamma_window_, dgamma_beta_dist); auto dgamma_window = make_tile_window(dgamma_window_, dgamma_beta_dist);
auto dbeta_window = make_tile_window(dbeta_window_, dgamma_beta_dist); auto dbeta_window = make_tile_window(dbeta_window_, dgamma_beta_dist);
...@@ -92,11 +92,11 @@ struct Layernorm2dBwdGammaBetaPipeline ...@@ -92,11 +92,11 @@ struct Layernorm2dBwdGammaBetaPipeline
clear_tile(ds_tile); clear_tile(ds_tile);
clear_tile(db_tile); clear_tile(db_tile);
auto dgamma_tile = make_static_distributed_tensor<GammaDataType>(dgamma_beta_dist); // auto dgamma_tile = make_static_distributed_tensor<GammaDataType>(dgamma_beta_dist);
auto dbeta_tile = make_static_distributed_tensor<BetaDataType>(dgamma_beta_dist); // auto dbeta_tile = make_static_distributed_tensor<BetaDataType>(dgamma_beta_dist);
auto dx_tile = make_static_distributed_tensor<XDataType>(x_dist); auto dx_tile = make_static_distributed_tensor<XDataType>(x_dist);
auto dgamma = cast_tile<ComputeDataType>(dgamma_tile); // auto dgamma = cast_tile<ComputeDataType>(dgamma_tile);
auto dbeta = cast_tile<ComputeDataType>(dbeta_tile); // auto dbeta = cast_tile<ComputeDataType>(dbeta_tile);
auto dx = cast_tile<ComputeDataType>(dx_tile); auto dx = cast_tile<ComputeDataType>(dx_tile);
static constexpr index_t Block_N = Problem::BlockShape::Block_N; static constexpr index_t Block_N = Problem::BlockShape::Block_N;
...@@ -119,36 +119,48 @@ struct Layernorm2dBwdGammaBetaPipeline ...@@ -119,36 +119,48 @@ struct Layernorm2dBwdGammaBetaPipeline
const auto gamma = type_convert<ComputeDataType>(gamma_tile[j_idx]); const auto gamma = type_convert<ComputeDataType>(gamma_tile[j_idx]);
ds_tile(i_idx) += dy * gamma * x; ds_tile(i_idx) += dy * gamma * x;
db_tile(i_idx) += dy * gamma; db_tile(i_idx) += dy * gamma;
// printf("threadidx=%d, blockidx=%d, ds_tile=%f\n",threadIdx.x, blockIdx.x, ds_tile[i_idx]); // dx(idx) = dy * gamma;
// ds_tile(i_idx) += dx[idx] * x;
// db_tile(i_idx) += dx[idx];
// printf("db_tile pre: threadidx=%d, blockidx=%d, db_tile=%f\n",threadIdx.x, blockIdx.x, db_tile[i_idx]);
// printf("dy_tile: threadidx=%d, blockidx=%d, dy_tile=%f\n",threadIdx.x, blockIdx.x, dy);
// printf("x: threadidx=%d, blockidx=%d, x_tile=%f\n",threadIdx.x, blockIdx.x, x);
// printf("gamma: threadidx=%d, blockidx=%d, gamma_tile=%f\n",threadIdx.x, blockIdx.x, gamma);
}); });
} }
auto block_reduce2d_sync = ReducePolicy::template GetBlockReduce2dSync<Problem>(); auto block_reduce2d_sync = ReducePolicy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync = ReducePolicy::template GetBlockReduce2dCrossWarpSync<Problem>();
block_reduce2d_sync(ds_tile, ck_tile::ReduceOp::Add{}); block_reduce2d_sync(ds_tile, ck_tile::ReduceOp::Add{});
block_reduce2d_sync(db_tile, ck_tile::ReduceOp::Add{}); block_reduce2d_sync(db_tile, ck_tile::ReduceOp::Add{});
// block_reduce2d_cross_warp_sync(ds_tile, smem, ck_tile::ReduceOp::Add{});
// block_reduce2d_cross_warp_sync(db_tile, smem, ck_tile::ReduceOp::Add{});
// sweep_tile(x_tile, [&](auto idx) { // sweep_tile(x_tile, [&](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]); // constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// printf("post::threadidx=%d, blockidx=%d, ds_tile=%f\n",threadIdx.x, blockIdx.x, // printf("db_tile post: threadidx=%d, blockidx=%d, db_tile=%f\n",threadIdx.x, blockIdx.x,
// ds_tile[i_idx]); // db_tile[i_idx]);
// }); // });
//store_tile(ds_window, ds_tile); // store_tile(ds_window, ds_tile);
//store_tile(db_window, db_tile); // store_tile(db_window, db_tile);
ck_tile::index_t stride_to_right_most_window = row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; ck_tile::index_t stride_to_right_most_window = row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(dy_window, {0, -Block_N}); move_tile_window(dy_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N});
move_tile_window(dx_window, {0, stride_to_right_most_window}); move_tile_window(dx_window, {0, stride_to_right_most_window});
move_tile_window(dbeta_window, {0, stride_to_right_most_window}); // move_tile_window(dbeta_window, {0, stride_to_right_most_window});
move_tile_window(dgamma_window, {0, stride_to_right_most_window}); // move_tile_window(dgamma_window, {0, stride_to_right_most_window});
using XDistributedTensor = decltype(load_tile(x_window)); using XDistributedTensor = decltype(load_tile(x_window));
constexpr auto spans = XDistributedTensor::get_distributed_spans(); constexpr auto spans = XDistributedTensor::get_distributed_spans();
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{ {
const auto x_tile = load_tile(x_window);
const auto dy_tile = load_tile(dy_window);
const auto gamma_tile = load_tile(gamma_window);
sweep_tile_span(spans[number<0>{}], [&](auto i_idx) { sweep_tile_span(spans[number<0>{}], [&](auto i_idx) {
constexpr auto idx0 = make_tuple(i_idx); constexpr auto idx0 = make_tuple(i_idx);
const auto mean = type_convert<ComputeDataType>(mean_tile[idx0]); const auto mean = type_convert<ComputeDataType>(mean_tile[idx0]);
...@@ -157,26 +169,28 @@ struct Layernorm2dBwdGammaBetaPipeline ...@@ -157,26 +169,28 @@ struct Layernorm2dBwdGammaBetaPipeline
auto c = -b * mean - db_tile[idx0] * inv_std / row_size; auto c = -b * mean - db_tile[idx0] * inv_std / row_size;
sweep_tile_span(spans[number<1>{}], [&](auto j_idx) { sweep_tile_span(spans[number<1>{}], [&](auto j_idx) {
constexpr auto idx1 = make_tuple(j_idx);
constexpr auto idx = make_tuple(i_idx, j_idx); constexpr auto idx = make_tuple(i_idx, j_idx);
constexpr auto gb_idx = make_tuple(number<0>{}, j_idx); //constexpr auto gb_idx = make_tuple(number<0>{}, j_idx);
const auto x = type_convert<ComputeDataType>(x_tile[idx]); const auto x = type_convert<ComputeDataType>(x_tile[idx]);
const auto dy = type_convert<ComputeDataType>(dy_tile[idx]); const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
const auto gamma = type_convert<ComputeDataType>(gamma_tile[idx]); const auto gamma = type_convert<ComputeDataType>(gamma_tile[idx1]);
dbeta(gb_idx) += dy; // dbeta(gb_idx) += dy;
dgamma(gb_idx) += dy * (x - mean) * inv_std; // dgamma(gb_idx) += dy * (x - mean) * inv_std;
dx(idx) = dy * gamma * inv_std + b * x + c; dx(idx) = dy * gamma * inv_std + b * x + c;
//printf("dx: threadidx=%d, blockidx=%d, dx_tile=%f\n",threadIdx.x, blockIdx.x, dx(idx));
}); });
}); });
store_tile(dbeta_window, cast_tile<BetaDataType>(dbeta)); // store_tile(dbeta_window, cast_tile<BetaDataType>(dbeta));
store_tile(dgamma_window, cast_tile<GammaDataType>(dgamma)); // store_tile(dgamma_window, cast_tile<GammaDataType>(dgamma));
store_tile(dx_window, cast_tile<XDataType>(dx)); store_tile(dx_window, cast_tile<XDataType>(dx));
move_tile_window(x_window, {0, -Block_N}); move_tile_window(x_window, {0, -Block_N});
move_tile_window(dy_window, {0, -Block_N}); move_tile_window(dy_window, {0, -Block_N});
move_tile_window(gamma_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N});
move_tile_window(dx_window, {0, -Block_N}); move_tile_window(dx_window, {0, -Block_N});
move_tile_window(dbeta_window, {0, -Block_N}); // move_tile_window(dbeta_window, {0, -Block_N});
move_tile_window(dgamma_window, {0, -Block_N}); // move_tile_window(dgamma_window, {0, -Block_N});
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment