Commit f591ad27 authored by rocking's avatar rocking
Browse files

1. Separate gamma aand beta from affine

2. Check if argument is valid
parent 8e2d0ae7
......@@ -40,9 +40,11 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
1, // SliceM
8, // SliceK
1, // SrcVecDim (0=M, 1=K)
1, // SrcScalarPerVector
1, // AffineVecDim (0=M, 1=K)
1, // AffineScalarPerVector
8, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K)
8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector
1>; // OutScalarPerVector
template <typename XDataType,
......@@ -129,6 +131,7 @@ int main()
auto argument_ptr = device_instance.MakeArgumentPointer({M, N},
{Stride, 1},
{0, 1},
{0, 1},
{1},
1e-4,
x_dev.GetDeviceBuffer(),
......
......@@ -34,15 +34,22 @@ template <typename XDataType,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t AffineSrcVectorDim,
index_t AffineSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceLayernorm : public BaseOperator
{
static_assert(
((AffineSrcVectorDim == 0 && MThreadSliceSize % AffineSrcVectorSize == 0) ||
(AffineSrcVectorDim == 1 && KThreadSliceSize % AffineSrcVectorSize == 0)),
"Invalid thread slice sizes and/or affine vector sizes configuration, please check!");
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
(GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
static_assert(
((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) ||
(BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough;
......@@ -83,8 +90,10 @@ struct DeviceLayernorm : public BaseOperator
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
AffineSrcVectorDim,
AffineSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
OutDstVectorSize,
false>;
......@@ -92,7 +101,8 @@ struct DeviceLayernorm : public BaseOperator
{
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> affineStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<index_t> reduceDims,
AccDataType epsilon,
const XDataType* p_x,
......@@ -116,14 +126,16 @@ struct DeviceLayernorm : public BaseOperator
p_gamma_(p_gamma),
p_beta_(p_beta)
{
affineStrides_ =
shuffle_tensor_dimensions<Rank, NumReduceDim>(affineStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
}
AccDataType epsilon_;
const GammaDataType* p_gamma_;
const BetaDataType* p_beta_;
std::vector<index_t> affineStrides_;
std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_;
};
struct Invoker : public BaseInvoker
......@@ -133,9 +145,9 @@ struct DeviceLayernorm : public BaseOperator
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto gamma_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.affineStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
arg.inLengths_, arg.gammaStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto beta_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.affineStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
arg.inLengths_, arg.betaStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
......@@ -189,14 +201,50 @@ struct DeviceLayernorm : public BaseOperator
{
return false;
}
// TODO - Check AffineSrcVectorDim and AffineSrcVectorSize
// if fastest dim is not reduced
if constexpr(GammaSrcVectorDim == 0)
{
if(p_arg_->gammaStrides_[Reduction::NumInvariantDim - 1] != 1)
return (false);
if(p_arg_->invariant_lowest_length % GammaSrcVectorSize != 0)
return (false);
}
else // if fastest dim is reduced
{
if(p_arg_->gammaStrides_[Rank - 1] != 1)
return (false);
if(p_arg_->reduce_lowest_length % GammaSrcVectorSize != 0)
return (false);
}
// if fastest dim is not reduced
if constexpr(BetaSrcVectorDim == 0)
{
if(p_arg_->betaStrides_[Reduction::NumInvariantDim - 1] != 1)
return (false);
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0)
return (false);
}
else // if fastest dim is reduced
{
if(p_arg_->betaStrides_[Rank - 1] != 1)
return (false);
if(p_arg_->reduce_lowest_length % BetaSrcVectorSize != 0)
return (false);
}
return true;
};
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> affineStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<int> reduceDims,
AccDataType epsilon,
const void* p_x,
......@@ -206,7 +254,8 @@ struct DeviceLayernorm : public BaseOperator
{
return std::make_unique<Argument>(inLengths,
inStrides,
affineStrides,
gammaStrides,
betaStrides,
reduceDims,
epsilon,
static_cast<const XDataType*>(p_x),
......
......@@ -59,8 +59,10 @@ template <typename XDataType,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t AffineSrcVectorDim,
index_t AffineSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t OutDstVectorSize,
bool SweepOnce>
struct GridwiseLayernorm_mk_to_mk
......@@ -205,8 +207,8 @@ struct GridwiseLayernorm_mk_to_mk
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
AffineSrcVectorDim,
AffineSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
1,
true>(
gamma_grid_desc_m_k,
......@@ -220,8 +222,8 @@ struct GridwiseLayernorm_mk_to_mk
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
AffineSrcVectorDim,
AffineSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
1,
true>(
beta_grid_desc_m_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment