Commit c03a937d authored by dummycoderfe's avatar dummycoderfe
Browse files

one block ok

parent 7db609fe
...@@ -4,25 +4,6 @@ ...@@ -4,25 +4,6 @@
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "layernorm2d_bwd.hpp" #include "layernorm2d_bwd.hpp"
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
bool kPadN_>
using trait_ = layernorm2d_bwd_traits_<DataType_,
Repeat_M_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
kPadN_>;
template <typename data_type>
float layernorm2d_bwd_b16_(layernorm2d_bwd_traits /*t*/,
layernorm2d_bwd_args a,
const ck_tile::stream_config& s)
{
return layernorm2d_bwd_<trait_<data_type, 1, 1, 64, true>>(s, a);
}
float layernorm2d_bwd(layernorm2d_bwd_traits t, float layernorm2d_bwd(layernorm2d_bwd_traits t,
layernorm2d_bwd_args a, layernorm2d_bwd_args a,
const ck_tile::stream_config& s) const ck_tile::stream_config& s)
...@@ -31,11 +12,11 @@ float layernorm2d_bwd(layernorm2d_bwd_traits t, ...@@ -31,11 +12,11 @@ float layernorm2d_bwd(layernorm2d_bwd_traits t,
float r = -1; float r = -1;
if(t.data_type.compare("fp16") == 0) if(t.data_type.compare("fp16") == 0)
{ {
return layernorm2d_bwd_b16_<ck_tile::fp16_t>(t, a, s); return layernorm2d_bwd_b16_<ck_tile::fp16_t>{}(t, a, s);
} }
else if(t.data_type.compare("bf16") == 0) else if(t.data_type.compare("bf16") == 0)
{ {
return layernorm2d_bwd_b16_<ck_tile::bf16_t>(t, a, s); return layernorm2d_bwd_b16_<ck_tile::bf16_t>{}(t, a, s);
} }
if(r < 0) if(r < 0)
throw std::runtime_error("Without supported instances!"); throw std::runtime_error("Without supported instances!");
......
...@@ -11,17 +11,6 @@ ...@@ -11,17 +11,6 @@
using S = ck_tile::stream_config; using S = ck_tile::stream_config;
using A = layernorm2d_bwd_args; using A = layernorm2d_bwd_args;
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
bool kPadN_>
using trait_ = layernorm2d_bwd_traits_<DataType_,
Repeat_M_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
kPadN_>;
template <typename Traits_> template <typename Traits_>
float layernorm2d_bwd_(const S& s, A a) float layernorm2d_bwd_(const S& s, A a)
{ {
......
...@@ -64,21 +64,27 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -64,21 +64,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> dy_host({m, n}, {stride, 1}); ck_tile::HostTensor<YDataType> dy_host({m, n}, {stride, 1});
ck_tile::HostTensor<MeanDataType> mean_host({m}); ck_tile::HostTensor<MeanDataType> mean_host({m});
ck_tile::HostTensor<InvStdDataType> invStd_host({m}); ck_tile::HostTensor<InvStdDataType> invStd_host({m});
ck_tile::HostTensor<GammaDataType> dgamma_host_dev({n}); ck_tile::index_t blockM = layernorm2d_bwd_block_m<XDataType>();
ck_tile::HostTensor<BetaDataType> dbeta_host_dev({n}); ck_tile::index_t reduce_m = (m + blockM - 1) / blockM;
ck_tile::HostTensor<GammaDataType> dgamma_host_ref({n}); ck_tile::HostTensor<GammaDataType> dgamma_host_dev({reduce_m, n});
ck_tile::HostTensor<BetaDataType> dbeta_host_ref({n}); ck_tile::HostTensor<BetaDataType> dbeta_host_dev({reduce_m, n});
ck_tile::HostTensor<GammaDataType> dgamma_host_ref({reduce_m, n});
ck_tile::HostTensor<BetaDataType> dbeta_host_ref({reduce_m, n});
// ck_tile::FillMonotonicSeq<YDataType>{}(dy_host);
ck_tile::FillUniformDistribution<YDataType>{-.5f, .5f}(dy_host); ck_tile::FillUniformDistribution<YDataType>{-.5f, .5f}(dy_host);
// ck_tile::FillUniformDistribution<MeanDataType>{-.5f, .5f}(mean_host); ck_tile::FillUniformDistribution<MeanDataType>{-.5f, .5f}(mean_host);
ck_tile::FillMonotonicSeq<MeanDataType>{}(mean_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
// ck_tile::FillMonotonicSeq<MeanDataType>{}(mean_host);
ck_tile::FillUniformDistribution<InvStdDataType>{-.5f, .5f}(invStd_host); ck_tile::FillUniformDistribution<InvStdDataType>{-.5f, .5f}(invStd_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dy_buf(dy_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dy_buf(dy_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem mean_buf(mean_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem mean_buf(mean_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem invStd_buf(invStd_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem invStd_buf(invStd_host.get_element_space_size_in_bytes());
...@@ -86,6 +92,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -86,6 +92,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem dgamma_buf(dgamma_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem dgamma_buf(dgamma_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem dbeta_buf(dbeta_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem dbeta_buf(dbeta_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
dy_buf.ToDevice(dy_host.data()); dy_buf.ToDevice(dy_host.data());
mean_buf.ToDevice(mean_host.data()); mean_buf.ToDevice(mean_host.data());
invStd_buf.ToDevice(invStd_host.data()); invStd_buf.ToDevice(invStd_host.data());
...@@ -94,13 +101,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -94,13 +101,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
layernorm2d_bwd_traits traits{data_type}; layernorm2d_bwd_traits traits{data_type};
layernorm2d_bwd_args args{x_buf.GetDeviceBuffer(),
layernorm2d_bwd_args args{dy_buf.GetDeviceBuffer(), dy_buf.GetDeviceBuffer(),
mean_buf.GetDeviceBuffer(), mean_buf.GetDeviceBuffer(),
invStd_buf.GetDeviceBuffer(), invStd_buf.GetDeviceBuffer(),
dgamma_buf.GetDeviceBuffer(), dgamma_buf.GetDeviceBuffer(),
dbeta_buf.GetDeviceBuffer(), dbeta_buf.GetDeviceBuffer(),
nullptr,
m, m,
n, n,
stride}; stride};
...@@ -118,41 +124,24 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -118,41 +124,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
// // reference // reference
// ck_tile::reference_layernorm2d_bwd<XDataType, ck_tile::reference_layernorm2d_bwd_gamma_part<XDataType,
// GammaDataType, GammaDataType,
// BetaDataType, BetaDataType,
// ComputeDataType, ComputeDataType,
// YDataType, YDataType,
// MeanDataType, MeanDataType,
// InvStdDataType>( InvStdDataType>(
// x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); x_host, dy_host, mean_host, invStd_host, dgamma_host_ref, dbeta_host_ref);
// y_buf.FromDevice(y_host_dev.data()); dgamma_buf.FromDevice(dgamma_host_dev.data());
dbeta_buf.FromDevice(dbeta_host_dev.data());
// auto [rtol, atol] = get_elimit<DataType>();
// if(stride == n) auto [rtol, atol] = get_elimit<DataType>();
// { pass = ck_tile::check_err(
// pass = ck_tile::check_err( dgamma_host_dev, dgamma_host_ref, std::string("GAMMA OUT Error: Incorrect results!"), rtol, atol);
// y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); pass &= ck_tile::check_err(
// } dbeta_host_dev, dbeta_host_ref, std::string("BETA OUT Error: Incorrect results!"), rtol, atol);
// else
// {
// for(int i_r = 0; i_r < m; i_r++)
// {
// std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
// y_host_dev.begin() + i_r * stride + n);
// std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
// y_host_ref.begin() + i_r * stride + n);
// pass &= ck_tile::check_err(y_host_dev_row,
// y_host_ref_row,
// std::string("OUT[") + std::to_string(i_r) +
// std::string("] Error: Incorrect results!"),
// rtol,
// atol);
// }
// }
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
} }
......
...@@ -101,6 +101,17 @@ struct layernorm2d_bwd_traits_ ...@@ -101,6 +101,17 @@ struct layernorm2d_bwd_traits_
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
}; };
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
bool kPadN_>
using trait_ = layernorm2d_bwd_traits_<DataType_,
Repeat_M_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
kPadN_>;
template <typename Traits_> template <typename Traits_>
float layernorm2d_bwd_(const ck_tile::stream_config& s, layernorm2d_bwd_args a); float layernorm2d_bwd_(const ck_tile::stream_config& s, layernorm2d_bwd_args a);
...@@ -108,6 +119,24 @@ float layernorm2d_bwd_(const ck_tile::stream_config& s, layernorm2d_bwd_args a); ...@@ -108,6 +119,24 @@ float layernorm2d_bwd_(const ck_tile::stream_config& s, layernorm2d_bwd_args a);
struct layernorm2d_bwd_traits struct layernorm2d_bwd_traits
{ {
std::string data_type; std::string data_type;
};
template <typename data_type>
struct layernorm2d_bwd_b16_
{
/* data */
using Trait = trait_<data_type, 1, 1, 64, true>;
float operator() (layernorm2d_bwd_traits /*t*/,
layernorm2d_bwd_args a,
const ck_tile::stream_config& s) {
return layernorm2d_bwd_<Trait>(s, a);
}
};
template <typename data_type>
ck_tile::index_t layernorm2d_bwd_block_m() {
return layernorm2d_bwd_b16_<data_type>::Trait::Block_M;
}; };
float layernorm2d_bwd(layernorm2d_bwd_traits, layernorm2d_bwd_args, const ck_tile::stream_config&); float layernorm2d_bwd(layernorm2d_bwd_traits, layernorm2d_bwd_args, const ck_tile::stream_config&);
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_bwd.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
......
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename YDataType,
typename MeanDataType,
typename InvStdDataType>
CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataType>& x_m_n,
const HostTensor<YDataType>& dy_m_n,
const HostTensor<MeanDataType>& mean_m,
const HostTensor<InvStdDataType>& inv_std_m,
HostTensor<GammaDataType>& dgamma_mpart_n,
HostTensor<BetaDataType>& dbeta_mpart_n)
{
const auto MN = x_m_n.mDesc.get_lengths();
const auto M = MN[0];
const auto N = MN[1];
const auto PartM = dgamma_mpart_n.mDesc.get_lengths()[0];
const auto MLoop = (M + PartM - 1) / PartM;
auto f = [&](auto m) {
const auto m_offset = m * MLoop;
for(int n = 0; n < N; ++n)
{
ComputeDataType gamma_acc = 0;
ComputeDataType beta_acc = 0;
for(int inner_m = 0; inner_m < MLoop && m_offset + inner_m < M; inner_m++)
{
const ComputeDataType mean = ck_tile::type_convert<ComputeDataType>(mean_m(m_offset + inner_m));
const ComputeDataType inv_std = ck_tile::type_convert<ComputeDataType>(inv_std_m(m_offset + inner_m));
const ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m_offset + inner_m, n));
const ComputeDataType dy = ck_tile::type_convert<ComputeDataType>(dy_m_n(m_offset + inner_m, n));
gamma_acc += dy * (x - mean) * inv_std;
beta_acc += dy;
}
dgamma_mpart_n(m, n) = ck_tile::type_convert<GammaDataType>(gamma_acc);
dbeta_mpart_n(m, n) = ck_tile::type_convert<BetaDataType>(beta_acc);
}
};
make_ParallelTensorFunctor(f, PartM)(std::thread::hardware_concurrency());
}
} // namespace ck_tile
...@@ -11,13 +11,13 @@ namespace ck_tile { ...@@ -11,13 +11,13 @@ namespace ck_tile {
// host side args // host side args
struct Layernorm2dBwdGammaBetaHostArgs struct Layernorm2dBwdGammaBetaHostArgs
{ {
const void* p_x;
const void* p_dY; const void* p_dY;
const void* p_mean; const void* p_mean;
const void* p_invStd; const void* p_invStd;
void* p_dGamma; void* p_dGamma;
void* p_dBeta; void* p_dBeta;
void* p_yMul;
index_t m; index_t m;
index_t n; index_t n;
...@@ -51,13 +51,13 @@ struct Layernorm2dBwdGammaBeta ...@@ -51,13 +51,13 @@ struct Layernorm2dBwdGammaBeta
struct Kargs struct Kargs
{ {
const void* p_x;
const void* p_dY; const void* p_dY;
const void* p_mean; const void* p_mean;
const void* p_invStd; const void* p_invStd;
void* p_dGamma; void* p_dGamma;
void* p_dBeta; void* p_dBeta;
void* p_yMul;
index_t m; index_t m;
index_t n; index_t n;
...@@ -67,12 +67,12 @@ struct Layernorm2dBwdGammaBeta ...@@ -67,12 +67,12 @@ struct Layernorm2dBwdGammaBeta
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{ {
return Kargs{hargs.p_dY, return Kargs{hargs.p_x,
hargs.p_dY,
hargs.p_mean, hargs.p_mean,
hargs.p_invStd, hargs.p_invStd,
hargs.p_dGamma, hargs.p_dGamma,
hargs.p_dBeta, hargs.p_dBeta,
hargs.p_yMul,
hargs.m, hargs.m,
hargs.n, hargs.n,
hargs.stride}; hargs.stride};
...@@ -119,7 +119,22 @@ struct Layernorm2dBwdGammaBeta ...@@ -119,7 +119,22 @@ struct Layernorm2dBwdGammaBeta
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
const auto iM = get_block_id() * Block_M; const auto block_id = get_block_id();
const auto iM = block_id * Block_M;
const auto x_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1));
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto dy_window = [&]() { const auto dy_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
...@@ -144,7 +159,7 @@ struct Layernorm2dBwdGammaBeta ...@@ -144,7 +159,7 @@ struct Layernorm2dBwdGammaBeta
const auto tmp2_ = const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{}); pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {0}); return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}(); }();
const auto invstd_window = [&]() { const auto invstd_window = [&]() {
...@@ -156,36 +171,37 @@ struct Layernorm2dBwdGammaBeta ...@@ -156,36 +171,37 @@ struct Layernorm2dBwdGammaBeta
const auto tmp2_ = const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{}); pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {0}); return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}(); }();
const auto dgamma_window = [&]() { auto dgamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_dGamma), static_cast<GammaDataType*>(kargs.p_dGamma),
make_tuple(kargs.n), make_tuple(gridDim.x, kargs.n),
make_tuple(1)); make_tuple(kargs.n, 1));
const auto tmp2_ = const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{}); pad_tensor_view(tmp_, make_tuple(number<1>{}, number<Block_N>{}), sequence<false, kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0}); return make_tile_window(tmp2_, make_tuple(number<1>{}, number<Block_N>{}), {block_id, 0});
}(); }();
const auto dbeta_window = [&]() { auto dbeta_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>( const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(kargs.p_dBeta), static_cast<BetaDataType*>(kargs.p_dBeta),
make_tuple(kargs.n), make_tuple(gridDim.x, kargs.n),
make_tuple(1)); make_tuple(kargs.n, 1));
const auto tmp2_ = const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{}); pad_tensor_view(tmp_, make_tuple(number<1>{}, number<Block_N>{}), sequence<false, kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {0}); return make_tile_window(tmp2_, make_tuple(number<1>{}, number<Block_N>{}), {block_id, 0});
}(); }();
__shared__ char smem[GetSmemSize()]; __shared__ char smem[GetSmemSize()];
Pipeline{}(dy_window, Pipeline{}(x_window,
dy_window,
mean_window, mean_window,
invstd_window, invstd_window,
dgamma_window, dgamma_window,
......
...@@ -10,7 +10,7 @@ namespace ck_tile { ...@@ -10,7 +10,7 @@ namespace ck_tile {
struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
{ {
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeDyBlockTileDistribution() CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{ {
using S = typename Problem::BlockShape; using S = typename Problem::BlockShape;
...@@ -18,11 +18,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy ...@@ -18,11 +18,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
tile_distribution_encoding< tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>, tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N>>, sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>, tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 0>, sequence<2, 1>>, tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1>, sequence<1, 2>,
sequence<0>>{}); sequence<0, 0>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMeanBlockTileDistribution() CK_TILE_DEVICE static constexpr auto MakeMeanBlockTileDistribution()
...@@ -39,20 +39,22 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy ...@@ -39,20 +39,22 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
sequence<0>>{}); sequence<0>>{});
} }
// template <typename Problem> template <typename Problem>
// CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution() CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
// { {
// using S = typename Problem::BlockShape; using S = typename Problem::BlockShape;
// return make_static_tile_distribution( return make_static_tile_distribution(
// tile_distribution_encoding< tile_distribution_encoding<
// sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>, sequence<>,
// tuple<sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N>>, tuple<sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
// tuple<sequence<0, 1>, sequence<0, 1>>, sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N>>,
// tuple<sequence<1, 0>, sequence<2, 1>>, tuple<sequence<1, 2>, sequence<1, 2>>,
// sequence<0>, tuple<sequence<0, 1>, sequence<1, 2>>,
// sequence<0>>{}); sequence<2>,
// } sequence<0>>{});
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
......
...@@ -24,7 +24,7 @@ struct Layernorm2dBwdGammaBetaPipeline ...@@ -24,7 +24,7 @@ struct Layernorm2dBwdGammaBetaPipeline
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>; using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>; using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dBwdGammaBetaProblem::kPadM static constexpr bool kPadM = false;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr const char* name = []() { static constexpr const char* name = []() {
...@@ -35,31 +35,13 @@ struct Layernorm2dBwdGammaBetaPipeline ...@@ -35,31 +35,13 @@ struct Layernorm2dBwdGammaBetaPipeline
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
// template <typename DumpTensor_> template <typename XWindow,
// CK_TILE_DEVICE void dump(const DumpTensor_& x) const
// {
// constexpr auto I0 = number<0>{};
// constexpr auto I1 = number<1>{};
// constexpr auto spans = DumpTensor_::get_distributed_spans();
// sweep_tile_span(spans[I1], [&](auto i1) {
// sweep_tile_span(spans[I0], [&](auto i0) {
// constexpr auto in_dstr_idx = make_tuple(i0, i1);
// auto v = ck_tile::type_convert<float>(x[in_dstr_idx]);
// index_t tid =
// (threadIdx.z * (blockDim.x * blockDim.y)) + (threadIdx.y * blockDim.x) + threadIdx.x;
// printf("%d %f\n", tid, v);
// });
// });
// }
template <typename DYWindow,
typename MeanWindow, typename MeanWindow,
typename InvStdWindow, typename InvStdWindow,
typename DGammaWindow, typename DGammaWindow,
typename DBetaWindow> typename DBetaWindow>
CK_TILE_DEVICE auto operator()(const DYWindow& dy_window_, CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XWindow& dy_window_,
const MeanWindow& mean_window_, const MeanWindow& mean_window_,
const InvStdWindow& inv_std_window_, const InvStdWindow& inv_std_window_,
DGammaWindow& gamma_window_, DGammaWindow& gamma_window_,
...@@ -67,52 +49,43 @@ struct Layernorm2dBwdGammaBetaPipeline ...@@ -67,52 +49,43 @@ struct Layernorm2dBwdGammaBetaPipeline
ck_tile::index_t row_size, ck_tile::index_t row_size,
void* smem) const void* smem) const
{ {
const auto dy_window = make_tile_window(dy_window_, auto gamma_beta_dist = Policy::template MakeGammaBetaBlockTileDistribution<Problem>();
Policy::template MakeDyBlockTileDistribution<Problem>()); auto mean_dist = Policy::template MakeMeanBlockTileDistribution<Problem>();
const auto mean_window = make_tile_window( auto x_dist = Policy::template MakeXBlockTileDistribution<Problem>();
mean_window_, Policy::template MakeMeanBlockTileDistribution<Problem>());
const auto inv_std_window = make_tile_window( const auto x_window = make_tile_window(x_window_, x_dist);
inv_std_window_, Policy::template MakeMeanBlockTileDistribution<Problem>()); const auto dy_window = make_tile_window(dy_window_, x_dist);
// const auto gamma_window = make_tile_window( const auto mean_window = make_tile_window(mean_window_, mean_dist);
// gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); const auto inv_std_window = make_tile_window(inv_std_window_, mean_dist);
// const auto beta_window = make_tile_window( const auto x_tile = load_tile(x_window);
// beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>()); const auto dy_tile = load_tile(dy_window);
const auto mean_tile = load_tile(mean_window);
const auto dy = load_tile(dy_window); const auto inv_std_tile = load_tile(inv_std_window);
const auto mean = load_tile(mean_window);
const auto inv_std = load_tile(inv_std_window); auto gamma_window = make_tile_window(gamma_window_, gamma_beta_dist);
auto beta_window = make_tile_window(beta_window_, gamma_beta_dist);
// auto y = make_static_distributed_tensor<YDataType>(dy.get_tile_distribution()); auto gamma_tile = make_static_distributed_tensor<GammaDataType>(gamma_beta_dist);
sweep_tile(mean, [&](auto idx) { auto beta_tile = make_static_distributed_tensor<BetaDataType>(gamma_beta_dist);
sweep_tile(x_tile, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// constexpr auto j_idx = make_tuple(idx[number<1>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]);
constexpr auto gb_idx = make_tuple(number<0>{}, idx[number<1>{}]);
index_t tid = (threadIdx.y * blockDim.x) + threadIdx.x; auto &gamma = gamma_tile(gb_idx);
const auto m = type_convert<float>(mean[i_idx]); auto &beta = beta_tile(gb_idx);
if(blockIdx.x == 0 && blockIdx.y == 0) const auto x = type_convert<ComputeDataType>(x_tile[idx]);
printf("%d %f\n", tid, m); const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
const auto mean = type_convert<ComputeDataType>(mean_tile[i_idx]);
const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[i_idx]);
beta += type_convert<BetaDataType>(dy);
gamma += type_convert<GammaDataType>(dy * (x - mean) * inv_std);
// index_t tid = (threadIdx.y * blockDim.x) + threadIdx.x;
// if(blockIdx.x < 3 && blockIdx.y == 0 && tid < 3) {
// printf("bid %d tid %d count %d gb %f %f\n",blockIdx.x, tid, count, type_convert<float>(g), type_convert<float>(b));
// }
}); });
// dump(dy); store_tile(gamma_window, gamma_tile);
// dump(mean); store_tile(beta_window, beta_tile);
// dump(inv_std);
*reinterpret_cast<char *>(smem) = row_size;
// layernorm computation
// auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
// sweep_tile(y, [&, mean_ = mean](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// constexpr auto j_idx = make_tuple(idx[number<1>{}]);
// const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
// const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
// const auto x_ = type_convert<ComputeDataType>(x[idx]);
// auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
// y(idx) = type_convert<YDataType>(y_);
// });
// store_tile(y_window, y);
} }
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment