Commit e508995a authored by rocking's avatar rocking
Browse files

Add deviceOp to backward x

parent b9cb4a21
......@@ -15,6 +15,7 @@
#include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_x_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
......@@ -46,8 +47,34 @@ constexpr int NumReduceDim = 1;
// dbeta = reduce_sum(dy, axis=0)
// [CAUSION]
// In DeviceNormalizationBwdGammaBetaImpl, M is invarient dimension, K is reduced dimension
// Hence, M in this example and DeviceNormalizationBwdGammaBetaImpl is different
// In DeviceNormalizationBwdXImpl & DeviceNormalizationBwdGammaBetaImpl, M is invarient dimension, K
// is reduced dimension Hence, M in this example and DeviceNormalizationBwdGammaBetaImpl is
// different
using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdXImpl<
DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
Rank,
NumReduceDim,
256, // BlockSize
8, // MThreadClusterSize
32, // KThreadClusterSize
1, // MThreadSliceSize
8, // KThreadSliceSize
true, // IsDYFastestDimReduced
8, // DYSrcVectorSize
true, // IsXFastestDimReduced
8, // XSrcVectorSize
true, // IsGammaFastestDimReduced
8, // GammaSrcVectorSize
false, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
true, // IsDXFastestDimReduced
8>; // DXDstVectorSize
using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl<
DYDataType,
XDataType,
......@@ -58,18 +85,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank,
NumReduceDim,
256, // BlockSize
8, // ClusterInvarient
32, // ClusterReduce
8, // SliceInvarient
1, // SliceReduce
8, // MThreadClusterSize
32, // KThreadClusterSize
8, // MThreadSliceSize
1, // KThreadSliceSize
false, // IsDYFastestDimReduced
8, // DYSrcVectorSize
false, // IsXFastestDimReduced
8, // XSrcVectorSize
true, // IsMeanInvStdFastestDimReduced
1, // MeanInvStdSrcVectorSize
1, // DGammaDstVectorSize
1>; // DBetaDstVectorSize
8, // DGammaDstVectorSize
8>; // DBetaDstVectorSize
int main()
{
......@@ -96,8 +123,10 @@ int main()
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize());
DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize());
......@@ -106,6 +135,34 @@ int main()
mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data());
// backward x
auto x_device_instance = XDeviceInstance{};
auto x_argument_ptr = x_device_instance.MakeArgumentPointer({M, N}, // lengths
{N, 1}, // dyStrides
{N, 1}, // xStrides
{0, 1}, // gammaStrides
{1, 0}, // meanStrides
{1, 0}, // invStdStrides
{N, 1}, // dxStrides
{1}, // reduceDims
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dx_dev.GetDeviceBuffer());
if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get()))
{
std::cout << "The runtime parameters are not supported" << std::endl;
return 1;
};
auto x_invoker_ptr = x_device_instance.MakeInvokerPointer();
x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
// backward gamma & beta
auto gamma_beta_device_instance = GammaBetaDeviceInstance{};
auto gamma_beta_argument_ptr =
gamma_beta_device_instance.MakeArgumentPointer({M, N}, // inLengths
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim>
struct DeviceNormalizationBwdX : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> lengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> dxStrides,
const std::vector<index_t> reduceDims,
const void* p_dy,
const void* p_x,
const void* p_gamma,
const void* p_mean,
const void* p_invStd,
void* p_dx) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim>
using DeviceNormalizationBwdXPtr = std::unique_ptr<DeviceNormalizationBwdX<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
Rank,
NumReduceDim>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_x.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
// M is invarient dimension, K is reduced dimension
namespace ck {
namespace tensor_operation {
namespace device {
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
bool IsDYFastestDimReduced,
index_t DYSrcVectorSize,
bool IsXFastestDimReduced,
index_t XSrcVectorSize,
bool IsGammaFastestDimReduced,
index_t GammaSrcVectorSize,
bool IsMeanInvStdFastestDimReduced,
index_t MeanInvStdSrcVectorSize,
bool IsDxFastestDimReduced,
index_t DXDstVectorSize>
struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
Rank,
NumReduceDim>
{
static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0;
static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0;
static constexpr index_t GammaSrcVectorDim = IsGammaFastestDimReduced ? 1 : 0;
static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0;
static constexpr index_t DXDstVectorDim = IsDxFastestDimReduced ? 1 : 0;
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) ||
(DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!");
static_assert(
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
(GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
static_assert(
(MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!");
static_assert(((DXDstVectorDim == 0 && MThreadSliceSize % DXDstVectorSize == 0) ||
(DXDstVectorDim == 1 && KThreadSliceSize % DXDstVectorSize == 0)),
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static_assert(!reduceAllDim);
static auto Make2dDescriptor(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides,
int numBlockTileIteration)
{
const auto tupleLengths = make_tuple_from_array(lengths, Number<Rank>{});
const auto tupleStrides = make_tuple_from_array(strides, Number<Rank>{});
const auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
const auto grid_desc_m_k = [&]() {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(lengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
return transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}();
const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
const auto pad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto pad_K = K_BlockTileSize * numBlockTileIteration - reduceLength;
auto grid_desc_m_k_padded =
transform_tensor_descriptor(grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, pad_M),
make_right_pad_transform(reduceLength, pad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return grid_desc_m_k_padded;
}
using GridDesc_M_K = decltype(Make2dDescriptor({1}, {1}, 1));
struct Argument : public BaseArgument
{
Argument(const std::vector<index_t> lengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> dxStrides,
const std::vector<index_t> reduceDims,
const DYDataType* p_dy,
const XDataType* p_x,
const GammaDataType* p_gamma,
const MeanInvStdDataType* p_mean,
const MeanInvStdDataType* p_invStd,
DXDataType* p_dx)
: p_dy_(p_dy),
p_x_(p_x),
p_gamma_(p_gamma),
p_mean_(p_mean),
p_invStd_(p_invStd),
p_dx_(p_dx),
dxStrides_{dxStrides}
{
lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
dyStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(dyStrides, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
meanStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(meanStrides, reduceDims);
invStdStrides_ =
shuffle_tensor_dimensions<Rank, NumReduceDim>(invStdStrides, reduceDims);
std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(lengths);
numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize);
gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize);
dy_grid_desc_m_k_ = Make2dDescriptor(lengths_, dyStrides_, numBlockTileIteration_);
x_grid_desc_m_k_ = Make2dDescriptor(lengths_, xStrides_, numBlockTileIteration_);
gamma_grid_desc_m_k_ =
Make2dDescriptor(lengths_, gammaStrides_, numBlockTileIteration_);
mean_grid_desc_m_k_ = Make2dDescriptor(lengths_, meanStrides_, numBlockTileIteration_);
inv_std_grid_desc_m_k_ =
Make2dDescriptor(lengths_, invStdStrides_, numBlockTileIteration_);
dx_grid_desc_m_k_ = Make2dDescriptor(lengths_, dxStrides_, numBlockTileIteration_);
isSweeponce_ = dy_grid_desc_m_k_.GetLength(Number<1>{}) <= K_BlockTileSize;
}
const DYDataType* p_dy_;
const XDataType* p_x_;
const GammaDataType* p_gamma_;
const MeanInvStdDataType* p_mean_;
const MeanInvStdDataType* p_invStd_;
DXDataType* p_dx_;
std::vector<index_t> lengths_;
std::vector<index_t> dyStrides_;
std::vector<index_t> xStrides_;
std::vector<index_t> gammaStrides_;
std::vector<index_t> meanStrides_;
std::vector<index_t> invStdStrides_;
std::vector<index_t> dxStrides_;
int numBlockTileIteration_;
size_t gridSize_;
// tensor descriptor
GridDesc_M_K dy_grid_desc_m_k_;
GridDesc_M_K x_grid_desc_m_k_;
GridDesc_M_K gamma_grid_desc_m_k_;
GridDesc_M_K mean_grid_desc_m_k_;
GridDesc_M_K inv_std_grid_desc_m_k_;
GridDesc_M_K dx_grid_desc_m_k_;
bool isSweeponce_;
index_t MRaw_; // invarient length
index_t KRaw_; // reduce length
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
// TODO
ignore = arg;
ignore = stream_config;
return 0;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
template <index_t SrcVectorDim, index_t SrcVectorSize>
bool IsVectorDimSizeValid(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
if constexpr(SrcVectorSize == 1)
return true;
// Fastest dimension is not reduced
if constexpr(SrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
return false;
if(strides[NumInvariantDim - 1] != 1)
return false;
if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0)
return false;
}
else // Fastest dimension is reduced
{
if(strides[Rank - 1] != 1)
return false;
if(lengths[Rank - 1] % SrcVectorSize != 0)
return false;
};
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
bool pass = true;
pass &= IsVectorDimSizeValid<DYSrcVectorDim, DYSrcVectorSize>(p_arg_->lengths_,
p_arg_->dyStrides_);
pass &= IsVectorDimSizeValid<XSrcVectorDim, XSrcVectorSize>(p_arg_->lengths_,
p_arg_->xStrides_);
pass &= IsVectorDimSizeValid<GammaSrcVectorDim, GammaSrcVectorSize>(p_arg_->lengths_,
p_arg_->gammaStrides_);
pass &= IsVectorDimSizeValid<MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize>(
p_arg_->lengths_, p_arg_->meanStrides_);
pass &= IsVectorDimSizeValid<MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize>(
p_arg_->lengths_, p_arg_->invStdStrides_);
pass &= IsVectorDimSizeValid<DXDstVectorDim, DXDstVectorSize>(p_arg_->lengths_,
p_arg_->dxStrides_);
return pass;
}
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> lengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> dxStrides,
const std::vector<index_t> reduceDims,
const void* p_dy,
const void* p_x,
const void* p_gamma,
const void* p_mean,
const void* p_invStd,
void* p_dx) override
{
if(lengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank ||
gammaStrides.size() != Rank || meanStrides.size() != Rank ||
invStdStrides.size() != Rank || dxStrides.size() != Rank)
throw std::runtime_error("dimension is incorrect");
return std::make_unique<Argument>(lengths,
dyStrides,
xStrides,
gammaStrides,
meanStrides,
invStdStrides,
dxStrides,
reduceDims,
static_cast<const DYDataType*>(p_dy),
static_cast<const XDataType*>(p_x),
static_cast<const XDataType*>(p_gamma),
static_cast<const MeanInvStdDataType*>(p_mean),
static_cast<const MeanInvStdDataType*>(p_invStd),
static_cast<DXDataType*>(p_dx));
}
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceNormalizationBwdXImpl<" << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment