Commit 6ed9ab3a authored by rocking's avatar rocking
Browse files

Use 1d descriptor for gamma and beta

parent 3bb0cbe7
...@@ -41,9 +41,7 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType, ...@@ -41,9 +41,7 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
8, // SliceK 8, // SliceK
1, // SrcVecDim (0=M, 1=K) 1, // SrcVecDim (0=M, 1=K)
8, // SrcScalarPerVector 8, // SrcScalarPerVector
1, // GammaVecDim (0=M, 1=K)
8, // GammaScalarPerVector 8, // GammaScalarPerVector
1, // BetaVecDim (0=M, 1=K)
8, // BetaScalarPerVector 8, // BetaScalarPerVector
1>; // OutScalarPerVector 1>; // OutScalarPerVector
...@@ -97,7 +95,7 @@ int main() ...@@ -97,7 +95,7 @@ int main()
ck::index_t M = 1024; ck::index_t M = 1024;
ck::index_t N = 1024; ck::index_t N = 1024;
ck::index_t Stride = 1024; ck::index_t Stride = N;
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}), return HostTensorDescriptor(std::vector<std::size_t>({len}),
...@@ -128,16 +126,17 @@ int main() ...@@ -128,16 +126,17 @@ int main()
beta_dev.ToDevice(beta.mData.data()); beta_dev.ToDevice(beta.mData.data());
auto device_instance = DeviceInstance{}; auto device_instance = DeviceInstance{};
auto argument_ptr = device_instance.MakeArgumentPointer({M, N}, auto argument_ptr = device_instance.MakeArgumentPointer(
{Stride, 1}, {M, N},
{0, 1}, std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
{0, 1}, std::vector<ck::index_t>{gamma.mDesc.GetStrides().begin(), gamma.mDesc.GetStrides().end()},
{1}, std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(), beta.mDesc.GetStrides().end()},
1e-4, {1},
x_dev.GetDeviceBuffer(), 1e-4,
gamma_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer()); beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer());
if(!device_instance.IsSupportedArgument(argument_ptr.get())) if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{ {
......
...@@ -34,21 +34,17 @@ template <typename XDataType, ...@@ -34,21 +34,17 @@ template <typename XDataType,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t XSrcVectorDim, index_t XSrcVectorDim,
index_t XSrcVectorSize, index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorSize> index_t YDstVectorSize>
struct DeviceLayernorm : public BaseOperator struct DeviceLayernorm : public BaseOperator
{ {
static_assert( static_assert(
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || (KThreadSliceSize % GammaSrcVectorSize == 0),
(GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
static_assert( static_assert(
((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) || (KThreadSliceSize % BetaSrcVectorSize == 0),
(BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); "Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
...@@ -75,7 +71,38 @@ struct DeviceLayernorm : public BaseOperator ...@@ -75,7 +71,38 @@ struct DeviceLayernorm : public BaseOperator
XSrcVectorSize, XSrcVectorSize,
1>; // YDstVectorSize 1>; // YDstVectorSize
static auto MakeAffine1dDescriptor(const std::vector<index_t>& Lengths,
const std::vector<index_t>& Strides,
int blkGroupSize,
int numBlockTileIteration)
{
const auto tupleLengths = make_tuple_from_array(Lengths, Number<NumReduceDim>{});
const auto tupleStrides = make_tuple_from_array(Strides, Number<NumReduceDim>{});
auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
auto grid_desc_k = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(tupleLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumReduceDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto reduceTotalLength = grid_desc_k.GetLength(Number<0>{});
const int reduceSizePerBlock = Reduction::K_BlockTileSize * numBlockTileIteration;
const auto Pad_K = reduceSizePerBlock * blkGroupSize - reduceTotalLength;
auto grid_desc_k_padded = transform_tensor_descriptor(
grid_desc_k,
make_tuple(make_right_pad_transform(reduceTotalLength, Pad_K)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (grid_desc_k_padded);
};
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridDesc_K = decltype(MakeAffine1dDescriptor({1}, {1}, 1, 1));
using GridwiseReduceLayernormGeneric = GridwiseLayernorm_mk_to_mk<XDataType, using GridwiseReduceLayernormGeneric = GridwiseLayernorm_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
...@@ -83,6 +110,7 @@ struct DeviceLayernorm : public BaseOperator ...@@ -83,6 +110,7 @@ struct DeviceLayernorm : public BaseOperator
YDataType, YDataType,
AccDataType, AccDataType,
GridDesc_M_K, GridDesc_M_K,
GridDesc_K,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -90,9 +118,7 @@ struct DeviceLayernorm : public BaseOperator ...@@ -90,9 +118,7 @@ struct DeviceLayernorm : public BaseOperator
KThreadSliceSize, KThreadSliceSize,
XSrcVectorDim, XSrcVectorDim,
XSrcVectorSize, XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize, GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorSize, YDstVectorSize,
false>; false>;
...@@ -103,6 +129,7 @@ struct DeviceLayernorm : public BaseOperator ...@@ -103,6 +129,7 @@ struct DeviceLayernorm : public BaseOperator
YDataType, YDataType,
AccDataType, AccDataType,
GridDesc_M_K, GridDesc_M_K,
GridDesc_K,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -110,9 +137,7 @@ struct DeviceLayernorm : public BaseOperator ...@@ -110,9 +137,7 @@ struct DeviceLayernorm : public BaseOperator
KThreadSliceSize, KThreadSliceSize,
XSrcVectorDim, XSrcVectorDim,
XSrcVectorSize, XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize, GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorSize, YDstVectorSize,
true>; true>;
...@@ -144,16 +169,22 @@ struct DeviceLayernorm : public BaseOperator ...@@ -144,16 +169,22 @@ struct DeviceLayernorm : public BaseOperator
PassThrough{}), PassThrough{}),
epsilon_(epsilon), epsilon_(epsilon),
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta) p_beta_(p_beta),
gammaStrides_(gammaStrides),
betaStrides_(betaStrides)
{ {
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims); reduceLength_.resize(NumReduceDim);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims); for(int i = 0; i < NumReduceDim; ++i)
{
reduceLength_[i] = lengths[reduceDims[i]];
}
} }
AccDataType epsilon_; AccDataType epsilon_;
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
const BetaDataType* p_beta_; const BetaDataType* p_beta_;
std::vector<index_t> reduceLength_;
std::vector<index_t> gammaStrides_; std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_; std::vector<index_t> betaStrides_;
}; };
...@@ -164,10 +195,10 @@ struct DeviceLayernorm : public BaseOperator ...@@ -164,10 +195,10 @@ struct DeviceLayernorm : public BaseOperator
{ {
const auto x_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto x_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto gamma_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto gamma_grid_desc_k = MakeAffine1dDescriptor(
arg.inLengths_, arg.gammaStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.reduceLength_, arg.gammaStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto beta_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto beta_grid_desc_k = MakeAffine1dDescriptor(
arg.inLengths_, arg.betaStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.reduceLength_, arg.betaStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto y_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto y_grid_desc_m_k = Reduction::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
...@@ -180,14 +211,16 @@ struct DeviceLayernorm : public BaseOperator ...@@ -180,14 +211,16 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
GridDesc_M_K> GridDesc_M_K,
GridDesc_K>
: kernel_layernorm<GridwiseReduceLayernormGeneric, : kernel_layernorm<GridwiseReduceLayernormGeneric,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
GridDesc_M_K>; GridDesc_M_K,
GridDesc_K>;
float avg_time = 0; float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
...@@ -196,8 +229,8 @@ struct DeviceLayernorm : public BaseOperator ...@@ -196,8 +229,8 @@ struct DeviceLayernorm : public BaseOperator
dim3(BlockSize), dim3(BlockSize),
0, 0,
x_grid_desc_m_k, x_grid_desc_m_k,
gamma_grid_desc_m_k, gamma_grid_desc_k,
beta_grid_desc_m_k, beta_grid_desc_k,
y_grid_desc_m_k, y_grid_desc_m_k,
arg.numBlockTileIteration, arg.numBlockTileIteration,
arg.epsilon_, arg.epsilon_,
...@@ -230,41 +263,26 @@ struct DeviceLayernorm : public BaseOperator ...@@ -230,41 +263,26 @@ struct DeviceLayernorm : public BaseOperator
return false; return false;
} }
// if fastest dim is not reduced if(p_arg_->gammaStrides_.size() != NumReduceDim ||
if constexpr(GammaSrcVectorDim == 0) p_arg_->betaStrides_.size() != NumReduceDim)
{ return false;
if(p_arg_->gammaStrides_[Reduction::NumInvariantDim - 1] != 1)
return (false);
if(p_arg_->invariant_lowest_length % GammaSrcVectorSize != 0) auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
return (false); bool ret = true;
}
else // if fastest dim is reduced
{
if(p_arg_->gammaStrides_[Rank - 1] != 1)
return (false);
if(p_arg_->reduce_lowest_length % GammaSrcVectorSize != 0) if(!isLastDimensionCoalesced)
return (false); ret = scalarPerVector == 1;
} else
ret = KThreadSliceSize % scalarPerVector == 0;
// if fastest dim is not reduced return ret;
if constexpr(BetaSrcVectorDim == 0) };
{
if(p_arg_->betaStrides_[Reduction::NumInvariantDim - 1] != 1)
return (false);
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) if(!IsScalarPerVectorValid(p_arg_->gammaStrides_.back() == 1, GammaSrcVectorSize))
return (false); return false;
}
else // if fastest dim is reduced
{
if(p_arg_->betaStrides_[Rank - 1] != 1)
return (false);
if(p_arg_->reduce_lowest_length % BetaSrcVectorSize != 0) if(!IsScalarPerVectorValid(p_arg_->betaStrides_.back() == 1, BetaSrcVectorSize))
return (false); return false;
}
return true; return true;
}; };
......
...@@ -20,10 +20,11 @@ template <typename GridwiseReduction, ...@@ -20,10 +20,11 @@ template <typename GridwiseReduction,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename AccDataType,
typename GridDesc_M_K> typename GridDesc_M_K,
typename GridDesc_K>
__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k, __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_K gamma_grid_desc_k,
const GridDesc_M_K beta_grid_desc_m_k, const GridDesc_K beta_grid_desc_k,
const GridDesc_M_K y_grid_desc_m_k, const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
...@@ -33,8 +34,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k, ...@@ -33,8 +34,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
YDataType* const __restrict__ p_y_global) YDataType* const __restrict__ p_y_global)
{ {
GridwiseReduction::Run(x_grid_desc_m_k, GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k, gamma_grid_desc_k,
beta_grid_desc_m_k, beta_grid_desc_k,
y_grid_desc_m_k, y_grid_desc_m_k,
num_k_block_tile_iteration, num_k_block_tile_iteration,
epsilon, epsilon,
...@@ -50,6 +51,7 @@ template <typename XDataType, ...@@ -50,6 +51,7 @@ template <typename XDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename AccDataType,
typename GridDesc_M_K, typename GridDesc_M_K,
typename GridDesc_K,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -57,9 +59,7 @@ template <typename XDataType, ...@@ -57,9 +59,7 @@ template <typename XDataType,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t XSrcVectorDim, index_t XSrcVectorDim,
index_t XSrcVectorSize, index_t XSrcVectorSize,
index_t GammaSrcVectorDim,
index_t GammaSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorSize, index_t YDstVectorSize,
bool SweepOnce> bool SweepOnce>
...@@ -114,8 +114,8 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -114,8 +114,8 @@ struct GridwiseLayernorm_mk_to_mk
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k, __device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k,
const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_K& gamma_grid_desc_k,
const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_K& beta_grid_desc_k,
const GridDesc_M_K& y_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
...@@ -141,11 +141,9 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -141,11 +141,9 @@ struct GridwiseLayernorm_mk_to_mk
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf; x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true> gamma_thread_buf;
gamma_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, KThreadSliceSize, true> beta_thread_buf;
beta_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf; y_thread_buf;
...@@ -175,15 +173,18 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -175,15 +173,18 @@ struct GridwiseLayernorm_mk_to_mk
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1]; const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>; using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( using ThreadBufferLengths_K = Sequence<KThreadSliceSize>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_k =
make_naive_tensor_descriptor_packed(make_tuple(Number<KThreadSliceSize>{}));
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType, auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType, AccDataType,
GridDesc_M_K, GridDesc_M_K,
decltype(thread_buffer_desc), decltype(thread_buffer_desc_m_k),
ThreadBufferLengths, ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
XSrcVectorDim, XSrcVectorDim,
XSrcVectorSize, XSrcVectorSize,
...@@ -194,67 +195,68 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -194,67 +195,68 @@ struct GridwiseLayernorm_mk_to_mk
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
auto threadwise_gamma_load = ThreadwiseTensorSliceTransfer_v2<GammaDataType, auto threadwise_gamma_load =
AccDataType, ThreadwiseTensorSliceTransfer_v2<GammaDataType,
GridDesc_M_K, AccDataType,
decltype(thread_buffer_desc), GridDesc_K,
ThreadBufferLengths, decltype(thread_buffer_desc_k),
ThreadBufferDimAccessOrder, ThreadBufferLengths_K,
GammaSrcVectorDim, Sequence<0>,
GammaSrcVectorSize, 0,
1, GammaSrcVectorSize,
true>( 1,
gamma_grid_desc_m_k, true>(
make_multi_index(block_global_id * M_BlockTileSize + gamma_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2<BetaDataType, auto threadwise_beta_load = ThreadwiseTensorSliceTransfer_v2<BetaDataType,
AccDataType, AccDataType,
GridDesc_M_K, GridDesc_K,
decltype(thread_buffer_desc), decltype(thread_buffer_desc_k),
ThreadBufferLengths, ThreadBufferLengths_K,
ThreadBufferDimAccessOrder, Sequence<0>,
BetaSrcVectorDim, 0,
BetaSrcVectorSize, BetaSrcVectorSize,
1, 1,
true>( true>(
beta_grid_desc_m_k, beta_grid_desc_k, make_multi_index(thread_k_cluster_id * KThreadSliceSize));
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, auto threadwise_y_store =
thread_k_cluster_id * KThreadSliceSize)); ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
YDataType,
auto threadwise_y_store = ThreadwiseTensorSliceTransfer_v1r3<AccDataType, decltype(thread_buffer_desc_m_k),
YDataType, GridDesc_M_K,
decltype(thread_buffer_desc), PassThroughOp,
GridDesc_M_K, ThreadBufferLengths_M_K,
PassThroughOp, ThreadBufferDimAccessOrder,
ThreadBufferLengths, XSrcVectorDim,
ThreadBufferDimAccessOrder, YDstVectorSize,
XSrcVectorDim, InMemoryDataOperationEnum::Set,
YDstVectorSize, 1,
InMemoryDataOperationEnum::Set, true>(
1, y_grid_desc_m_k,
true>( make_multi_index(block_global_id * M_BlockTileSize +
y_grid_desc_m_k, thread_m_cluster_id * MThreadSliceSize,
make_multi_index(block_global_id * M_BlockTileSize + thread_k_cluster_id * KThreadSliceSize),
thread_m_cluster_id * MThreadSliceSize, PassThroughOp{});
thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{});
// Copy x from Cache // Copy x from Cache
// one pass: fwd, second pass: bwd // one pass: fwd, second pass: bwd
constexpr auto thread_copy_fwd_step = make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize); constexpr auto thread_copy_fwd_step_k = make_multi_index(SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_bwd_step = make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); constexpr auto thread_copy_bwd_step_k = make_multi_index(SweepOnce ? 0 : -K_BlockTileSize);
constexpr auto thread_copy_fwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); p_gamma_global, gamma_grid_desc_k.GetElementSpaceSize());
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); p_beta_global, beta_grid_desc_k.GetElementSpaceSize());
// E(x), E[x^2], var(x) // E(x), E[x^2], var(x)
int reduce_length = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1]; int reduce_length = x_grid_desc_m_k.GetTransforms()[I0].GetUpperLengths()[I1];
...@@ -264,22 +266,23 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -264,22 +266,23 @@ struct GridwiseLayernorm_mk_to_mk
{ {
threadwise_x_load.Run(x_grid_desc_m_k, threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf, x_global_val_buf,
thread_buffer_desc, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
x_thread_buf); x_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); constexpr auto offset_m_k =
x_square_thread_buf(Number<offset>{}) = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
x_thread_buf(Number<offset>{}) * x_thread_buf(Number<offset>{}); x_square_thread_buf(Number<offset_m_k>{}) =
x_thread_buf(Number<offset_m_k>{}) * x_thread_buf(Number<offset_m_k>{});
}); });
}); });
ThreadwiseSumReduce::Reduce(x_thread_buf, mean_thread_buf); ThreadwiseSumReduce::Reduce(x_thread_buf, mean_thread_buf);
ThreadwiseSumReduce::Reduce(x_square_thread_buf, mean_square_thread_buf); ThreadwiseSumReduce::Reduce(x_square_thread_buf, mean_square_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
++reducedTiles; ++reducedTiles;
} while(reducedTiles < num_k_block_tile_iteration); } while(reducedTiles < num_k_block_tile_iteration);
...@@ -297,12 +300,13 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -297,12 +300,13 @@ struct GridwiseLayernorm_mk_to_mk
}); });
// y = (x - E[x]) / sqrt(var[x] + epsilon) // y = (x - E[x]) / sqrt(var[x] + epsilon)
auto thread_copy_tail = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step; auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
auto thread_copy_tail_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_k;
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_tail); threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_tail_k);
threadwise_beta_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_tail); threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_tail_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail); threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
reducedTiles = 0; reducedTiles = 0;
do do
...@@ -311,48 +315,51 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -311,48 +315,51 @@ struct GridwiseLayernorm_mk_to_mk
{ {
threadwise_x_load.Run(x_grid_desc_m_k, threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf, x_global_val_buf,
thread_buffer_desc, thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
x_thread_buf); x_thread_buf);
} }
threadwise_gamma_load.Run(gamma_grid_desc_m_k, threadwise_gamma_load.Run(gamma_grid_desc_k,
gamma_global_val_buf, gamma_global_val_buf,
thread_buffer_desc, thread_buffer_desc_k,
make_tuple(I0, I0), make_tuple(I0),
gamma_thread_buf); gamma_thread_buf);
threadwise_beta_load.Run(beta_grid_desc_m_k, threadwise_beta_load.Run(beta_grid_desc_k,
beta_global_val_buf, beta_global_val_buf,
thread_buffer_desc, thread_buffer_desc_k,
make_tuple(I0, I0), make_tuple(I0),
beta_thread_buf); beta_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
constexpr auto offset_k = thread_buffer_desc_k.CalculateOffset(make_tuple(iK));
// normalize // normalize
y_thread_buf(Number<offset>{}) = y_thread_buf(Number<offset_m_k>{}) =
(x_thread_buf(Number<offset>{}) - mean_thread_buf(iM)) / (x_thread_buf(Number<offset_m_k>{}) - mean_thread_buf(iM)) /
sqrt(var_value_buf(iM) + epsilon); sqrt(var_value_buf(iM) + epsilon);
// affine // affine
y_thread_buf(Number<offset>{}) = y_thread_buf(Number<offset_m_k>{}) =
y_thread_buf(Number<offset>{}) * gamma_thread_buf(Number<offset>{}) + y_thread_buf(Number<offset_m_k>{}) * gamma_thread_buf(Number<offset_k>{}) +
beta_thread_buf(Number<offset>{}); beta_thread_buf(Number<offset_k>{});
}); });
}); });
threadwise_y_store.Run(thread_buffer_desc, threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0), make_tuple(I0, I0),
y_thread_buf, y_thread_buf,
y_grid_desc_m_k, y_grid_desc_m_k,
y_global_val_buf); y_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step); threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
threadwise_gamma_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step); threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_k, thread_copy_bwd_step_k);
threadwise_beta_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step); threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_k, thread_copy_bwd_step_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step); threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k);
++reducedTiles; ++reducedTiles;
} while(reducedTiles < num_k_block_tile_iteration); } while(reducedTiles < num_k_block_tile_iteration);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment