Commit ab663329 authored by aska-0096's avatar aska-0096
Browse files

Merge develop

parents 4fec5ad3 8a4253ba
...@@ -150,7 +150,10 @@ template <typename ADataType, ...@@ -150,7 +150,10 @@ template <typename ADataType,
ck::index_t BBlockTransferDstScalarPerVector_K1, ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector> ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumGemmKPrefetchStage = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -323,7 +326,10 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -323,7 +326,10 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector,
NumGemmKPrefetchStage,
LoopSched,
PipelineVer>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
...@@ -622,6 +628,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -622,6 +628,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
{ {
auto str = std::stringstream(); auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off // clang-format off
str << "DeviceBatchedGemmXdl" str << "DeviceBatchedGemmXdl"
<< "<" << "<"
...@@ -629,7 +641,13 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -629,7 +641,13 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock
<< ">"; << ">"
<< " NumGemmKPrefetchStage: "
<< NumGemmKPrefetchStage << ", "
<< "LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -214,6 +214,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -214,6 +214,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
K1,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
......
...@@ -141,7 +141,8 @@ template <typename ALayout, ...@@ -141,7 +141,8 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
...@@ -282,7 +283,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -282,7 +283,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched,
PipelineVer>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
...@@ -664,6 +666,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -664,6 +666,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{ {
auto str = std::stringstream(); auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off // clang-format off
str << "DeviceGemmMultipleD_Xdl_CShuffle" str << "DeviceGemmMultipleD_Xdl_CShuffle"
<< "<" << "<"
...@@ -674,7 +682,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -674,7 +682,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec)
<< ">"; << ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -56,7 +56,9 @@ template <typename ADataType, ...@@ -56,7 +56,9 @@ template <typename ADataType,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector, ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1> ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmXdl : public DeviceGemm<ALayout, struct DeviceGemmXdl : public DeviceGemm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -230,7 +232,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -230,7 +232,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
NumPrefetch>; NumPrefetch,
LoopSched,
PipelineVer>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -523,6 +527,12 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -523,6 +527,12 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
{ {
auto str = std::stringstream(); auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off // clang-format off
str << "DeviceGemmXdl" str << "DeviceGemmXdl"
<< "<" << "<"
...@@ -535,7 +545,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -535,7 +545,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
<< NPerXDL << ", " << NPerXDL << ", "
<< MXdlPerWave << ", " << MXdlPerWave << ", "
<< NXdlPerWave << NXdlPerWave
<< ">"; << ">"
<< " NumPrefetch: "
<< NumPrefetch << ", "
<< "LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -64,7 +64,8 @@ template <typename ALayout, ...@@ -64,7 +64,8 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -393,7 +394,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -393,7 +394,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched,
PipelineVer>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -656,6 +658,12 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -656,6 +658,12 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{ {
auto str = std::stringstream(); auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off // clang-format off
str << "DeviceGemm_Xdl_CShuffle" str << "DeviceGemm_Xdl_CShuffle"
<< "<" << "<"
...@@ -665,7 +673,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -665,7 +673,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1
<< ">"; << ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];;
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -24,17 +24,17 @@ template <typename GridwiseReduction, ...@@ -24,17 +24,17 @@ template <typename GridwiseReduction,
typename AccDataType, typename AccDataType,
typename AccElementwiseOperation, typename AccElementwiseOperation,
typename GridDesc_M_K> typename GridDesc_M_K>
__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k, __global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k, const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k, const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
{ {
GridwiseReduction::Run(x_grid_desc_m_k, GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
...@@ -54,7 +54,7 @@ namespace ck { ...@@ -54,7 +54,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Y = LayerNorm(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
...@@ -168,49 +168,49 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -168,49 +168,49 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduceLayernormGeneric = using GridwiseReduceLayernormGeneric =
GridwiseLayernormWelfordVariance_mk_to_mk<XDataType, GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation, AccElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
MThreadSliceSize, MThreadSliceSize,
KThreadSliceSize, KThreadSliceSize,
XYSrcVectorDim, XYSrcVectorDim,
XSrcVectorSize, XSrcVectorSize,
GammaSrcVectorDim, GammaSrcVectorDim,
GammaSrcVectorSize, GammaSrcVectorSize,
BetaSrcVectorDim, BetaSrcVectorDim,
BetaSrcVectorSize, BetaSrcVectorSize,
XYSrcVectorDim, XYSrcVectorDim,
YDstVectorSize, YDstVectorSize,
false>; false>;
using GridwiseReduceLayernormSweepOnce = using GridwiseNormalizationSweepOnce =
GridwiseLayernormWelfordVariance_mk_to_mk<XDataType, GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation, AccElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
MThreadSliceSize, MThreadSliceSize,
KThreadSliceSize, KThreadSliceSize,
XYSrcVectorDim, XYSrcVectorDim,
XSrcVectorSize, XSrcVectorSize,
GammaSrcVectorDim, GammaSrcVectorDim,
GammaSrcVectorSize, GammaSrcVectorSize,
BetaSrcVectorDim, BetaSrcVectorDim,
BetaSrcVectorSize, BetaSrcVectorSize,
XYSrcVectorDim, XYSrcVectorDim,
YDstVectorSize, YDstVectorSize,
true>; true>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -295,22 +295,22 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -295,22 +295,22 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto kernel_main = arg.isSweeponce_ const auto kernel_main = arg.isSweeponce_
? kernel_layernorm<GridwiseReduceLayernormSweepOnce, ? kernel_normalization<GridwiseNormalizationSweepOnce,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation, AccElementwiseOperation,
GridDesc_M_K> GridDesc_M_K>
: kernel_layernorm<GridwiseReduceLayernormGeneric, : kernel_normalization<GridwiseReduceLayernormGeneric,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation, AccElementwiseOperation,
GridDesc_M_K>; GridDesc_M_K>;
float avg_time = 0; float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
...@@ -426,8 +426,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -426,8 +426,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
void* p_y, void* p_y,
void* p_saveMean,
void* p_saveInvVar,
AccElementwiseOperation acc_elementwise_op) override AccElementwiseOperation acc_elementwise_op) override
{ {
// TODO
// Optional cache of the intermediate results (mean and InvVariance) during the
// forward pass could speedup in the backward
ignore = p_saveMean;
ignore = p_saveInvVar;
return std::make_unique<Argument>(lengths, return std::make_unique<Argument>(lengths,
xStrides, xStrides,
gammaStrides, gammaStrides,
......
...@@ -226,6 +226,30 @@ struct DeviceReduceMultiBlock ...@@ -226,6 +226,30 @@ struct DeviceReduceMultiBlock
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
if(Rank != inLengths.size() || Rank != inStrides.size() ||
NumReduceDim != reduceDims.size())
{
throw std::runtime_error(
"One of inLengths/inStrides/reduceDims has invalid size!"
"\nExpected size inLengths: " +
std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) +
", reduceDims: " + std::to_string(NumReduceDim) +
"\nBut have inLengths: " + std::to_string(inLengths.size()) +
", inStrides: " + std::to_string(inStrides.size()) +
", reduceDims: " + std::to_string(reduceDims.size()));
}
for(std::size_t i = 0; i < reduceDims.size(); ++i)
{
if(reduceDims[i] < 0 || reduceDims[i] >= Rank)
{
throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!"
"\nHave reduceDims[" +
std::to_string(i) +
"]: " + std::to_string(reduceDims[i]));
}
}
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims); inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
......
...@@ -40,8 +40,9 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -40,8 +40,9 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
AccElementwiseOp, AccElementwiseOp,
Rank> Rank>
{ {
static constexpr index_t kRank = Rank; static constexpr index_t kRank = Rank;
static constexpr index_t kNumReduceDim = NumReduceDim; static constexpr index_t kNumReduceDim = NumReduceDim;
static constexpr index_t kNumInvariantDim = Rank - NumReduceDim;
virtual index_t GetRank() const override { return kRank; } virtual index_t GetRank() const override { return kRank; }
...@@ -168,6 +169,30 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -168,6 +169,30 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
if(Rank != inLengths.size() || Rank != inStrides.size() ||
NumReduceDim != reduceDims.size())
{
throw std::runtime_error(
"One of inLengths/inStrides/reduceDims has invalid size!"
"\nExpected size inLengths: " +
std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) +
", reduceDims: " + std::to_string(NumReduceDim) +
"\nBut have inLengths: " + std::to_string(inLengths.size()) +
", inStrides: " + std::to_string(inStrides.size()) +
", reduceDims: " + std::to_string(reduceDims.size()));
}
for(std::size_t i = 0; i < reduceDims.size(); ++i)
{
if(reduceDims[i] < 0 || reduceDims[i] >= Rank)
{
throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!"
"\nHave reduceDims[" +
std::to_string(i) +
"]: " + std::to_string(reduceDims[i]));
}
}
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims); inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
...@@ -257,40 +282,78 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -257,40 +282,78 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
}; };
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override static bool IsSupportedArgument(const Argument& arg)
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
if constexpr(InSrcVectorDim == 0) if constexpr(InSrcVectorDim == 0)
{ {
if constexpr(NumInvariantDim == 0) if constexpr(kNumInvariantDim == 0)
{ {
return false; return false;
} }
else else
{ {
if(p_arg_->inStrides_[NumInvariantDim - 1] != 1) if(arg.inStrides_[kNumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
{
return false; return false;
}
if(p_arg_->invariant_lowest_length_ % InSrcVectorSize != 0) if(arg.invariant_lowest_length_ % InSrcVectorSize != 0)
{
return false; return false;
}; }
}
} }
else else
{ {
if(p_arg_->inStrides_[Rank - 1] != 1) if(arg.inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
{
return false; return false;
}
if(p_arg_->inLengths_[Rank - 1] % InSrcVectorSize != 0) if(arg.inLengths_[Rank - 1] % InSrcVectorSize != 0)
{
return false; return false;
}; }
}
// To improve
if(kNumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0)
{
return false;
}
if(p_arg_->invariant_lowest_length_ % OutDstVectorSize != 0) if(arg.inLengths_[Rank - 1] % OutDstVectorSize != 0)
{
return false; return false;
}
return true; return true;
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const AccDataType alpha,
const AccDataType beta,
const InDataType* in_dev,
OutDataType* out_dev,
InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op)
{
return Argument{inLengths,
inStrides,
reduceDims,
alpha,
beta,
in_dev,
out_dev,
in_elementwise_op,
acc_elementwise_op};
};
// //
// @brief Makes a pointer to Argument class. // @brief Makes a pointer to Argument class.
// //
...@@ -330,6 +393,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -330,6 +393,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
acc_elementwise_op); acc_elementwise_op);
}; };
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(); return std::make_unique<Invoker>();
...@@ -340,10 +405,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -340,10 +405,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceReduceSoftmax<" << BlockSize << ","; str << "DeviceReduceSoftmax<"
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; << Rank << "," << NumReduceDim << "," << BlockSize << ","
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","
<< "InSrcVectorDim_" << InSrcVectorDim
<< "_InSrcVectorSize_" << InSrcVectorSize
<< "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment