"...composable_kernel.git" did not exist on "0a08477b66474ade84147807fc064b923490ca6b"
Unverified Commit ac76519a authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Merge branch 'develop' into aosewski/gemm_tile_loop

parents a70c6283 578142db
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -136,3 +136,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -136,3 +136,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__
void add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances( void add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmStreamK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmStreamK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt ...@@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -63,6 +63,7 @@ using device_grouped_conv_bwd_data_xdl_f16_instances = ...@@ -63,6 +63,7 @@ using device_grouped_conv_bwd_data_xdl_f16_instances =
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// TODO: After enable, add instance for small conv.K and conv.C
#endif #endif
// clang-format on // clang-format on
>; >;
...@@ -97,6 +98,7 @@ using device_grouped_conv_bwd_data_xdl_bf16_instances = ...@@ -97,6 +98,7 @@ using device_grouped_conv_bwd_data_xdl_bf16_instances =
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// TODO: After enable, add instance for small conv.K and conv.C
#endif #endif
// clang-format on // clang-format on
>; >;
...@@ -131,6 +133,7 @@ using device_grouped_conv_bwd_data_xdl_f32_instances = ...@@ -131,6 +133,7 @@ using device_grouped_conv_bwd_data_xdl_f32_instances =
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// TODO: After enable, add instance for small conv.K and conv.C
#endif #endif
// clang-format on // clang-format on
>; >;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "device_grouped_conv2d_fwd_common.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <typename InLayout, template <typename InLayout,
typename WeiLayout, typename WeiLayout,
typename DsLayout, typename DsLayout,
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -192,3 +192,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -192,3 +192,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -16,7 +16,7 @@ namespace ck { ...@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef __fp16__
// FP16 // FP16
void add_device_normalization_rank_2_1_f16_instances( void add_device_normalization_rank_2_1_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 2, 1>>>&); std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 2, 1>>>&);
...@@ -26,7 +26,8 @@ void add_device_normalization_rank_4_3_f16_instances( ...@@ -26,7 +26,8 @@ void add_device_normalization_rank_4_3_f16_instances(
void add_device_normalization_rank_5_3_f16_instances( void add_device_normalization_rank_5_3_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 5, 3>>>&); std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, PassThrough, 5, 3>>>&);
#endif
#ifdef __fp32__
// FP32 // FP32
void add_device_normalization_rank_2_1_f32_instances( void add_device_normalization_rank_2_1_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&); std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 2, 1>>>&);
...@@ -36,7 +37,7 @@ void add_device_normalization_rank_4_3_f32_instances( ...@@ -36,7 +37,7 @@ void add_device_normalization_rank_4_3_f32_instances(
void add_device_normalization_rank_5_3_f32_instances( void add_device_normalization_rank_5_3_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 5, 3>>>&); std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, PassThrough, 5, 3>>>&);
#endif
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
...@@ -65,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -65,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__
if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> && if constexpr(is_same_v<XDataType, F16> && is_same_v<GammaDataType, F16> &&
is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>) is_same_v<BetaDataType, F16> && is_same_v<YDataType, F16>)
{ {
...@@ -82,8 +83,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -82,8 +83,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
add_device_normalization_rank_5_3_f16_instances(op_ptrs); add_device_normalization_rank_5_3_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> && #endif
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>) #ifdef __fp32__
if constexpr(is_same_v<XDataType, F32> && is_same_v<GammaDataType, F32> &&
is_same_v<BetaDataType, F32> && is_same_v<YDataType, F32>)
{ {
if constexpr(Rank == 2 && NumReduceDim == 1) if constexpr(Rank == 2 && NumReduceDim == 1)
{ {
...@@ -98,7 +101,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal ...@@ -98,7 +101,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
add_device_normalization_rank_5_3_f32_instances(op_ptrs); add_device_normalization_rank_5_3_f32_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 2; ...@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 2;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX; static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG; static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
#ifdef __fp16__
// FP16 // FP16
void add_device_pool2d_fwd_nhwc_f16_instances( void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector< std::vector<
...@@ -36,7 +36,8 @@ void add_device_pool2d_fwd_nhwc_f16_instances( ...@@ -36,7 +36,8 @@ void add_device_pool2d_fwd_nhwc_f16_instances(
void add_device_pool2d_fwd_nhwc_index_f16_instances( void add_device_pool2d_fwd_nhwc_index_f16_instances(
std::vector< std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&); std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
#endif
#ifdef __fp32__
// FP32 // FP32
void add_device_pool2d_fwd_nhwc_f32_instances( void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector< std::vector<
...@@ -50,7 +51,7 @@ void add_device_pool2d_fwd_nhwc_f32_instances( ...@@ -50,7 +51,7 @@ void add_device_pool2d_fwd_nhwc_f32_instances(
void add_device_pool2d_fwd_nhwc_index_f32_instances( void add_device_pool2d_fwd_nhwc_index_f32_instances(
std::vector< std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&); std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
#endif
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType, typename IndexDataType,
...@@ -75,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -75,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> && if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>) is_same_v<IndexDataType, I32>)
{ {
...@@ -88,8 +89,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -88,8 +89,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs); add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> && #endif
is_same_v<IndexDataType, I32>) #ifdef __fp32__
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{ {
if constexpr(OutputIndex && ReduceOpId == MaxOp) if constexpr(OutputIndex && ReduceOpId == MaxOp)
{ {
...@@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs); add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3; ...@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX; static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG; static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
#ifdef __fp16__
// FP16 // FP16
void add_device_pool3d_fwd_ndhwc_f16_instances( void add_device_pool3d_fwd_ndhwc_f16_instances(
std::vector< std::vector<
...@@ -36,7 +36,8 @@ void add_device_pool3d_fwd_ndhwc_f16_instances( ...@@ -36,7 +36,8 @@ void add_device_pool3d_fwd_ndhwc_f16_instances(
void add_device_pool3d_fwd_ndhwc_index_f16_instances( void add_device_pool3d_fwd_ndhwc_index_f16_instances(
std::vector< std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&); std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, MaxOp, true>>>&);
#endif
#ifdef __fp32__
// FP32 // FP32
void add_device_pool3d_fwd_ndhwc_f32_instances( void add_device_pool3d_fwd_ndhwc_f32_instances(
std::vector< std::vector<
...@@ -50,7 +51,7 @@ void add_device_pool3d_fwd_ndhwc_f32_instances( ...@@ -50,7 +51,7 @@ void add_device_pool3d_fwd_ndhwc_f32_instances(
void add_device_pool3d_fwd_ndhwc_index_f32_instances( void add_device_pool3d_fwd_ndhwc_index_f32_instances(
std::vector< std::vector<
std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&); std::unique_ptr<DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, MaxOp, true>>>&);
#endif
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType, typename IndexDataType,
...@@ -75,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -75,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef __fp16__
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> && if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>) is_same_v<IndexDataType, I32>)
{ {
...@@ -88,8 +89,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -88,8 +89,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs); add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> && #endif
is_same_v<IndexDataType, I32>) #ifdef __fp32__
if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{ {
if constexpr(OutputIndex && ReduceOpId == MaxOp) if constexpr(OutputIndex && ReduceOpId == MaxOp)
{ {
...@@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs); add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
#ifdef DL_KERNELS
// Layout(A, B, C) = [Col, Row, Row] // Layout(A, B, C) = [Col, Row, Row]
void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances( void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col, std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
...@@ -76,7 +76,7 @@ void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( ...@@ -76,7 +76,7 @@ void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
PassThrough, PassThrough,
Activation_Mul_Clamp<PassThrough>>>>& Activation_Mul_Clamp<PassThrough>>>>&
instances); instances);
#endif
// Layout(A, B, C) = [Col, Row, Row] // Layout(A, B, C) = [Col, Row, Row]
void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col, std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
...@@ -181,7 +181,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -181,7 +181,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
{ {
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
} }
} }
...@@ -190,7 +192,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -190,7 +192,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
{ {
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
} }
} }
...@@ -199,7 +203,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -199,7 +203,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
{ {
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
} }
} }
...@@ -208,7 +214,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -208,7 +214,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{ {
if constexpr(is_same_v<Activation, PassThrough>) if constexpr(is_same_v<Activation, PassThrough>)
{ {
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
} }
} }
...@@ -222,3 +230,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -222,3 +230,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
\ No newline at end of file
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
) )
endif()
\ No newline at end of file
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_bias_permute_instance add_instance_library(device_batched_gemm_bias_permute_instance
device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp
) )
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_gemm_instance add_instance_library(device_batched_gemm_gemm_instance
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
) )
endif()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment