Unverified Commit 24af0144 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Merge branch 'develop' into gemm_layernorm_welford

parents 961f5e9e b79bbbc2
...@@ -90,9 +90,9 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -90,9 +90,9 @@ struct ComputePtrOffsetOfStridedBatch
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
* pointer offset into \p ComputePtrOffsetOfStridedBatch. * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
* *
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_multiple_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_multiple_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp"
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_multiple_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_multiple_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
#include "ck/utility/reduction_operator.hpp" #include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -24,17 +24,17 @@ template <typename GridwiseReduction, ...@@ -24,17 +24,17 @@ template <typename GridwiseReduction,
typename AccDataType, typename AccDataType,
typename AccElementwiseOperation, typename AccElementwiseOperation,
typename GridDesc_M_K> typename GridDesc_M_K>
__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k, __global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k, const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k, const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
{ {
GridwiseReduction::Run(x_grid_desc_m_k, GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
...@@ -54,7 +54,7 @@ namespace ck { ...@@ -54,7 +54,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Y = LayerNorm(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
...@@ -75,14 +75,14 @@ template <typename XDataType, ...@@ -75,14 +75,14 @@ template <typename XDataType,
index_t BetaSrcVectorDim, index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorSize> index_t YDstVectorSize>
struct DeviceLayernormImpl : public DeviceLayernorm<XDataType, struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
YDataType, YDataType,
AccElementwiseOperation, AccElementwiseOperation,
Rank, Rank,
NumReduceDim> NumReduceDim>
{ {
static_assert( static_assert(
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
...@@ -168,49 +168,49 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType, ...@@ -168,49 +168,49 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduceLayernormGeneric = using GridwiseReduceLayernormGeneric =
GridwiseLayernormWelfordVariance_mk_to_mk<XDataType, GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation, AccElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
MThreadSliceSize, MThreadSliceSize,
KThreadSliceSize, KThreadSliceSize,
XYSrcVectorDim, XYSrcVectorDim,
XSrcVectorSize, XSrcVectorSize,
GammaSrcVectorDim, GammaSrcVectorDim,
GammaSrcVectorSize, GammaSrcVectorSize,
BetaSrcVectorDim, BetaSrcVectorDim,
BetaSrcVectorSize, BetaSrcVectorSize,
XYSrcVectorDim, XYSrcVectorDim,
YDstVectorSize, YDstVectorSize,
false>; false>;
using GridwiseReduceLayernormSweepOnce = using GridwiseNormalizationSweepOnce =
GridwiseLayernormWelfordVariance_mk_to_mk<XDataType, GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation, AccElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
MThreadSliceSize, MThreadSliceSize,
KThreadSliceSize, KThreadSliceSize,
XYSrcVectorDim, XYSrcVectorDim,
XSrcVectorSize, XSrcVectorSize,
GammaSrcVectorDim, GammaSrcVectorDim,
GammaSrcVectorSize, GammaSrcVectorSize,
BetaSrcVectorDim, BetaSrcVectorDim,
BetaSrcVectorSize, BetaSrcVectorSize,
XYSrcVectorDim, XYSrcVectorDim,
YDstVectorSize, YDstVectorSize,
true>; true>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -295,22 +295,22 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType, ...@@ -295,22 +295,22 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto kernel_main = arg.isSweeponce_ const auto kernel_main = arg.isSweeponce_
? kernel_layernorm<GridwiseReduceLayernormSweepOnce, ? kernel_normalization<GridwiseNormalizationSweepOnce,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation, AccElementwiseOperation,
GridDesc_M_K> GridDesc_M_K>
: kernel_layernorm<GridwiseReduceLayernormGeneric, : kernel_normalization<GridwiseReduceLayernormGeneric,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation, AccElementwiseOperation,
GridDesc_M_K>; GridDesc_M_K>;
float avg_time = 0; float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
...@@ -426,8 +426,16 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType, ...@@ -426,8 +426,16 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
void* p_y, void* p_y,
void* p_saveMean,
void* p_saveInvVar,
AccElementwiseOperation acc_elementwise_op) override AccElementwiseOperation acc_elementwise_op) override
{ {
// TODO
// Optional cache of the intermediate results (mean and InvVariance) during the
// forward pass could speedup in the backward
ignore = p_saveMean;
ignore = p_saveInvVar;
return std::make_unique<Argument>(lengths, return std::make_unique<Argument>(lengths,
xStrides, xStrides,
gammaStrides, gammaStrides,
...@@ -452,7 +460,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType, ...@@ -452,7 +460,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceLayernormImpl<" << BlockSize << ","; str << "DeviceNormalizationImpl<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XYSrcVectorDim_" << XYSrcVectorDim << ","; str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
......
...@@ -5,13 +5,12 @@ ...@@ -5,13 +5,12 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <array>
#include "ck/utility/common_header.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
...@@ -41,7 +40,8 @@ template <typename InDataType, ...@@ -41,7 +40,8 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation> struct DeviceReduceMultiBlock
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
...@@ -58,8 +58,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -58,8 +58,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t NumSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
// So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
...@@ -81,13 +81,15 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -81,13 +81,15 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths, static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
const std::vector<index_t>& inStrides, const std::array<index_t, Rank>& inStrides,
int blkGroupSize, int blkGroupSize,
int numBlockTileIteration) int numBlockTileIteration)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths =
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
const auto tupleSrcStrides =
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
...@@ -97,7 +99,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -97,7 +99,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
const auto one_dim_inDesc = transform_tensor_descriptor( const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)), make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc, return transform_tensor_descriptor(one_dim_inDesc,
...@@ -111,10 +113,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -111,10 +113,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type; using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = const auto reduceDimLengths = generate_tuple(
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths = const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor( return transform_tensor_descriptor(
inDesc, inDesc,
...@@ -143,18 +145,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -143,18 +145,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths, static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
const std::vector<index_t>& outStrides) const std::array<index_t, NumDstDim>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths =
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
...@@ -170,18 +174,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -170,18 +174,20 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
static auto MakeDst1dDescriptorForBufferSet(const std::vector<index_t>& outLengths, static auto MakeDst1dDescriptorForBufferSet(const std::array<index_t, NumDstDim>& outLengths,
const std::vector<index_t>& outStrides) const std::array<index_t, NumDstDim>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths =
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto length = out_grid_desc_m.GetLength(Number<0>{}); const auto length = out_grid_desc_m.GetLength(Number<0>{});
...@@ -198,11 +204,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -198,11 +204,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<index_t> inLengths, Argument(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
...@@ -220,6 +226,30 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -220,6 +226,30 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
if(Rank != inLengths.size() || Rank != inStrides.size() ||
NumReduceDim != reduceDims.size())
{
throw std::runtime_error(
"One of inLengths/inStrides/reduceDims has invalid size!"
"\nExpected size inLengths: " +
std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) +
", reduceDims: " + std::to_string(NumReduceDim) +
"\nBut have inLengths: " + std::to_string(inLengths.size()) +
", inStrides: " + std::to_string(inStrides.size()) +
", reduceDims: " + std::to_string(reduceDims.size()));
}
for(std::size_t i = 0; i < reduceDims.size(); ++i)
{
if(reduceDims[i] < 0 || reduceDims[i] >= Rank)
{
throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!"
"\nHave reduceDims[" +
std::to_string(i) +
"]: " + std::to_string(reduceDims[i]));
}
}
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims); inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
...@@ -272,10 +302,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -272,10 +302,10 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize; math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize;
} }
std::vector<index_t> inLengths_; std::array<index_t, Rank> inLengths_;
std::vector<index_t> inStrides_; std::array<index_t, Rank> inStrides_;
std::vector<index_t> outLengths_; std::array<index_t, NumDstDim> outLengths_;
std::vector<index_t> outStrides_; std::array<index_t, NumDstDim> outStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
...@@ -459,11 +489,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -459,11 +489,11 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
......
...@@ -5,11 +5,12 @@ ...@@ -5,11 +5,12 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <array>
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
...@@ -34,7 +35,8 @@ template <typename InDataType, ...@@ -34,7 +35,8 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation> struct DeviceReduceThreadWise
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
...@@ -49,18 +51,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -49,18 +51,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t NumSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize; static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths, static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
const std::vector<index_t>& inStrides) const std::array<index_t, Rank>& inStrides)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths =
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
const auto tupleSrcStrides =
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
...@@ -70,7 +74,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -70,7 +74,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
const auto one_dim_inDesc = transform_tensor_descriptor( const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc, inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)), make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc, return transform_tensor_descriptor(one_dim_inDesc,
...@@ -84,10 +88,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -84,10 +88,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type; using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = const auto reduceDimLengths = generate_tuple(
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths = const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor( return transform_tensor_descriptor(
inDesc, inDesc,
...@@ -116,18 +120,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -116,18 +120,20 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths, static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
const std::vector<index_t>& outStrides) const std::array<index_t, NumDstDim>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths =
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor( auto out_grid_desc_m = transform_tensor_descriptor(
outDesc, outDesc,
make_tuple(make_merge_transform(tupleDstLengths)), make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
...@@ -145,11 +151,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -145,11 +151,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<index_t> inLengths, Argument(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
...@@ -187,10 +193,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -187,10 +193,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
M_BlockTileSize; M_BlockTileSize;
} }
std::vector<index_t> inLengths_; std::array<index_t, Rank> inLengths_;
std::vector<index_t> inStrides_; std::array<index_t, Rank> inStrides_;
std::vector<index_t> outLengths_; std::array<index_t, NumDstDim> outLengths_;
std::vector<index_t> outStrides_; std::array<index_t, NumDstDim> outStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
...@@ -321,11 +327,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE ...@@ -321,11 +327,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccE
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::vector<index_t> inStrides, const std::array<index_t, Rank> inStrides,
const std::vector<index_t> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::vector<index_t> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::vector<int> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
......
...@@ -8,12 +8,9 @@ ...@@ -8,12 +8,9 @@
#include "ck/utility/reduction_operator.hpp" #include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp" #include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -43,36 +40,88 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -43,36 +40,88 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
AccElementwiseOp, AccElementwiseOp,
Rank> Rank>
{ {
static constexpr index_t kRank = Rank; static constexpr index_t kRank = Rank;
static constexpr index_t kNumReduceDim = NumReduceDim; static constexpr index_t kNumReduceDim = NumReduceDim;
static constexpr index_t kNumInvariantDim = Rank - NumReduceDim;
virtual index_t GetRank() const override { return kRank; } virtual index_t GetRank() const override { return kRank; }
virtual index_t GetNumReduceDim() const override { return kNumReduceDim; } virtual index_t GetNumReduceDim() const override { return kNumReduceDim; }
// Used for freeloading of some handy functions from DeviceReduceMultiBlock static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
using Reduction = DeviceReduceMultiBlock<InDataType,
AccDataType, static constexpr index_t NumSrcDim = Rank;
OutDataType, static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
Rank, static constexpr bool reduceAllDim = (NumInvariantDim == 0);
NumReduceDim,
reduce::Add, static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
InElementwiseOp, static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
AccElementwiseOp,
InMemoryDataOperationEnum::Set, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
false, // PropagateNan const std::vector<index_t>& inStrides,
false, // OutputIndex int blkGroupSize,
false, // HaveIndexInputIfOutputIndex int numBlockTileIteration)
BlockSize, {
MThreadClusterSize, const auto tupleSrcLengths =
KThreadClusterSize, generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
MThreadSliceSize, const auto tupleSrcStrides =
KThreadSliceSize, generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
InSrcVectorDim,
InSrcVectorSize, const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
1>; // OutDstVectorSize
const auto in_grid_desc_m_k = [&]() {
using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = generate_tuple(
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType, using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk<InDataType,
OutDataType, OutDataType,
...@@ -102,7 +151,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -102,7 +151,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
OutDstVectorSize, OutDstVectorSize,
true>; true>;
struct Argument : public Reduction::Argument struct Argument : public BaseArgument
{ {
Argument(const std::vector<index_t> inLengths, Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
...@@ -113,42 +162,84 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -113,42 +162,84 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
OutDataType* out_dev, OutDataType* out_dev,
InElementwiseOp in_elementwise_op, InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op) AccElementwiseOp acc_elementwise_op)
: Reduction::Argument(inLengths, : alpha_{alpha},
inStrides, beta_{beta},
{}, in_dev_{in_dev},
{}, out_dev_{out_dev},
reduceDims, in_elementwise_op_{in_elementwise_op},
0.0f, // alpha acc_elementwise_op_{acc_elementwise_op}
0.0f, // beta
in_dev,
nullptr,
out_dev,
nullptr,
in_elementwise_op,
acc_elementwise_op),
// FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of
// float32 precision. Make it support any data type so the fields can be removed.
alpha_(alpha),
beta_(beta)
{ {
// std::cout << "blkGroupSize= " << this->blkGroupSize if(Rank != inLengths.size() || Rank != inStrides.size() ||
// << ", numBlockTileIteration= " << this->numBlockTileIteration NumReduceDim != reduceDims.size())
// << ", gridSize=" << this->gridSize {
// << ", invariant_total_length=" << this->invariant_total_length << throw std::runtime_error(
// std::endl; "One of inLengths/inStrides/reduceDims has invalid size!"
"\nExpected size inLengths: " +
std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) +
", reduceDims: " + std::to_string(NumReduceDim) +
"\nBut have inLengths: " + std::to_string(inLengths.size()) +
", inStrides: " + std::to_string(inStrides.size()) +
", reduceDims: " + std::to_string(reduceDims.size()));
}
for(std::size_t i = 0; i < reduceDims.size(); ++i)
{
if(reduceDims[i] < 0 || reduceDims[i] >= Rank)
{
throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!"
"\nHave reduceDims[" +
std::to_string(i) +
"]: " + std::to_string(reduceDims[i]));
}
}
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
long_index_t invariant_total_length;
long_index_t reduce_total_length;
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length_ = 1;
else
invariant_lowest_length_ = inLengths_[NumInvariantDim - 1];
blkGroupSize = 1;
numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize;
} }
std::vector<index_t> inLengths_;
std::vector<index_t> inStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
InElementwiseOp in_elementwise_op_;
AccElementwiseOp acc_elementwise_op_;
index_t invariant_lowest_length_;
int blkGroupSize;
int numBlockTileIteration;
size_t gridSize;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( const auto out_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
bool sweep_once = bool sweep_once =
...@@ -191,16 +282,45 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -191,16 +282,45 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
}; };
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override static bool IsSupportedArgument(const Argument& arg)
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg); if constexpr(InSrcVectorDim == 0)
{
if constexpr(kNumInvariantDim == 0)
{
return false;
}
else
{
if(arg.inStrides_[kNumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
{
return false;
}
if(arg.invariant_lowest_length_ % InSrcVectorSize != 0)
{
return false;
}
}
}
else
{
if(arg.inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
{
return false;
}
if(arg.inLengths_[Rank - 1] % InSrcVectorSize != 0)
{
return false;
}
}
if(!Reduction::IsSupportedArgument(p_arg_)) // To improve
if(kNumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0)
{ {
return false; return false;
} }
if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0) if(arg.inLengths_[Rank - 1] % OutDstVectorSize != 0)
{ {
return false; return false;
} }
...@@ -208,6 +328,32 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -208,6 +328,32 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
return true; return true;
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const AccDataType alpha,
const AccDataType beta,
const InDataType* in_dev,
OutDataType* out_dev,
InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op)
{
return Argument{inLengths,
inStrides,
reduceDims,
alpha,
beta,
in_dev,
out_dev,
in_elementwise_op,
acc_elementwise_op};
};
// //
// @brief Makes a pointer to Argument class. // @brief Makes a pointer to Argument class.
// //
...@@ -247,6 +393,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -247,6 +393,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
acc_elementwise_op); acc_elementwise_op);
}; };
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(); return std::make_unique<Invoker>();
...@@ -257,10 +405,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -257,10 +405,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceReduceSoftmax<" << BlockSize << ","; str << "DeviceReduceSoftmax<"
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; << Rank << "," << NumReduceDim << "," << BlockSize << ","
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","
<< "InSrcVectorDim_" << InSrcVectorDim
<< "_InSrcVectorSize_" << InSrcVectorSize
<< "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
enum struct MaskingSpecialization
{
MaskDisabled,
MaskOutUpperTriangle
};
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
{
switch(s)
{
case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle";
default: return "Unrecognized specialization!";
}
}
struct MaskDisabledPredicate
{
__host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const
{
return false;
};
__host__ __device__ constexpr bool
IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const
{
return false;
}
};
struct MaskOutUpperTrianglePredicate
{
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; }
__host__ __device__ constexpr bool
IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const
{
return operator()(m + m_tile - 1, n);
}
};
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
template <typename MaskOutPredicate>
struct C0MatrixMask_impl
{
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const
{
return predicate_(m, n) || IsNOutOfBound(n);
}
__host__ __device__ constexpr bool
IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const
{
return predicate_.IsTileSkippable(m, n, m_tile, n_tile);
}
private:
// index_t MRaw_;
index_t NRaw_;
MaskOutPredicate predicate_;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t K_BlockTileSize, index_t KThreadSliceSize>
struct GetReduceCountPerThreadForBlockwiseWelford
{
GetReduceCountPerThreadForBlockwiseWelford(index_t numBlockTileIteration,
long_index_t reduce_length)
: numBlockTileIteration_{numBlockTileIteration}
{
count_in_last_tile_ = reduce_length % K_BlockTileSize;
};
__device__ index_t operator()(index_t thread_k_cluster_id) const
{
if(count_in_last_tile_ == 0)
return (KThreadSliceSize * numBlockTileIteration_);
else
{
index_t num_complete_slice = count_in_last_tile_ / KThreadSliceSize;
index_t count_in_last_slice = count_in_last_tile_ % KThreadSliceSize;
if(thread_k_cluster_id < num_complete_slice)
return (KThreadSliceSize * numBlockTileIteration_);
else if(thread_k_cluster_id == num_complete_slice)
return (KThreadSliceSize * (numBlockTileIteration_ - 1) + count_in_last_slice);
else
return (KThreadSliceSize * (numBlockTileIteration_ - 1));
};
};
index_t numBlockTileIteration_;
index_t count_in_last_tile_;
};
template <index_t K_BlockTileSize, index_t KThreadSliceSize>
struct GetReduceCountPerThreadForMultiblockWelford
{
GetReduceCountPerThreadForMultiblockWelford(index_t blkGroupSize,
index_t numBlockTileIteration,
long_index_t reduce_length)
: blkGroupSize_(blkGroupSize), numBlockTileIteration_{numBlockTileIteration}
{
last_block_reduce_length_ =
reduce_length - K_BlockTileSize * numBlockTileIteration_ * (blkGroupSize_ - 1);
numBlockTileIterationByLastBlock_ =
(last_block_reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize;
};
__device__ index_t operator()(index_t block_local_id, index_t thread_k_cluster_id) const
{
if(last_block_reduce_length_ == K_BlockTileSize * numBlockTileIteration_ ||
block_local_id < blkGroupSize_ - 1)
return (KThreadSliceSize * numBlockTileIteration_);
index_t count_in_last_tile = last_block_reduce_length_ % K_BlockTileSize;
if(count_in_last_tile == 0)
return (KThreadSliceSize * numBlockTileIterationByLastBlock_);
else
{
index_t num_complete_slice = count_in_last_tile / KThreadSliceSize;
if(thread_k_cluster_id < num_complete_slice)
return (KThreadSliceSize * numBlockTileIterationByLastBlock_);
else if(thread_k_cluster_id == num_complete_slice)
return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1) +
count_in_last_tile);
else
return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1));
};
};
index_t blkGroupSize_;
index_t numBlockTileIteration_;
index_t last_block_reduce_length_;
index_t numBlockTileIterationByLastBlock_;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck/utility/math_v2.hpp" #include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/quantization_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
#pragma once
#include "ck/utility/data_type.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
template <typename Activation>
struct Activation_Mul_Clamp
{
Activation_Mul_Clamp(float multiplier, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
{
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
float y_fp32 = math::clamp(multiplier_ * x_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void operator()(float& y, const int32_t& x) const
{
// We might type_convert to int8 after lambda in someplace
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
y = math::clamp(multiplier_ * x_fp32, -128.f, 127.f);
}
float multiplier_;
Activation activationOp_;
};
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
template <typename Activation>
struct Add_Activation_Mul_Clamp
{
Add_Activation_Mul_Clamp(float multiplier, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const
{
float y_fp32 = ck::type_convert<float>(x1 + x2);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float multiplier_;
Activation activationOp_;
};
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp
{
Add_Mul_Activation_Mul_Clamp(float multiplier1, float multiplier2, Activation activationOp)
: multiplier1_(multiplier1), multiplier2_(multiplier2), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const
{
float y_fp32 = ck::type_convert<float>(x1 + x2);
y_fp32 = multiplier1_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier2_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float multiplier1_;
float multiplier2_;
Activation activationOp_;
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp" #include "ck/utility/math_v2.hpp"
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseMultiblockWelfordFirstHalf_,
typename XDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor>
__global__ void kernel_multiblock_welford_first_half(
const XGridDesc_M_K x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
mean_var_count_grid_desc_m_g,
get_reduce_count_per_thread,
num_k_block_tile_iteration,
p_x,
p_welford_mean,
p_welford_variance,
p_welford_count);
};
template <typename XDataType,
typename AccDataType,
typename MeanVarDataType,
typename XGridDesc_M_K,
typename MeanVarCountGridDesc_M_G,
typename GetReduceCountPerThreadFunctor,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcCountSrcVectorDim,
index_t XSrcCountSrcVectorSize>
struct GridwiseMultiblockWelfordFirstHalf
{
static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
(XSrcCountSrcVectorDim == 1 &&
KThreadSliceSize % XSrcCountSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
false>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
index_t num_k_block_tile_iteration,
const XDataType* const __restrict__ p_x,
MeanVarDataType* const p_welford_mean,
MeanVarDataType* const p_welford_variance,
int32_t* const p_welford_count)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcCountSrcVectorDim,
XSrcCountSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_welford_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
auto threadwise_welford_count_store =
ThreadwiseTensorSliceTransfer_v1r3<int32_t,
int32_t,
decltype(thread_buffer_desc_m_1),
MeanVarCountGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_count_grid_desc_m_g,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
block_local_id),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ =
get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
welford_count_thread_buf(I) = threadwise_welford.cur_count_;
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
if(thread_k_cluster_id == 0)
{
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_mean_thread_buf,
mean_var_count_grid_desc_m_g,
welford_mean_global_val_buf);
threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_var_thread_buf,
mean_var_count_grid_desc_m_g,
welford_var_global_val_buf);
threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
make_tuple(I0, I0),
welford_count_thread_buf,
mean_var_count_grid_desc_m_g,
welford_count_global_val_buf);
};
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseWelfordSecondHalfBatchNormForwardFinal_,
typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarCountGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M>
__global__ void kernel_welford_second_half_batchnorm_forward_final(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K y_grid_desc_m_k,
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
const MeanVarGridDesc_M mean_var_grid_desc_m,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count,
const XDataType* const __restrict__ p_x,
const ScaleDataType* const __restrict__ p_scale,
const BiasDataType* const __restrict__ p_bias,
const YElementwiseOp y_elementwise_op,
YDataType* const __restrict__ p_y,
bool updateMovingAverage,
AccDataType averageFactor,
MeanVarDataType* const __restrict__ resultRunningMean,
MeanVarDataType* const __restrict__ resultRunningVariance,
bool saveMeanInvVariance,
MeanVarDataType* const __restrict__ resultSaveMean,
MeanVarDataType* const __restrict__ resultSaveInvVariance)
{
GridwiseWelfordSecondHalfBatchNormForwardFinal_::Run(x_grid_desc_m_k,
y_grid_desc_m_k,
mean_var_count_grid_desc_m_k,
scale_grid_desc_m,
bias_grid_desc_m,
mean_var_grid_desc_m,
blkgroup_size,
num_xy_k_block_tile_iteration,
num_mean_var_count_k_block_tile_iteration,
epsilon,
p_in_welford_mean,
p_in_welford_variance,
p_in_welford_count,
p_x,
p_scale,
p_bias,
y_elementwise_op,
p_y,
updateMovingAverage,
averageFactor,
resultRunningMean,
resultRunningVariance,
saveMeanInvVariance,
resultSaveMean,
resultSaveInvVariance);
};
template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename YElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarCountGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t XSrcYDstVectorDim,
index_t XSrcVectorSize,
index_t YDstVectorSize,
index_t ScaleSrcVectorSize,
index_t BiasSrcVectorSize,
index_t MeanVarSrcDstVectorSize>
struct GridwiseWelfordSecondHalfBatchNormForwardFinal
{
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
(XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_1 = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using ThreadwiseWelford =
ThreadwiseWelfordMerge<AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M>;
using BlockwiseWelford = BlockwiseWelford<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& y_grid_desc_m_k,
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
AccDataType epsilon,
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count,
const XDataType* const __restrict__ p_x,
const ScaleDataType* const __restrict__ p_scale,
const BiasDataType* const __restrict__ p_bias,
const YElementwiseOp y_elementwise_op,
YDataType* const __restrict__ p_y,
bool updateMovingAverage,
AccDataType averageFactor,
MeanVarDataType* const __restrict__ resultRunningMean,
MeanVarDataType* const __restrict__ resultRunningVariance,
bool saveMeanInvVariance,
MeanVarDataType* const __restrict__ resultSaveMean,
MeanVarDataType* const __restrict__ resultSaveInvVariance)
{
using ck::math::sqrt;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
in_welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * 1, true>
in_welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize * 1, true>
in_welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
welford_var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
x_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
y_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> scale_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> bias_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / blkgroup_size;
const index_t block_local_id = block_global_id % blkgroup_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
auto threadwise_mean_var_load_m_k =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
auto threadwise_count_load_m_k =
ThreadwiseTensorSliceTransfer_v2<int32_t,
int32_t,
MeanVarCountGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_var_count_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * 1));
const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
constexpr auto mean_var_count_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
// Step 1: do final welford reduction to get mean and variance
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
welford_count_thread_buf(I) = 0;
});
for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration;
++reducedTiles)
{
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_mean_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_mean_thread_buf);
threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_var_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_var_thread_buf);
threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
welford_count_global_val_buf,
thread_buffer_desc_m_1,
make_tuple(I0, I0),
in_welford_count_thread_buf);
ThreadwiseWelford::Run(in_welford_mean_thread_buf,
in_welford_var_thread_buf,
in_welford_count_thread_buf,
welford_mean_thread_buf,
welford_var_thread_buf,
welford_count_thread_buf);
threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_thread_copy_step_m_k);
threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
mean_var_count_thread_copy_step_m_k);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if constexpr(I > 0)
block_sync_lds();
BlockwiseWelford::Run(
welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
});
// Step 2: do normalization and output y
const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
AccDataType,
XYGridDesc_M_K,
decltype(thread_buffer_desc_m_k),
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcYDstVectorDim,
XSrcVectorSize,
1,
true>(
x_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id +
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_y_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
YDataType,
decltype(thread_buffer_desc_m_k),
XYGridDesc_M_K,
YElementwiseOp,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcYDstVectorDim,
YDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
y_grid_desc_m_k,
make_multi_index(
blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
workSizePerBlock * block_local_id + thread_k_cluster_id * KThreadSliceSize),
y_elementwise_op);
auto threadwise_scale_load =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
ScaleSrcVectorSize,
1,
true>(
scale_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType,
ScaleBiasGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
BiasSrcVectorSize,
1,
true>(
bias_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_x, x_grid_desc_m_k.GetElementSpaceSize());
const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale, scale_grid_desc_m.GetElementSpaceSize());
const auto bias_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias, bias_grid_desc_m.GetElementSpaceSize());
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y, y_grid_desc_m_k.GetElementSpaceSize());
threadwise_scale_load.Run(scale_grid_desc_m,
scale_global_val_buf,
thread_buffer_desc_m,
make_tuple(I0),
scale_thread_buf);
threadwise_bias_load.Run(bias_grid_desc_m,
bias_global_val_buf,
thread_buffer_desc_m,
make_tuple(I0),
bias_thread_buf);
constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
for(index_t workTiles = 0; workTiles < num_xy_k_block_tile_iteration; ++workTiles)
{
threadwise_x_load.Run(x_grid_desc_m_k,
x_global_val_buf,
thread_buffer_desc_m_k,
make_tuple(I0, I0),
x_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType multiplier =
scale_thread_buf[iM] / sqrt(welford_var_thread_buf[iM] + epsilon);
AccDataType fused_mean_bias =
bias_thread_buf[iM] - welford_mean_thread_buf[iM] * multiplier;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
y_thread_buf(Number<offset>{}) =
x_thread_buf[Number<offset>{}] * multiplier + fused_mean_bias;
});
});
threadwise_y_store.Run(thread_buffer_desc_m_k,
make_tuple(I0, I0),
y_thread_buf,
y_grid_desc_m_k,
y_global_val_buf);
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_thread_copy_step_m_k);
}
// Step 3: update the moving average of mean and variance (optional)
if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0)
{
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
running_mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true>
running_var_thread_buf;
auto running_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize());
auto running_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize());
auto threadwise_mean_var_load_m =
ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
AccDataType,
MeanVarGridDesc_M,
decltype(thread_buffer_desc_m),
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_mean_var_load_m.Run(mean_var_grid_desc_m,
running_mean_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
running_mean_thread_buf);
threadwise_mean_var_load_m.Run(mean_var_grid_desc_m,
running_var_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
running_var_thread_buf);
AccDataType oneMinusAverageFactor = type_convert<AccDataType>(1.0) - averageFactor;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor +
welford_mean_thread_buf[I] * averageFactor;
running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor +
welford_var_thread_buf[I] * averageFactor;
});
auto threadwise_mean_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m),
MeanVarGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_mean_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
running_mean_thread_buf,
mean_var_grid_desc_m,
running_mean_global_buf);
threadwise_mean_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
running_var_thread_buf,
mean_var_grid_desc_m,
running_var_global_buf);
};
// Step 4: save mean and inv-variance (optional)
if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0)
{
auto result_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize());
auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
welford_var_thread_buf(I) =
type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]);
});
auto threadwise_mean_inv_var_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
MeanVarDataType,
decltype(thread_buffer_desc_m),
MeanVarGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>,
0,
MeanVarSrcDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
mean_var_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
welford_mean_thread_buf,
mean_var_grid_desc_m,
result_mean_global_buf);
threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
make_tuple(I0),
welford_var_thread_buf,
mean_var_grid_desc_m,
result_inv_var_global_buf);
};
}
};
} // namespace ck
...@@ -364,14 +364,16 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -364,14 +364,16 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
index_t M01 = 1, index_t M01 = 1,
index_t N01 = 1, index_t N01 = 1,
index_t KSplit = 1) index_t KSplit = 1)
: M01_(M01), : c_grid_desc_m_n_(c_grid_desc_m_n),
M01_(M01),
N01_(N01), N01_(N01),
KSplit_(KSplit), KSplit_(KSplit),
underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit)) underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit))
{ {
} }
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ __device__ constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{ {
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
...@@ -387,7 +389,10 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -387,7 +389,10 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
template <typename TopIdx> template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{ {
return underlying_map_.CalculateBottomIndex(idx_top); static_assert(TopIdx::Size() == 1);
return underlying_map_.CalculateBottomIndex(
make_multi_index(idx_top[I0] % CalculateGridSize()));
} }
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
...@@ -418,6 +423,11 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -418,6 +423,11 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
} }
private: private:
__device__ constexpr index_t CalculateGridSize() const
{
return CalculateGridSize(c_grid_desc_m_n_);
}
__host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, __host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01, index_t M01,
index_t N01, index_t N01,
...@@ -450,6 +460,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -450,6 +460,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor; return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor;
} }
CGridDesc_M_N c_grid_desc_m_n_;
index_t M01_, N01_, KSplit_; index_t M01_, N01_, KSplit_;
using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1)); using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1));
UnderlyingMap underlying_map_; UnderlyingMap underlying_map_;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...@@ -74,7 +74,8 @@ template <typename FloatAB, ...@@ -74,7 +74,8 @@ template <typename FloatAB,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched> LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedGemmGemm_Xdl_CShuffle struct GridwiseBatchedGemmGemm_Xdl_CShuffle
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
...@@ -101,7 +102,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -101,7 +102,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>; using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -486,8 +488,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -486,8 +488,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// gridwise GEMM pipeline // gridwise GEMM pipeline
// Only supports LoopScheduler::Default // Only supports LoopScheduler::Default
const auto gridwise_gemm_pipeline = const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector<PipelineVer,
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopScheduler::Default>(); NumGemmKPrefetchStage,
LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment