Commit ce72f286 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 50320413 f30e5975
...@@ -20,7 +20,8 @@ template <typename ALayout, ...@@ -20,7 +20,8 @@ template <typename ALayout,
typename CDataType, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
typename ComputeType = CDataType>
struct DeviceGemmSplitK : public BaseOperator struct DeviceGemmSplitK : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
...@@ -48,7 +49,8 @@ template <typename ALayout, ...@@ -48,7 +49,8 @@ template <typename ALayout,
typename CDataType, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
typename ComputeType = CDataType>
using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout, using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -57,7 +59,8 @@ using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout, ...@@ -57,7 +59,8 @@ using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>>; CElementwiseOperation,
ComputeType>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -14,8 +14,8 @@ namespace device { ...@@ -14,8 +14,8 @@ namespace device {
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
...@@ -27,6 +27,8 @@ struct DeviceNormalization : public BaseOperator ...@@ -27,6 +27,8 @@ struct DeviceNormalization : public BaseOperator
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
double epsilon, double epsilon,
const void* p_x, const void* p_x,
...@@ -43,16 +45,16 @@ struct DeviceNormalization : public BaseOperator ...@@ -43,16 +45,16 @@ struct DeviceNormalization : public BaseOperator
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType, using DeviceNormalizationPtr = std::unique_ptr<DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim>>; NumReduceDim>>;
......
...@@ -296,6 +296,28 @@ struct DeviceElementwiseImpl ...@@ -296,6 +296,28 @@ struct DeviceElementwiseImpl
{ {
return std::make_unique<Invoker>(); return std::make_unique<Invoker>();
}; };
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceElementwiseImpl<" ;
str << "NumDim_" << NumDim << ",";
str << "MPerThread_" << MPerThread << ",";
str << "InScalarPerVector";
static_for<0, InScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << InScalarPerVectorSeq::At(i).value; });
str << ",";
str << "OutScalarPerVector";
static_for<0, OutScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << OutScalarPerVectorSeq::At(i).value; });
str << ">";
// clang-format on
return str.str();
}
}; // namespace device }; // namespace device
} // namespace device } // namespace device
......
...@@ -69,7 +69,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -69,7 +69,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation,
ComputeType>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -126,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -126,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
PipelineVer, PipelineVer,
ComputeType>; ComputeType>;
using Argument = typename GridwiseGemm::Argument; struct Argument : public GridwiseGemm::Argument
{
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t K0_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: GridwiseGemm::Argument(p_a_grid_,
p_b_grid_,
p_c_grid_,
M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
MPadded_,
NPadded_,
KPadded_,
K0_,
k_batch_),
a_element_op(a_element_op_),
b_element_op(b_element_op_),
c_element_op(c_element_op_)
{
}
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
CElementwiseOperation c_element_op;
};
using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
// Invoker // Invoker
...@@ -167,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -167,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
karg.M * karg.N * sizeof(CDataType), karg.M * karg.N * sizeof(CDataType),
stream_config.stream_id_)); stream_config.stream_id_));
ave_time = launch_and_time_kernel( ave_time =
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map); launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
static_cast<typename GridwiseGemm::Argument>(karg),
b2c_map,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
...@@ -179,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -179,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm, kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
DefaultBlock2CTileMap>; DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel); Run(kernel);
} }
...@@ -189,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -189,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm, kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
DefaultBlock2CTileMap>; DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel); Run(kernel);
} }
...@@ -202,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -202,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm, kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false, false,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
DefaultBlock2CTileMap>; DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel); Run(kernel);
} }
...@@ -212,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -212,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm, kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false, false,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
DefaultBlock2CTileMap>; DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel); Run(kernel);
} }
...@@ -260,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -260,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
AElementwiseOperation, AElementwiseOperation a_element_op,
BElementwiseOperation, BElementwiseOperation b_element_op,
CElementwiseOperation, CElementwiseOperation c_element_op,
index_t KBatch) index_t KBatch)
{ {
return Argument{p_a, return Argument(p_a,
p_b, p_b,
p_c, p_c,
M, M,
...@@ -278,7 +343,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -278,7 +343,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateNPadded(N), GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch), GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch), GridwiseGemm::CalculateK0(K, KBatch),
KBatch}; KBatch,
a_element_op,
b_element_op,
c_element_op);
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -293,9 +361,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -293,9 +361,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
AElementwiseOperation, AElementwiseOperation a_element_op,
BElementwiseOperation, BElementwiseOperation b_element_op,
CElementwiseOperation, CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) override ck::index_t KBatch = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
...@@ -311,7 +379,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -311,7 +379,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateNPadded(N), GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch), GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch), GridwiseGemm::CalculateK0(K, KBatch),
KBatch); KBatch,
a_element_op,
b_element_op,
c_element_op);
} }
// polymorphic // polymorphic
......
...@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, ADataType,
BDataType, BDataType,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
...@@ -22,32 +23,6 @@ namespace ck { ...@@ -22,32 +23,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace {
struct ComputePtrOffsetOfStridedBatch
{
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
} // namespace
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
OutElementwiseOperation a_element_op_; OutElementwiseOperation a_element_op_;
...@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>, remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch<I0>,
has_main_loop, has_main_loop,
has_double_loop>; has_double_loop>;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -21,32 +22,6 @@ namespace ck { ...@@ -21,32 +22,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace {
struct ComputePtrOffsetOfStridedBatch
{
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
} // namespace
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
...@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch<I0>,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
if constexpr(!is_GNWK_GKXC_GNWC) if constexpr(!is_GNWK_GKXC_GNWC)
......
...@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle<
GridwiseOp, GridwiseOp,
ADataType, ADataType,
BDataType, BDataType,
......
...@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch
return ds_offset; return ds_offset;
} }
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
// alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return g_idx * static_cast<long_index_t>(BatchStrideE_);
} }
...@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch
index_t BatchStrideB_; index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_; Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_; index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
}; };
} // namespace device } // namespace device
......
...@@ -28,6 +28,7 @@ template <typename XDataType, ...@@ -28,6 +28,7 @@ template <typename XDataType,
typename BetaDataType, typename BetaDataType,
typename ComputeDataType, typename ComputeDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
...@@ -43,12 +44,13 @@ template <typename XDataType, ...@@ -43,12 +44,13 @@ template <typename XDataType,
index_t BetaSrcVectorDim, index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool UseWelford = true> bool UseWelford = true>
struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
ComputeDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
YElementwiseOperation, YElementwiseOperation,
Rank, Rank,
NumReduceDim> NumReduceDim>
...@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
(BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); "Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static_assert(!reduceAllDim); // TODO
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides, const std::vector<index_t>& inStrides,
int numBlockTileIteration) int numBlockTileIteration)
{ {
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
...@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeSaveMeanInvStdDescriptor_M(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{});
const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto grid_desc_m =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(InvariantDims{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
const auto pad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto grid_desc_m_padded = transform_tensor_descriptor(
grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, pad_M)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return grid_desc_m_padded;
}
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1));
using GridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1}));
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
YElementwiseOperation y_elementwise_op, YElementwiseOperation y_elementwise_op,
double epsilon, double epsilon,
const XDataType* p_x, const XDataType* p_x,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
YDataType* p_y) YDataType* p_y,
SaveMeanInvStdDataType* p_saveMean,
SaveMeanInvStdDataType* p_saveInvStd)
: p_x_(p_x), : p_x_(p_x),
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
p_saveMean_(p_saveMean),
p_saveInvStd_(p_saveInvStd),
y_elementwise_op_(y_elementwise_op) y_elementwise_op_(y_elementwise_op)
{ {
epsilon_ = static_cast<ComputeDataType>(epsilon); epsilon_ = static_cast<ComputeDataType>(epsilon);
...@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims); yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims); gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
saveMeanStrides_ = saveMeanStrides;
saveInvStdStrides_ = saveInvStdStrides;
long_index_t invariant_length; std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(Lengths_);
long_index_t reduce_length;
std::tie(invariant_length, reduce_length) =
get_2d_lengths<Rank, NumReduceDim>(Lengths_);
numBlockTileIteration_ = math::integer_divide_ceil(reduce_length, K_BlockTileSize); numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize);
gridSize_ = math::integer_divide_ceil(invariant_length, M_BlockTileSize); gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize);
x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_); x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_);
gamma_grid_desc_m_k_ = gamma_grid_desc_m_k_ =
...@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
beta_grid_desc_m_k_ = beta_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_); MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_);
y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_); y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_);
save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides);
save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides);
isSweeponce_ = isSweeponce_ =
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
if constexpr(NumInvariantDim == 0)
invariant_lowest_length_ = 1;
else
invariant_lowest_length_ = Lengths_[NumInvariantDim - 1];
} }
ComputeDataType epsilon_; ComputeDataType epsilon_;
...@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
const BetaDataType* p_beta_; const BetaDataType* p_beta_;
YDataType* p_y_; YDataType* p_y_;
SaveMeanInvStdDataType* p_saveMean_;
SaveMeanInvStdDataType* p_saveInvStd_;
std::vector<index_t> Lengths_; std::vector<index_t> Lengths_;
std::vector<index_t> xStrides_; std::vector<index_t> xStrides_;
std::vector<index_t> gammaStrides_; std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_; std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_; std::vector<index_t> yStrides_;
std::vector<index_t> saveMeanStrides_;
std::vector<index_t> saveInvStdStrides_;
YElementwiseOperation y_elementwise_op_; YElementwiseOperation y_elementwise_op_;
...@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GridDesc_M_K gamma_grid_desc_m_k_; GridDesc_M_K gamma_grid_desc_m_k_;
GridDesc_M_K beta_grid_desc_m_k_; GridDesc_M_K beta_grid_desc_m_k_;
GridDesc_M_K y_grid_desc_m_k_; GridDesc_M_K y_grid_desc_m_k_;
GridDesc_M save_mean_grid_desc_m_;
GridDesc_M save_inv_std_grid_desc_m_;
bool isSweeponce_; bool isSweeponce_;
index_t MRaw_; // invarient length
index_t KRaw_; // reduce length
index_t invariant_lowest_length_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
...@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -217,9 +277,11 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -233,6 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
BetaSrcVectorSize, BetaSrcVectorSize,
XYSrcVectorDim, XYSrcVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
UseWelford>(arg.isSweeponce_); UseWelford>(arg.isSweeponce_);
float avg_time = 0; float avg_time = 0;
...@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
arg.gamma_grid_desc_m_k_, arg.gamma_grid_desc_m_k_,
arg.beta_grid_desc_m_k_, arg.beta_grid_desc_m_k_,
arg.y_grid_desc_m_k_, arg.y_grid_desc_m_k_,
arg.save_mean_grid_desc_m_,
arg.save_inv_std_grid_desc_m_,
arg.numBlockTileIteration_, arg.numBlockTileIteration_,
arg.epsilon_, arg.epsilon_,
arg.p_x_, arg.p_x_,
arg.p_gamma_, arg.p_gamma_,
arg.p_beta_, arg.p_beta_,
arg.p_y_, arg.p_y_,
arg.p_saveMean_,
arg.p_saveInvStd_,
arg.y_elementwise_op_); arg.y_elementwise_op_);
return (avg_time); return (avg_time);
...@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -267,8 +334,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg); const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
if constexpr(XYSrcVectorDim == 0) if constexpr(XYSrcVectorDim == 0)
{ {
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
...@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -277,13 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
} }
else else
{ {
printf("!!!! %d\n", p_arg_->invariant_lowest_length_);
if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
return false; return false;
if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0)
return false; return false;
if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0)
return false; return false;
}; };
} }
...@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -325,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
return (false); return (false);
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0)
return (false); return (false);
} }
else // if fastest dim is reduced else // if fastest dim is reduced
...@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -337,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return (false); return (false);
} }
if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0)
return false;
return true; return true;
}; };
...@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -346,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> saveMeanStrides,
const std::vector<index_t> saveInvStdStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
double epsilon, double epsilon,
const void* p_x, const void* p_x,
...@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -353,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const void* p_beta, const void* p_beta,
void* p_y, void* p_y,
void* p_saveMean, void* p_saveMean,
void* p_saveInvVar, void* p_saveInvStd,
YElementwiseOperation y_elementwise_op) override YElementwiseOperation y_elementwise_op) override
{ {
// TODO if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank ||
// Optional cache of the intermediate results (mean and InvVariance) during the betaStrides.size() != Rank || yStrides.size() != Rank ||
// forward pass could speedup in the backward saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim)
ignore = p_saveMean; throw std::runtime_error("dimension is incorrect");
ignore = p_saveInvVar;
return std::make_unique<Argument>(lengths, return std::make_unique<Argument>(lengths,
xStrides, xStrides,
gammaStrides, gammaStrides,
betaStrides, betaStrides,
yStrides, yStrides,
saveMeanStrides,
saveInvStdStrides,
reduceDims, reduceDims,
y_elementwise_op, y_elementwise_op,
epsilon, epsilon,
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma), static_cast<const GammaDataType*>(p_gamma),
static_cast<const BetaDataType*>(p_beta), static_cast<const BetaDataType*>(p_beta),
static_cast<YDataType*>(p_y)); static_cast<YDataType*>(p_y),
static_cast<SaveMeanInvStdDataType*>(p_saveMean),
static_cast<SaveMeanInvStdDataType*>(p_saveInvStd));
}; };
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -113,7 +113,6 @@ struct PassThrough ...@@ -113,7 +113,6 @@ struct PassThrough
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const __host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{ {
...@@ -143,9 +142,7 @@ struct PassThrough ...@@ -143,9 +142,7 @@ struct PassThrough
{ {
y = type_convert<f8_t>(x); y = type_convert<f8_t>(x);
} }
#endif
#if defined CK_ENABLE_BF8
template <> template <>
__host__ __device__ void operator()<bf8_t, bf8_t>(bf8_t& y, const bf8_t& x) const __host__ __device__ void operator()<bf8_t, bf8_t>(bf8_t& y, const bf8_t& x) const
{ {
...@@ -175,7 +172,6 @@ struct PassThrough ...@@ -175,7 +172,6 @@ struct PassThrough
{ {
y = ck::type_convert<bf8_t>(x); y = ck::type_convert<bf8_t>(x);
} }
#endif
}; };
struct UnaryConvert struct UnaryConvert
...@@ -204,7 +200,6 @@ struct ConvertBF16RTN ...@@ -204,7 +200,6 @@ struct ConvertBF16RTN
} }
}; };
#if defined CK_ENABLE_FP8
struct ConvertF8SR struct ConvertF8SR
{ {
// convert to fp8 using stochastic rounding (SR) // convert to fp8 using stochastic rounding (SR)
...@@ -212,7 +207,8 @@ struct ConvertF8SR ...@@ -212,7 +207,8 @@ struct ConvertF8SR
__host__ __device__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
// check Y datatype // check Y datatype
static_assert(is_same<Y, f8_t>::value, "Data type is not supported by this operation!"); static_assert(is_same<Y, f8_t>::value || is_same<Y, bf8_t>::value,
"Data type is not supported by this operation!");
// check X datatype // check X datatype
static_assert(is_same<X, float>::value || is_same<X, half_t>::value, static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
...@@ -221,7 +217,6 @@ struct ConvertF8SR ...@@ -221,7 +217,6 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x); y = f8_convert_sr<Y>(x);
} }
}; };
#endif
struct Scale struct Scale
{ {
...@@ -448,10 +443,11 @@ struct Sigmoid ...@@ -448,10 +443,11 @@ struct Sigmoid
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value, is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = 1 / (ck::type_convert<T>(1) + exp(-x)); y = one / (one + ck::math::exp(-x));
}; };
}; };
...@@ -461,7 +457,8 @@ struct TanH ...@@ -461,7 +457,8 @@ struct TanH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value, is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::tanh(x); y = ck::math::tanh(x);
...@@ -487,7 +484,101 @@ struct Swish ...@@ -487,7 +484,101 @@ struct Swish
y = type_convert<Y>(x / (1.f + ck::math::exp(bx))); y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
}; };
float beta_ = 1.0f; const float beta_;
};
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
const float alpha_;
};
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
const float alpha_;
const float beta_;
const float gamma_;
};
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
const float alpha_;
const float beta_;
};
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
const float alpha_;
};
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
const float alpha_;
}; };
} // namespace element_wise } // namespace element_wise
......
...@@ -945,7 +945,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -945,7 +945,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
} }
}(); }();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{ {
return transform_tensor_descriptor(c_grid_desc_m_n, return transform_tensor_descriptor(c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, MPad - M), make_tuple(make_right_pad_transform(M, MPad - M),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment