"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "7e0566905aad8c90545fe26937af634577dbf395"
Commit f74b77bc authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into stream-k-initial-impl

parents b5be51ed 0d911822
...@@ -67,7 +67,7 @@ __global__ void ...@@ -67,7 +67,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
...@@ -55,7 +55,7 @@ __global__ void ...@@ -55,7 +55,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940)) defined(__gfx940) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
...@@ -25,7 +25,7 @@ __global__ void ...@@ -25,7 +25,7 @@ __global__ void
kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
...@@ -46,7 +46,7 @@ __global__ void ...@@ -46,7 +46,7 @@ __global__ void
typename GridwiseGemm::Problem problem) typename GridwiseGemm::Problem problem)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
......
...@@ -58,7 +58,7 @@ __global__ void ...@@ -58,7 +58,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// TODO ANT: separate into MMA + Epilogue // TODO ANT: separate into MMA + Epilogue
......
...@@ -166,7 +166,7 @@ __global__ void ...@@ -166,7 +166,7 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
......
...@@ -45,7 +45,7 @@ __global__ void ...@@ -45,7 +45,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid,
......
...@@ -36,7 +36,7 @@ __global__ void ...@@ -36,7 +36,7 @@ __global__ void
const CGridDesc_M_N c_grid_desc_m_n) const CGridDesc_M_N c_grid_desc_m_n)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -64,7 +64,7 @@ __global__ void ...@@ -64,7 +64,7 @@ __global__ void
kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const auto a_grid_desc_k0_m_k1 = const auto a_grid_desc_k0_m_k1 =
......
...@@ -43,7 +43,7 @@ __global__ void ...@@ -43,7 +43,7 @@ __global__ void
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......
...@@ -31,7 +31,7 @@ __global__ void ...@@ -31,7 +31,7 @@ __global__ void
const Block2CTileMap& b2c_map) const Block2CTileMap& b2c_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
......
...@@ -47,7 +47,7 @@ __global__ void ...@@ -47,7 +47,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainK0BlockLoop>( GridwiseGemm::template Run<HasMainK0BlockLoop>(
......
...@@ -50,7 +50,7 @@ __global__ void ...@@ -50,7 +50,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
......
...@@ -54,7 +54,7 @@ __global__ void ...@@ -54,7 +54,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
......
...@@ -344,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -344,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
{ {
#if defined(__gfx90a__) || defined(__gfx940__) #if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
#else #else
......
...@@ -5,11 +5,10 @@ ...@@ -5,11 +5,10 @@
#include <vector> #include <vector>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -29,20 +28,34 @@ template <typename InputType, ...@@ -29,20 +28,34 @@ template <typename InputType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename OutputType> typename OutputType>
auto get_device_normalize_from_mean_meansquare_instances() struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<InputType, MeanType, MeanSquareType, GammaDataType, BetaDataType>,
ck::Tuple<OutputType>,
Normalize,
2>>
{ {
std::vector<DeviceNormalizeFromMeanMeanSquarePtr> op_ptrs; using DeviceOp = DeviceElementwise<
ck::Tuple<InputType, MeanType, MeanSquareType, GammaDataType, BetaDataType>,
ck::Tuple<OutputType>,
Normalize,
2>;
if constexpr(is_same<InputType, half_t>::value && is_same<MeanType, float>::value && static auto GetInstances()
is_same<MeanSquareType, float>::value && is_same<GammaDataType, half_t>::value &&
is_same<BetaDataType, half_t>::value && is_same<OutputType, half_t>::value)
{ {
ck::tensor_operation::device::instance:: std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(op_ptrs);
} if constexpr(is_same<InputType, half_t>::value && is_same<MeanType, float>::value &&
is_same<MeanSquareType, float>::value &&
return op_ptrs; is_same<GammaDataType, half_t>::value &&
} is_same<BetaDataType, half_t>::value && is_same<OutputType, half_t>::value)
{
ck::tensor_operation::device::instance::
add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(op_ptrs);
}
return op_ptrs;
};
};
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -245,11 +245,11 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( ...@@ -245,11 +245,11 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
// grouped conv3d forward, NDHWGC/KZYXGC/NDHWGK // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
KZYXGC, GKZYXC,
Empty_Tuple, Empty_Tuple,
NDHWGK, NDHWGK,
BF16, BF16,
...@@ -260,10 +260,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances( ...@@ -260,10 +260,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
KZYXGC, GKZYXC,
Empty_Tuple, Empty_Tuple,
NDHWGK, NDHWGK,
F16, F16,
...@@ -274,10 +274,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances( ...@@ -274,10 +274,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
KZYXGC, GKZYXC,
Empty_Tuple, Empty_Tuple,
NDHWGK, NDHWGK,
F32, F32,
...@@ -288,10 +288,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances( ...@@ -288,10 +288,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
KZYXGC, GKZYXC,
Empty_Tuple, Empty_Tuple,
NDHWGK, NDHWGK,
int8_t, int8_t,
...@@ -433,28 +433,28 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -433,28 +433,28 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} }
} }
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> && else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, KZYXGC> && is_same_v<OutLayout, NDHWGK>) is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{ {
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>) is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs);
} }
} }
......
...@@ -9,34 +9,33 @@ ...@@ -9,34 +9,33 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp" #include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank3_instances( template <typename InDataType,
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>&); typename AccDataType,
void add_device_softmax_f16_f16_rank4_instances( typename OutDataType,
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>&); index_t Rank,
index_t NumReduceDim>
void add_device_softmax_f32_f32_rank3_instances( struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftmax<InDataType,
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3>>&); AccDataType,
void add_device_softmax_f32_f32_rank4_instances( OutDataType,
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4>>&); PassThrough,
PassThrough,
void add_device_softmax_i8_i8_rank3_instances( Rank,
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3>>&); NumReduceDim>>
void add_device_softmax_i8_i8_rank4_instances(
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4>>&);
template <typename InDataType, typename AccDataType, typename OutDataType, index_t Rank>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>>
{ {
using DeviceOp = using DeviceOp = DeviceSoftmax<InDataType,
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>; AccDataType,
OutDataType,
PassThrough,
PassThrough,
Rank,
NumReduceDim>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -46,25 +45,49 @@ struct DeviceOperationInstanceFactory< ...@@ -46,25 +45,49 @@ struct DeviceOperationInstanceFactory<
std::is_same_v<OutDataType, F16>) std::is_same_v<OutDataType, F16>)
{ {
if constexpr(Rank == 3) if constexpr(Rank == 3)
add_device_softmax_f16_f16_rank3_instances(op_ptrs); {
if constexpr(NumReduceDim == 1)
add_device_softmax_f16_f16_rank3_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f16_f16_rank3_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f16_f16_rank3_reduce3_instances(op_ptrs);
}
else if constexpr(Rank == 4) else if constexpr(Rank == 4)
add_device_softmax_f16_f16_rank4_instances(op_ptrs); {
if constexpr(NumReduceDim == 1)
add_device_softmax_f16_f16_rank4_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f16_f16_rank4_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f16_f16_rank4_reduce3_instances(op_ptrs);
else if constexpr(NumReduceDim == 4)
add_device_softmax_f16_f16_rank4_reduce4_instances(op_ptrs);
}
} }
else if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> && else if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> &&
std::is_same_v<OutDataType, F32>) std::is_same_v<OutDataType, F32>)
{ {
if constexpr(Rank == 3) if constexpr(Rank == 3)
add_device_softmax_f32_f32_rank3_instances(op_ptrs); {
else if constexpr(Rank == 4) if constexpr(NumReduceDim == 1)
add_device_softmax_f32_f32_rank4_instances(op_ptrs); add_device_softmax_f32_f32_rank3_reduce1_instances(op_ptrs);
} else if constexpr(NumReduceDim == 2)
else if constexpr(std::is_same_v<InDataType, I8> && std::is_same_v<AccDataType, F32> && add_device_softmax_f32_f32_rank3_reduce2_instances(op_ptrs);
std::is_same_v<OutDataType, I8>) else if constexpr(NumReduceDim == 3)
{ add_device_softmax_f32_f32_rank3_reduce3_instances(op_ptrs);
if constexpr(Rank == 3) }
add_device_softmax_i8_i8_rank3_instances(op_ptrs);
else if constexpr(Rank == 4) else if constexpr(Rank == 4)
add_device_softmax_i8_i8_rank4_instances(op_ptrs); {
if constexpr(NumReduceDim == 1)
add_device_softmax_f32_f32_rank4_reduce1_instances(op_ptrs);
else if constexpr(NumReduceDim == 2)
add_device_softmax_f32_f32_rank4_reduce2_instances(op_ptrs);
else if constexpr(NumReduceDim == 3)
add_device_softmax_f32_f32_rank4_reduce3_instances(op_ptrs);
else if constexpr(NumReduceDim == 4)
add_device_softmax_f32_f32_rank4_reduce4_instances(op_ptrs);
}
} }
return op_ptrs; return op_ptrs;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_softmax_f16_f16_rank3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances);
void add_device_softmax_f16_f16_rank4_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>& instances);
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank3_reduce1_instances( void add_device_softmax_f16_f16_rank3_reduce1_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 1>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank3_reduce2_instances( void add_device_softmax_f16_f16_rank3_reduce2_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 2>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
namespace instance { namespace instance {
void add_device_softmax_f16_f16_rank3_reduce3_instances( void add_device_softmax_f16_f16_rank3_reduce3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>& instances); std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 3>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment