"vscode:/vscode.git/clone" did not exist on "98acc5a8874faab28b82c28936f4b400b389f5d6"
Commit d62f0358 authored by rocking's avatar rocking
Browse files

Remove save mean and inv std

parent 29cff07e
......@@ -52,11 +52,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<MeanDataType> mean_host_ref({M});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M});
// TODO - move SAVE_MEAN_INV_STD to user args
#ifdef SAVE_MEAN_INV_STD
ck_tile::HostTensor<MeanDataType> mean_host_dev({M});
ck_tile::HostTensor<InvStdDataType> invStd_host_dev({M});
#endif
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
......@@ -66,10 +61,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
#ifdef SAVE_MEAN_INV_STD
ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes());
#endif
x_buf.ToDevice(x_host.data());
gamma_buf.ToDevice(gamma_host.data());
......@@ -81,13 +72,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(),
#ifdef SAVE_MEAN_INV_STD
mean_buf.GetDeviceBuffer(),
invStd_buf.GetDeviceBuffer(),
#else
nullptr,
nullptr,
#endif
epsilon,
M,
N};
......
......@@ -8,11 +8,6 @@
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
struct layernorm2d_fwd_traits
{
std::string data_type;
};
template <typename DataType>
struct LayerNormTypeConfig;
......@@ -23,13 +18,8 @@ struct LayerNormTypeConfig<ck_tile::half_t>
using YDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using BetaDataType = ck_tile::half_t;
#ifdef SAVE_MEAN_INV_STD
using MeanDataType = ck_tile::half_t;
using InvStdDataType = ck_tile::half_t;
#else
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float;
};
......@@ -40,16 +30,12 @@ struct LayerNormTypeConfig<float>
using YDataType = float;
using GammaDataType = float;
using BetaDataType = float;
#ifdef SAVE_MEAN_INV_STD
using MeanDataType = float;
using InvStdDataType = float;
#else
using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type;
#endif
using ComputeDataType = float;
};
// runtime args
struct layernorm2d_fwd_args
{
const void* p_x;
......@@ -63,5 +49,10 @@ struct layernorm2d_fwd_args
ck_tile::index_t N;
};
// host API
// This is the public API, will be generated by script
struct layernorm2d_fwd_traits
{
std::string data_type;
};
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
......@@ -14,6 +14,7 @@ template <typename InOutDataType,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kPadN,
bool kSaveMeanInvStd,
bool kTwoPass>
struct layernorm_dispatch
{
......@@ -38,6 +39,7 @@ struct layernorm_dispatch
typename LayerNormTypeConfig<InOutDataType>::InvStdDataType,
Shape,
kPadN,
kSaveMeanInvStd,
kTwoPass>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
......@@ -75,6 +77,13 @@ template <typename InOutDataType,
bool kTwoPass = false>
float run_layernorm(const layernorm2d_fwd_args& param, ck_tile::stream_config stream)
{
return layernorm_dispatch<InOutDataType, NRepeat, NThread, VectorAccessSize, kPadN, kTwoPass>::
Run(param, stream);
// TODO - Add SaveMeanInvStd instance
constexpr bool kSaveMeanInvStd = false;
return layernorm_dispatch<InOutDataType,
NRepeat,
NThread,
VectorAccessSize,
kSaveMeanInvStd,
kPadN,
kTwoPass>::Run(param, stream);
};
......@@ -26,8 +26,8 @@ struct Layernorm2dFwd
static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = !std::is_same_v<MeanDataType, ck_tile::null_type>;
static constexpr bool kSaveInvStd = !std::is_same_v<InvStdDataType, ck_tile::null_type>;
static constexpr bool kSaveMean = Problem::kSaveMeanInvStd;
static constexpr bool kSaveInvStd = Problem::kSaveMeanInvStd;
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
......
......@@ -16,6 +16,7 @@ template <typename XDataType_,
typename InvStdDataType_,
typename BlockShape_,
bool kPadN_,
bool kSaveMeanInvStd_,
bool kTwoPass_>
struct BlockLayernorm2dFwdProblem
{
......@@ -27,7 +28,9 @@ struct BlockLayernorm2dFwdProblem
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kTwoPass = kTwoPass_;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment