Commit c03a937d authored by dummycoderfe's avatar dummycoderfe
Browse files

one block ok

parent 7db609fe
......@@ -4,25 +4,6 @@
#include <ck_tile/core.hpp>
#include "layernorm2d_bwd.hpp"
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
bool kPadN_>
using trait_ = layernorm2d_bwd_traits_<DataType_,
Repeat_M_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
kPadN_>;
template <typename data_type>
float layernorm2d_bwd_b16_(layernorm2d_bwd_traits /*t*/,
layernorm2d_bwd_args a,
const ck_tile::stream_config& s)
{
return layernorm2d_bwd_<trait_<data_type, 1, 1, 64, true>>(s, a);
}
float layernorm2d_bwd(layernorm2d_bwd_traits t,
layernorm2d_bwd_args a,
const ck_tile::stream_config& s)
......@@ -31,11 +12,11 @@ float layernorm2d_bwd(layernorm2d_bwd_traits t,
float r = -1;
if(t.data_type.compare("fp16") == 0)
{
return layernorm2d_bwd_b16_<ck_tile::fp16_t>(t, a, s);
return layernorm2d_bwd_b16_<ck_tile::fp16_t>{}(t, a, s);
}
else if(t.data_type.compare("bf16") == 0)
{
return layernorm2d_bwd_b16_<ck_tile::bf16_t>(t, a, s);
return layernorm2d_bwd_b16_<ck_tile::bf16_t>{}(t, a, s);
}
if(r < 0)
throw std::runtime_error("Without supported instances!");
......
......@@ -11,17 +11,6 @@
using S = ck_tile::stream_config;
using A = layernorm2d_bwd_args;
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
bool kPadN_>
using trait_ = layernorm2d_bwd_traits_<DataType_,
Repeat_M_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
kPadN_>;
template <typename Traits_>
float layernorm2d_bwd_(const S& s, A a)
{
......
......@@ -64,21 +64,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> dy_host({m, n}, {stride, 1});
ck_tile::HostTensor<MeanDataType> mean_host({m});
ck_tile::HostTensor<InvStdDataType> invStd_host({m});
ck_tile::HostTensor<GammaDataType> dgamma_host_dev({n});
ck_tile::HostTensor<BetaDataType> dbeta_host_dev({n});
ck_tile::HostTensor<GammaDataType> dgamma_host_ref({n});
ck_tile::HostTensor<BetaDataType> dbeta_host_ref({n});
ck_tile::index_t blockM = layernorm2d_bwd_block_m<XDataType>();
ck_tile::index_t reduce_m = (m + blockM - 1) / blockM;
ck_tile::HostTensor<GammaDataType> dgamma_host_dev({reduce_m, n});
ck_tile::HostTensor<BetaDataType> dbeta_host_dev({reduce_m, n});
ck_tile::HostTensor<GammaDataType> dgamma_host_ref({reduce_m, n});
ck_tile::HostTensor<BetaDataType> dbeta_host_ref({reduce_m, n});
// ck_tile::FillMonotonicSeq<YDataType>{}(dy_host);
ck_tile::FillUniformDistribution<YDataType>{-.5f, .5f}(dy_host);
// ck_tile::FillUniformDistribution<MeanDataType>{-.5f, .5f}(mean_host);
ck_tile::FillMonotonicSeq<MeanDataType>{}(mean_host);
ck_tile::FillUniformDistribution<MeanDataType>{-.5f, .5f}(mean_host);
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
// ck_tile::FillMonotonicSeq<MeanDataType>{}(mean_host);
ck_tile::FillUniformDistribution<InvStdDataType>{-.5f, .5f}(invStd_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dy_buf(dy_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem mean_buf(mean_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem invStd_buf(invStd_host.get_element_space_size_in_bytes());
......@@ -86,6 +92,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem dgamma_buf(dgamma_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem dbeta_buf(dbeta_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
dy_buf.ToDevice(dy_host.data());
mean_buf.ToDevice(mean_host.data());
invStd_buf.ToDevice(invStd_host.data());
......@@ -94,13 +101,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
layernorm2d_bwd_traits traits{data_type};
layernorm2d_bwd_args args{dy_buf.GetDeviceBuffer(),
layernorm2d_bwd_args args{x_buf.GetDeviceBuffer(),
dy_buf.GetDeviceBuffer(),
mean_buf.GetDeviceBuffer(),
invStd_buf.GetDeviceBuffer(),
dgamma_buf.GetDeviceBuffer(),
dbeta_buf.GetDeviceBuffer(),
nullptr,
m,
n,
stride};
......@@ -118,41 +124,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation)
{
// // reference
// ck_tile::reference_layernorm2d_bwd<XDataType,
// GammaDataType,
// BetaDataType,
// ComputeDataType,
// YDataType,
// MeanDataType,
// InvStdDataType>(
// x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
// y_buf.FromDevice(y_host_dev.data());
// auto [rtol, atol] = get_elimit<DataType>();
// if(stride == n)
// {
// pass = ck_tile::check_err(
// y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
// }
// else
// {
// for(int i_r = 0; i_r < m; i_r++)
// {
// std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
// y_host_dev.begin() + i_r * stride + n);
// std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
// y_host_ref.begin() + i_r * stride + n);
// pass &= ck_tile::check_err(y_host_dev_row,
// y_host_ref_row,
// std::string("OUT[") + std::to_string(i_r) +
// std::string("] Error: Incorrect results!"),
// rtol,
// atol);
// }
// }
// reference
ck_tile::reference_layernorm2d_bwd_gamma_part<XDataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(
x_host, dy_host, mean_host, invStd_host, dgamma_host_ref, dbeta_host_ref);
dgamma_buf.FromDevice(dgamma_host_dev.data());
dbeta_buf.FromDevice(dbeta_host_dev.data());
auto [rtol, atol] = get_elimit<DataType>();
pass = ck_tile::check_err(
dgamma_host_dev, dgamma_host_ref, std::string("GAMMA OUT Error: Incorrect results!"), rtol, atol);
pass &= ck_tile::check_err(
dbeta_host_dev, dbeta_host_ref, std::string("BETA OUT Error: Incorrect results!"), rtol, atol);
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
......
......@@ -101,6 +101,17 @@ struct layernorm2d_bwd_traits_
static constexpr bool kPadN = kPadN_;
};
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
bool kPadN_>
using trait_ = layernorm2d_bwd_traits_<DataType_,
Repeat_M_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
kPadN_>;
template <typename Traits_>
float layernorm2d_bwd_(const ck_tile::stream_config& s, layernorm2d_bwd_args a);
......@@ -108,6 +119,24 @@ float layernorm2d_bwd_(const ck_tile::stream_config& s, layernorm2d_bwd_args a);
struct layernorm2d_bwd_traits
{
std::string data_type;
};
template <typename data_type>
struct layernorm2d_bwd_b16_
{
/* data */
using Trait = trait_<data_type, 1, 1, 64, true>;
float operator() (layernorm2d_bwd_traits /*t*/,
layernorm2d_bwd_args a,
const ck_tile::stream_config& s) {
return layernorm2d_bwd_<Trait>(s, a);
}
};
template <typename data_type>
ck_tile::index_t layernorm2d_bwd_block_m() {
return layernorm2d_bwd_b16_<data_type>::Trait::Block_M;
};
float layernorm2d_bwd(layernorm2d_bwd_traits, layernorm2d_bwd_args, const ck_tile::stream_config&);
......@@ -22,6 +22,7 @@
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_bwd.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"
......
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename YDataType,
typename MeanDataType,
typename InvStdDataType>
CK_TILE_HOST void reference_layernorm2d_bwd_gamma_part(const HostTensor<XDataType>& x_m_n,
const HostTensor<YDataType>& dy_m_n,
const HostTensor<MeanDataType>& mean_m,
const HostTensor<InvStdDataType>& inv_std_m,
HostTensor<GammaDataType>& dgamma_mpart_n,
HostTensor<BetaDataType>& dbeta_mpart_n)
{
const auto MN = x_m_n.mDesc.get_lengths();
const auto M = MN[0];
const auto N = MN[1];
const auto PartM = dgamma_mpart_n.mDesc.get_lengths()[0];
const auto MLoop = (M + PartM - 1) / PartM;
auto f = [&](auto m) {
const auto m_offset = m * MLoop;
for(int n = 0; n < N; ++n)
{
ComputeDataType gamma_acc = 0;
ComputeDataType beta_acc = 0;
for(int inner_m = 0; inner_m < MLoop && m_offset + inner_m < M; inner_m++)
{
const ComputeDataType mean = ck_tile::type_convert<ComputeDataType>(mean_m(m_offset + inner_m));
const ComputeDataType inv_std = ck_tile::type_convert<ComputeDataType>(inv_std_m(m_offset + inner_m));
const ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m_offset + inner_m, n));
const ComputeDataType dy = ck_tile::type_convert<ComputeDataType>(dy_m_n(m_offset + inner_m, n));
gamma_acc += dy * (x - mean) * inv_std;
beta_acc += dy;
}
dgamma_mpart_n(m, n) = ck_tile::type_convert<GammaDataType>(gamma_acc);
dbeta_mpart_n(m, n) = ck_tile::type_convert<BetaDataType>(beta_acc);
}
};
make_ParallelTensorFunctor(f, PartM)(std::thread::hardware_concurrency());
}
} // namespace ck_tile
......@@ -11,13 +11,13 @@ namespace ck_tile {
// host side args
struct Layernorm2dBwdGammaBetaHostArgs
{
const void* p_x;
const void* p_dY;
const void* p_mean;
const void* p_invStd;
void* p_dGamma;
void* p_dBeta;
void* p_yMul;
index_t m;
index_t n;
......@@ -51,13 +51,13 @@ struct Layernorm2dBwdGammaBeta
struct Kargs
{
const void* p_x;
const void* p_dY;
const void* p_mean;
const void* p_invStd;
void* p_dGamma;
void* p_dBeta;
void* p_yMul;
index_t m;
index_t n;
......@@ -67,12 +67,12 @@ struct Layernorm2dBwdGammaBeta
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{hargs.p_dY,
return Kargs{hargs.p_x,
hargs.p_dY,
hargs.p_mean,
hargs.p_invStd,
hargs.p_dGamma,
hargs.p_dBeta,
hargs.p_yMul,
hargs.m,
hargs.n,
hargs.stride};
......@@ -119,7 +119,22 @@ struct Layernorm2dBwdGammaBeta
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const auto iM = get_block_id() * Block_M;
const auto block_id = get_block_id();
const auto iM = block_id * Block_M;
const auto x_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1));
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto dy_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
......@@ -144,7 +159,7 @@ struct Layernorm2dBwdGammaBeta
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {0});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}();
const auto invstd_window = [&]() {
......@@ -156,36 +171,37 @@ struct Layernorm2dBwdGammaBeta
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {0});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
}();
const auto dgamma_window = [&]() {
auto dgamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_dGamma),
make_tuple(kargs.n),
make_tuple(1));
static_cast<GammaDataType*>(kargs.p_dGamma),
make_tuple(gridDim.x, kargs.n),
make_tuple(kargs.n, 1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
pad_tensor_view(tmp_, make_tuple(number<1>{}, number<Block_N>{}), sequence<false, kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
return make_tile_window(tmp2_, make_tuple(number<1>{}, number<Block_N>{}), {block_id, 0});
}();
const auto dbeta_window = [&]() {
auto dbeta_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(kargs.p_dBeta),
make_tuple(kargs.n),
make_tuple(1));
static_cast<BetaDataType*>(kargs.p_dBeta),
make_tuple(gridDim.x, kargs.n),
make_tuple(kargs.n, 1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {0});
pad_tensor_view(tmp_, make_tuple(number<1>{}, number<Block_N>{}), sequence<false, kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<1>{}, number<Block_N>{}), {block_id, 0});
}();
__shared__ char smem[GetSmemSize()];
Pipeline{}(dy_window,
Pipeline{}(x_window,
dy_window,
mean_window,
invstd_window,
dgamma_window,
......
......@@ -10,7 +10,7 @@ namespace ck_tile {
struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeDyBlockTileDistribution()
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
......@@ -18,11 +18,11 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N>>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
sequence<1>,
sequence<0>>{});
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMeanBlockTileDistribution()
......@@ -39,20 +39,22 @@ struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
sequence<0>>{});
}
// template <typename Problem>
// CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
// {
// using S = typename Problem::BlockShape;
// return make_static_tile_distribution(
// tile_distribution_encoding<
// sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
// tuple<sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N>>,
// tuple<sequence<0, 1>, sequence<0, 1>>,
// tuple<sequence<1, 0>, sequence<2, 1>>,
// sequence<0>,
// sequence<0>>{});
// }
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<2>,
sequence<0>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
......
......@@ -24,7 +24,7 @@ struct Layernorm2dBwdGammaBetaPipeline
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dBwdGammaBetaProblem::kPadM
static constexpr bool kPadM = false;
static constexpr bool kPadN = Problem::kPadN;
static constexpr const char* name = []() {
......@@ -35,31 +35,13 @@ struct Layernorm2dBwdGammaBetaPipeline
{
return Policy::template GetSmemSize<Problem>();
}
// template <typename DumpTensor_>
// CK_TILE_DEVICE void dump(const DumpTensor_& x) const
// {
// constexpr auto I0 = number<0>{};
// constexpr auto I1 = number<1>{};
// constexpr auto spans = DumpTensor_::get_distributed_spans();
// sweep_tile_span(spans[I1], [&](auto i1) {
// sweep_tile_span(spans[I0], [&](auto i0) {
// constexpr auto in_dstr_idx = make_tuple(i0, i1);
// auto v = ck_tile::type_convert<float>(x[in_dstr_idx]);
// index_t tid =
// (threadIdx.z * (blockDim.x * blockDim.y)) + (threadIdx.y * blockDim.x) + threadIdx.x;
// printf("%d %f\n", tid, v);
// });
// });
// }
template <typename DYWindow,
template <typename XWindow,
typename MeanWindow,
typename InvStdWindow,
typename DGammaWindow,
typename DBetaWindow>
CK_TILE_DEVICE auto operator()(const DYWindow& dy_window_,
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XWindow& dy_window_,
const MeanWindow& mean_window_,
const InvStdWindow& inv_std_window_,
DGammaWindow& gamma_window_,
......@@ -67,52 +49,43 @@ struct Layernorm2dBwdGammaBetaPipeline
ck_tile::index_t row_size,
void* smem) const
{
const auto dy_window = make_tile_window(dy_window_,
Policy::template MakeDyBlockTileDistribution<Problem>());
const auto mean_window = make_tile_window(
mean_window_, Policy::template MakeMeanBlockTileDistribution<Problem>());
const auto inv_std_window = make_tile_window(
inv_std_window_, Policy::template MakeMeanBlockTileDistribution<Problem>());
// const auto gamma_window = make_tile_window(
// gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
// const auto beta_window = make_tile_window(
// beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto dy = load_tile(dy_window);
const auto mean = load_tile(mean_window);
const auto inv_std = load_tile(inv_std_window);
// auto y = make_static_distributed_tensor<YDataType>(dy.get_tile_distribution());
sweep_tile(mean, [&](auto idx) {
auto gamma_beta_dist = Policy::template MakeGammaBetaBlockTileDistribution<Problem>();
auto mean_dist = Policy::template MakeMeanBlockTileDistribution<Problem>();
auto x_dist = Policy::template MakeXBlockTileDistribution<Problem>();
const auto x_window = make_tile_window(x_window_, x_dist);
const auto dy_window = make_tile_window(dy_window_, x_dist);
const auto mean_window = make_tile_window(mean_window_, mean_dist);
const auto inv_std_window = make_tile_window(inv_std_window_, mean_dist);
const auto x_tile = load_tile(x_window);
const auto dy_tile = load_tile(dy_window);
const auto mean_tile = load_tile(mean_window);
const auto inv_std_tile = load_tile(inv_std_window);
auto gamma_window = make_tile_window(gamma_window_, gamma_beta_dist);
auto beta_window = make_tile_window(beta_window_, gamma_beta_dist);
auto gamma_tile = make_static_distributed_tensor<GammaDataType>(gamma_beta_dist);
auto beta_tile = make_static_distributed_tensor<BetaDataType>(gamma_beta_dist);
sweep_tile(x_tile, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// constexpr auto j_idx = make_tuple(idx[number<1>{}]);
index_t tid = (threadIdx.y * blockDim.x) + threadIdx.x;
const auto m = type_convert<float>(mean[i_idx]);
if(blockIdx.x == 0 && blockIdx.y == 0)
printf("%d %f\n", tid, m);
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
constexpr auto gb_idx = make_tuple(number<0>{}, idx[number<1>{}]);
auto &gamma = gamma_tile(gb_idx);
auto &beta = beta_tile(gb_idx);
const auto x = type_convert<ComputeDataType>(x_tile[idx]);
const auto dy = type_convert<ComputeDataType>(dy_tile[idx]);
const auto mean = type_convert<ComputeDataType>(mean_tile[i_idx]);
const auto inv_std = type_convert<ComputeDataType>(inv_std_tile[i_idx]);
beta += type_convert<BetaDataType>(dy);
gamma += type_convert<GammaDataType>(dy * (x - mean) * inv_std);
// index_t tid = (threadIdx.y * blockDim.x) + threadIdx.x;
// if(blockIdx.x < 3 && blockIdx.y == 0 && tid < 3) {
// printf("bid %d tid %d count %d gb %f %f\n",blockIdx.x, tid, count, type_convert<float>(g), type_convert<float>(b));
// }
});
// dump(dy);
// dump(mean);
// dump(inv_std);
*reinterpret_cast<char *>(smem) = row_size;
// layernorm computation
// auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
// sweep_tile(y, [&, mean_ = mean](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// constexpr auto j_idx = make_tuple(idx[number<1>{}]);
// const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
// const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
// const auto x_ = type_convert<ComputeDataType>(x[idx]);
// auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
store_tile(gamma_window, gamma_tile);
store_tile(beta_window, beta_tile);
// y(idx) = type_convert<YDataType>(y_);
// });
// store_tile(y_window, y);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment