Commit 96568141 authored by rocking's avatar rocking
Browse files

1. Add save mean and save std back

2. Move construction of tensor_view and tile_window to operator()
parent e0b473b6
...@@ -55,6 +55,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -55,6 +55,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<MeanDataType> mean_host_ref({M}); ck_tile::HostTensor<MeanDataType> mean_host_ref({M});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M}); ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M});
// TODO - move SAVE_MEAN_INV_STD to user args
#ifdef SAVE_MEAN_INV_STD
ck_tile::HostTensor<MeanDataType> mean_host_dev({M});
ck_tile::HostTensor<InvStdDataType> invStd_host_dev({M});
#endif
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host); ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host); ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
...@@ -63,6 +69,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -63,6 +69,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
#ifdef SAVE_MEAN_INV_STD
ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes());
#endif
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
gamma_buf.ToDevice(gamma_host.data()); gamma_buf.ToDevice(gamma_host.data());
...@@ -74,6 +84,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -74,6 +84,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
gamma_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(), y_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
mean_buf.GetDeviceBuffer(),
invStd_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
epsilon, epsilon,
M, M,
N}; N};
...@@ -121,6 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -121,6 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
std::cout << std::endl << std::flush; std::cout << std::endl << std::flush;
std::cout << "pass = " << pass << std::endl;
return pass; return pass;
} }
......
...@@ -56,8 +56,8 @@ struct layernorm2d_fwd_args ...@@ -56,8 +56,8 @@ struct layernorm2d_fwd_args
const void* p_gamma; const void* p_gamma;
const void* p_beta; const void* p_beta;
void* p_y; void* p_y;
// void* p_mean; void* p_mean;
// void* p_invStd; void* p_invStd;
float epsilon; float epsilon;
ck_tile::index_t M; ck_tile::index_t M;
ck_tile::index_t N; ck_tile::index_t N;
......
...@@ -59,6 +59,8 @@ struct layernorm_dispatch ...@@ -59,6 +59,8 @@ struct layernorm_dispatch
param.p_gamma, param.p_gamma,
param.p_beta, param.p_beta,
param.p_y, param.p_y,
param.p_mean,
param.p_invStd,
param.epsilon, param.epsilon,
param.M, param.M,
param.N)); param.N));
......
...@@ -31,10 +31,15 @@ struct Layernorm2dFwd ...@@ -31,10 +31,15 @@ struct Layernorm2dFwd
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock; static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock; static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
static constexpr bool kPadM = false; // TODO - Problem::kPadM
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass; static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp; static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
struct Kargs struct Kargs
{ {
...@@ -43,8 +48,8 @@ struct Layernorm2dFwd ...@@ -43,8 +48,8 @@ struct Layernorm2dFwd
const void* p_beta; const void* p_beta;
void* p_y; void* p_y;
// void* p_mean; void* p_mean;
// void* p_invStd; void* p_invStd;
float epsilon; float epsilon;
...@@ -150,53 +155,24 @@ struct Layernorm2dFwd ...@@ -150,53 +155,24 @@ struct Layernorm2dFwd
return iN0 * S::kNPerThread + iN3; return iN0 * S::kNPerThread + iN3;
} }
template <bool Cond = (kHasGamma && kHasBeta)> template <typename XBlockWindow,
CK_TILE_DEVICE std::enable_if_t<Cond> OnePassLayernorm2dFwd(const XDataType* p_x, typename GammaBlockWindow,
const GammaDataType* p_gamma, typename BetaBlockWindow,
const BetaDataType* p_beta, typename YBlockWindow,
YDataType* p_y, typename MeanBlockWindow,
const ComputeDataType epsilon, typename InvStdBlockWindow,
ck_tile::index_t M, bool Cond = (kHasGamma && kHasBeta)>
ck_tile::index_t N) const CK_TILE_DEVICE std::enable_if_t<Cond>
OnePassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{ {
using S = typename Problem::BlockShape;
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::kNPerThread>{}, number<1>{});
return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<false, kPadN>{});
}();
const auto gamma_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
p_gamma, make_tuple(N), make_tuple(1), number<S::kNPerThread>{}, number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
}();
const auto beta_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
p_beta, make_tuple(N), make_tuple(1), number<S::kNPerThread>{}, number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
}();
const auto iM = get_block_id() * kMPerBlock;
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N); auto intra_thread_count_last = GetLastloopLayerNormIntraLaneReduceCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{intra_thread_count_last}; ThreadWelford<ComputeDataType, XDataType> thread_welford{intra_thread_count_last};
using XTensorType = decltype(load_tile(x_block_window)); using XTensorType = decltype(load_tile(x_block_window));
...@@ -210,37 +186,21 @@ struct Layernorm2dFwd ...@@ -210,37 +186,21 @@ struct Layernorm2dFwd
const auto x_block_tensor = load_tile(x_block_window); const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor); thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window =
make_tile_window(beta_n, make_tuple(number<kNPerBlock>{}), {0}, betaDstr);
const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
// TODO: support cross warp Welford // TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}( WarpMergeWelford<ComputeDataType, true>{}(
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_); mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
// TODO: Extract normalize pipeline if constexpr(kSaveMean)
const auto y_m_n = [&]() { store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>( if constexpr(kSaveInvStd)
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::kNPerThread>{}, number<1>{}); store_tile(inv_std_block_window,
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
return pad_tensor_view(y_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<false, kPadN>{});
}();
auto y_block_window = make_tile_window( // TODO: Extract normalize pipeline
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}); const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans(); constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
...@@ -269,51 +229,24 @@ struct Layernorm2dFwd ...@@ -269,51 +229,24 @@ struct Layernorm2dFwd
store_tile(y_block_window, y_block_tensor); store_tile(y_block_window, y_block_tensor);
} }
template <bool Cond = (kHasGamma && kHasBeta)> template <typename XBlockWindow,
CK_TILE_DEVICE std::enable_if_t<Cond> TwoPassLayernorm2dFwd(const XDataType* p_x, typename GammaBlockWindow,
const GammaDataType* p_gamma, typename BetaBlockWindow,
const BetaDataType* p_beta, typename YBlockWindow,
YDataType* p_y, typename MeanBlockWindow,
const ComputeDataType epsilon, typename InvStdBlockWindow,
ck_tile::index_t M, bool Cond = (kHasGamma && kHasBeta)>
ck_tile::index_t N) const CK_TILE_DEVICE std::enable_if_t<Cond>
TwoPassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{ {
using S = typename Problem::BlockShape; using S = typename Problem::BlockShape;
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::kNPerThread>{}, number<1>{});
return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<false, true>{});
}();
const auto gamma_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
p_gamma, make_tuple(N), make_tuple(1), number<S::kNPerThread>{}, number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<true>{});
}();
const auto beta_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
p_beta, make_tuple(N), make_tuple(1), number<S::kNPerThread>{}, number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<true>{});
}();
const auto iM = get_block_id() * kMPerBlock;
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
index_t num_n_tile_iteration = index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane((N + kNPerBlock - 1) / kNPerBlock); __builtin_amdgcn_readfirstlane((N + kNPerBlock - 1) / kNPerBlock);
...@@ -352,27 +285,11 @@ struct Layernorm2dFwd ...@@ -352,27 +285,11 @@ struct Layernorm2dFwd
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon); auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
// TODO: Extract normalize pipeline if constexpr(kSaveMean)
const auto y_m_n = [&]() { store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>( if constexpr(kSaveInvStd)
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::kNPerThread>{}, number<1>{}); store_tile(inv_std_block_window,
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
return pad_tensor_view(y_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<false, true>{});
}();
auto y_block_window = make_tile_window(
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window =
make_tile_window(beta_n, make_tuple(number<kNPerBlock>{}), {0}, betaDstr);
// reverse read x to reuse cache // reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window = ck_tile::index_t stride_to_right_most_window =
...@@ -426,29 +343,137 @@ struct Layernorm2dFwd ...@@ -426,29 +343,137 @@ struct Layernorm2dFwd
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
void* p_y, void* p_y,
void* p_mean,
void* p_invStd,
const ComputeDataType epsilon, const ComputeDataType epsilon,
ck_tile::index_t M, ck_tile::index_t M,
ck_tile::index_t N) const ck_tile::index_t N) const
{ {
const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(p_x),
make_tuple(M, N),
make_tuple(N, 1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
const auto gamma_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(p_gamma),
make_tuple(N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
}();
const auto beta_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(p_beta),
make_tuple(N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
}();
const auto iM = get_block_id() * kMPerBlock;
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
const auto y_m_n = [&]() {
const auto y_dram_naive =
make_naive_tensor_view<address_space_enum::global>(static_cast<YDataType*>(p_y),
make_tuple(M, N),
make_tuple(N, 1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(y_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
auto y_block_window = make_tile_window(
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window = make_tile_window(
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
auto mean_block_window = [&]() {
if constexpr(kSaveMean)
{
const auto mean_m = [&]() {
const auto mean_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<MeanDataType*>(p_mean), make_tuple(M), number<1>{});
return pad_tensor_view(
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
}();
return make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
}();
auto inv_std_block_window = [&]() {
if constexpr(kSaveInvStd)
{
const auto inv_std_m = [&]() {
const auto inv_std_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<InvStdDataType*>(p_invStd), make_tuple(M), number<1>{});
return pad_tensor_view(
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
}();
return make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
}();
if constexpr(kTwoPass) if constexpr(kTwoPass)
{ {
TwoPassLayernorm2dFwd(static_cast<const XDataType*>(p_x), TwoPassLayernorm2dFwd(x_block_window,
static_cast<const GammaDataType*>(p_gamma), gamma_block_window,
static_cast<const BetaDataType*>(p_beta), beta_block_window,
static_cast<YDataType*>(p_y), y_block_window,
mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(epsilon), static_cast<const ComputeDataType>(epsilon),
M,
N); N);
} }
else else
{ {
OnePassLayernorm2dFwd(x_block_window,
OnePassLayernorm2dFwd(static_cast<const XDataType*>(p_x), gamma_block_window,
static_cast<const GammaDataType*>(p_gamma), beta_block_window,
static_cast<const BetaDataType*>(p_beta), y_block_window,
static_cast<YDataType*>(p_y), mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(epsilon), static_cast<const ComputeDataType>(epsilon),
M,
N); N);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment