Commit f591ad27 authored by rocking's avatar rocking
Browse files

1. Separate gamma aand beta from affine

2. Check if argument is valid
parent 8e2d0ae7
...@@ -40,9 +40,11 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType, ...@@ -40,9 +40,11 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
1, // SliceM 1, // SliceM
8, // SliceK 8, // SliceK
1, // SrcVecDim (0=M, 1=K) 1, // SrcVecDim (0=M, 1=K)
1, // SrcScalarPerVector 8, // SrcScalarPerVector
1, // AffineVecDim (0=M, 1=K) 1, // GammaVecDim (0=M, 1=K)
1, // AffineScalarPerVector 8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector
1>; // OutScalarPerVector 1>; // OutScalarPerVector
template <typename XDataType, template <typename XDataType,
...@@ -129,6 +131,7 @@ int main() ...@@ -129,6 +131,7 @@ int main()
auto argument_ptr = device_instance.MakeArgumentPointer({M, N}, auto argument_ptr = device_instance.MakeArgumentPointer({M, N},
{Stride, 1}, {Stride, 1},
{0, 1}, {0, 1},
{0, 1},
{1}, {1},
1e-4, 1e-4,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
......
...@@ -34,15 +34,22 @@ template <typename XDataType, ...@@ -34,15 +34,22 @@ template <typename XDataType,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t AffineSrcVectorDim, index_t GammaSrcVectorDim,
index_t AffineSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceLayernorm : public BaseOperator struct DeviceLayernorm : public BaseOperator
{ {
static_assert( static_assert(
((AffineSrcVectorDim == 0 && MThreadSliceSize % AffineSrcVectorSize == 0) || ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
(AffineSrcVectorDim == 1 && KThreadSliceSize % AffineSrcVectorSize == 0)), (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or affine vector sizes configuration, please check!"); "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
static_assert(
((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) ||
(BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
...@@ -83,8 +90,10 @@ struct DeviceLayernorm : public BaseOperator ...@@ -83,8 +90,10 @@ struct DeviceLayernorm : public BaseOperator
KThreadSliceSize, KThreadSliceSize,
InSrcVectorDim, InSrcVectorDim,
InSrcVectorSize, InSrcVectorSize,
AffineSrcVectorDim, GammaSrcVectorDim,
AffineSrcVectorSize, GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
OutDstVectorSize, OutDstVectorSize,
false>; false>;
...@@ -92,7 +101,8 @@ struct DeviceLayernorm : public BaseOperator ...@@ -92,7 +101,8 @@ struct DeviceLayernorm : public BaseOperator
{ {
Argument(const std::vector<index_t> inLengths, Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<index_t> affineStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon, AccDataType epsilon,
const XDataType* p_x, const XDataType* p_x,
...@@ -116,14 +126,16 @@ struct DeviceLayernorm : public BaseOperator ...@@ -116,14 +126,16 @@ struct DeviceLayernorm : public BaseOperator
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta) p_beta_(p_beta)
{ {
affineStrides_ = gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
shuffle_tensor_dimensions<Rank, NumReduceDim>(affineStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
} }
AccDataType epsilon_; AccDataType epsilon_;
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
const BetaDataType* p_beta_; const BetaDataType* p_beta_;
std::vector<index_t> affineStrides_; std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
...@@ -133,9 +145,9 @@ struct DeviceLayernorm : public BaseOperator ...@@ -133,9 +145,9 @@ struct DeviceLayernorm : public BaseOperator
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto gamma_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto gamma_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.affineStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.gammaStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto beta_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto beta_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.affineStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.betaStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
...@@ -189,14 +201,50 @@ struct DeviceLayernorm : public BaseOperator ...@@ -189,14 +201,50 @@ struct DeviceLayernorm : public BaseOperator
{ {
return false; return false;
} }
// TODO - Check AffineSrcVectorDim and AffineSrcVectorSize
// if fastest dim is not reduced
if constexpr(GammaSrcVectorDim == 0)
{
if(p_arg_->gammaStrides_[Reduction::NumInvariantDim - 1] != 1)
return (false);
if(p_arg_->invariant_lowest_length % GammaSrcVectorSize != 0)
return (false);
}
else // if fastest dim is reduced
{
if(p_arg_->gammaStrides_[Rank - 1] != 1)
return (false);
if(p_arg_->reduce_lowest_length % GammaSrcVectorSize != 0)
return (false);
}
// if fastest dim is not reduced
if constexpr(BetaSrcVectorDim == 0)
{
if(p_arg_->betaStrides_[Reduction::NumInvariantDim - 1] != 1)
return (false);
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0)
return (false);
}
else // if fastest dim is reduced
{
if(p_arg_->betaStrides_[Rank - 1] != 1)
return (false);
if(p_arg_->reduce_lowest_length % BetaSrcVectorSize != 0)
return (false);
}
return true; return true;
}; };
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths, std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<index_t> affineStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
AccDataType epsilon, AccDataType epsilon,
const void* p_x, const void* p_x,
...@@ -206,7 +254,8 @@ struct DeviceLayernorm : public BaseOperator ...@@ -206,7 +254,8 @@ struct DeviceLayernorm : public BaseOperator
{ {
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
affineStrides, gammaStrides,
betaStrides,
reduceDims, reduceDims,
epsilon, epsilon,
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
......
...@@ -59,8 +59,10 @@ template <typename XDataType, ...@@ -59,8 +59,10 @@ template <typename XDataType,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t AffineSrcVectorDim, index_t GammaSrcVectorDim,
index_t AffineSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize,
index_t OutDstVectorSize, index_t OutDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseLayernorm_mk_to_mk struct GridwiseLayernorm_mk_to_mk
...@@ -205,8 +207,8 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -205,8 +207,8 @@ struct GridwiseLayernorm_mk_to_mk
decltype(thread_buffer_desc), decltype(thread_buffer_desc),
ThreadBufferLengths, ThreadBufferLengths,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
AffineSrcVectorDim, GammaSrcVectorDim,
AffineSrcVectorSize, GammaSrcVectorSize,
1, 1,
true>( true>(
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
...@@ -220,8 +222,8 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -220,8 +222,8 @@ struct GridwiseLayernorm_mk_to_mk
decltype(thread_buffer_desc), decltype(thread_buffer_desc),
ThreadBufferLengths, ThreadBufferLengths,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
AffineSrcVectorDim, BetaSrcVectorDim,
AffineSrcVectorSize, BetaSrcVectorSize,
1, 1,
true>( true>(
beta_grid_desc_m_k, beta_grid_desc_m_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment