Commit 5c9cba82 authored by felix's avatar felix
Browse files

add l

parent 54f0e6f4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
// host side args
struct Layernorm2dBwdGammaBetaHostArgs
{
const void* p_dY;
const void* p_mean;
const void* p_invStd;
void* p_dGamma;
void* p_dBeta;
void* p_yMul;
index_t m;
index_t n;
index_t stride; // row_stride
};
// TODO: Extract some type to wrapper class
template <typename Pipeline_>
struct Layernorm2dBwdGammaBeta
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = remove_cvref_t<typename Problem::InvStdDataType>;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
struct Kargs
{
const void* p_dY;
const void* p_mean;
const void* p_invStd;
void* p_dGamma;
void* p_dBeta;
void* p_yMul;
index_t m;
index_t n;
index_t stride; // row_stride
};
using Hargs = Layernorm2dBwdGammaBetaHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{hargs.p_dY,
hargs.p_mean,
hargs.p_invStd,
hargs.p_dGamma,
hargs.p_dBeta,
hargs.p_yMul,
hargs.m,
hargs.n,
hargs.stride};
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
return (hargs.m + Block_M - 1) / Block_M;
}
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
// in byte
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
CK_TILE_HOST static std::string GetName()
{
// clang-format off
using S_ = typename Problem::BlockShape;
auto surfix = [&] () {
std::string n;
if (kPadN) n += "_pn";
if (kSaveMeanInvStd) n += "_mv";
if (kTwoPass) n += "_2p";
return n; }();
#define _SS_ std::string
#define _TS_ std::to_string
return _SS_("layernorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix;
#undef _SS_
#undef _TS_
// clang-format on
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const auto iM = get_block_id() * Block_M;
const auto dy_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const YDataType*>(kargs.p_dY),
make_tuple(kargs.m, kargs.n),
make_tuple(kargs.stride, 1));
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
// check the max count dynamically
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
}();
const auto mean_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const MeanDataType*>(kargs.p_mean),
make_tuple(kargs.m),
make_tuple(1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {0});
}();
const auto invstd_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const MeanDataType*>(kargs.p_invStd),
make_tuple(kargs.m),
make_tuple(1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {0});
}();
const auto dgamma_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_dGamma),
make_tuple(kargs.n),
make_tuple(1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}();
const auto dbeta_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(kargs.p_dBeta),
make_tuple(kargs.n),
make_tuple(1));
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {0});
}();
__shared__ char smem[GetSmemSize()];
Pipeline{}(dy_window,
mean_window,
invstd_window,
dgamma_window,
dbeta_window,
kargs.n,
smem);
}
};
} // namespace ck_tile
......@@ -94,7 +94,7 @@ struct Layernorm2dFwd
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
return (hargs.m + Block_M - 1) / Block_M;
return dim3((hargs.n + Block_N - 1) / Block_N,(hargs.m + Block_M - 1) / Block_M);
}
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
......@@ -124,7 +124,7 @@ struct Layernorm2dFwd
#define _SS_ std::string
#define _TS_ std::to_string
return _SS_("layernorm2d_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
return _SS_("layernorm2d_bwd_gamma_beta_") + _SS_(t2s<XDataType>::name) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = Layernorm2dBwdGammaBetaPipelineDefaultPolicy>
struct Layernorm2dBwdGammaBetaPipeline
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using GammaDataType = ck_tile::remove_cvref_t<typename Problem::GammaDataType>;
using BetaDataType = ck_tile::remove_cvref_t<typename Problem::BetaDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using MeanDataType = ck_tile::remove_cvref_t<typename Problem::MeanDataType>;
using InvStdDataType = ck_tile::remove_cvref_t<typename Problem::InvStdDataType>;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dBwdGammaBetaProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr const char* name = []() {
return "bwd_gamma_beta"
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename DYWindow,
typename MeanWindow,
typename InvStdWindow,
typename DGammaWindow,
typename DBetaWindow>
CK_TILE_DEVICE auto operator()(const DYWindow& dy_window_,
const MeanWindow& mean_window_,
const InvStdWindow& inv_std_window_,
DGammaWindow& gamma_window_,
DBetaWindow& beta_window_,
ck_tile::index_t row_size,
void* smem) const
{
const auto dy_window =
make_tile_window(dy_window_, Policy::template MakeDyBlockTileDistribution<Problem>());
const auto mean_window = make_tile_window(
mean_window_, Policy::template MakeMeanBlockTileDistribution<Problem>());
// const auto gamma_window = make_tile_window(
// gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
// const auto beta_window = make_tile_window(
// beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
const auto dy = load_tile(dy_window);
const auto mean = load_tile(mean_window);
// layernorm computation
// auto y = make_static_distributed_tensor<YDataType>(x.get_tile_distribution());
// sweep_tile(y, [&, mean_ = mean](auto idx) {
// constexpr auto i_idx = make_tuple(idx[number<0>{}]);
// constexpr auto j_idx = make_tuple(idx[number<1>{}]);
// const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
// const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
// const auto x_ = type_convert<ComputeDataType>(x[idx]);
// auto y_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
// y(idx) = type_convert<YDataType>(y_);
// });
// store_tile(y_window, y);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/block/block_welford.hpp"
namespace ck_tile {
struct Layernorm2dBwdGammaBetaPipelineDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeDyBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
sequence<1>,
sequence<0>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMeanBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N>,
tuple<sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>>,
tuple<sequence<1, 0>, sequence<1, 0>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
sequence<1>,
sequence<0>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return 1;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename XDataType_,
typename GammaDataType_,
typename BetaDataType_,
typename ComputeDataType_,
typename YDataType_,
typename MeanDataType_,
typename InvStdDataType_,
typename BlockShape_,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
struct Layernorm2dBwdGammaBetaPipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kPadN = kPadN_;
};
} // namespace ck_tile
......@@ -37,4 +37,33 @@ struct Layernorm2dFwdPipelineProblem
static constexpr bool kTwoPass = kTwoPass_;
};
template <typename XDataType_,
typename GammaDataType_,
typename BetaDataType_,
typename ComputeDataType_,
typename YDataType_,
typename MeanDataType_,
typename InvStdDataType_,
typename BlockShape_,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
struct Layernorm2dFwdPipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment