Commit 46233520 authored by ltqin's avatar ltqin
Browse files

add device creator

parent ad45a6cd
...@@ -29,6 +29,8 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins ...@@ -29,6 +29,8 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins
}); });
} }
template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceCreator;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_fp16_gmk_gnk_gno_gmo_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_fp16_gmk_gnk_gno_gmo_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_bf16_gmk_gnk_gno_gmo_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_bf16_gmk_gnk_gno_gmo_instance.hpp"
...@@ -55,8 +54,7 @@ void add_device_instances( ...@@ -55,8 +54,7 @@ void add_device_instances(
C1DEElementwiseOperation, C1DEElementwiseOperation,
MaskingSpec>>>& instances) MaskingSpec>>>& instances)
{ {
add_device_operation_instances(instances, using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
create_device_instances<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
...@@ -72,7 +70,9 @@ void add_device_instances( ...@@ -72,7 +70,9 @@ void add_device_instances(
C0DEElementwiseOperation, C0DEElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
C1DEElementwiseOperation, C1DEElementwiseOperation,
MaskingSpec>()); MaskingSpec>;
add_device_operation_instances(
instances, DeviceOperationInstanceCreator<DeviceOp>::create_device_instances());
} }
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,8 +14,6 @@ namespace tensor_operation { ...@@ -14,8 +14,6 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
namespace bf16_data { namespace bf16_data {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -65,16 +63,11 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo ...@@ -65,16 +63,11 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo
// clang-format on // clang-format on
>; >;
} // namespace bf16_data } // namespace bf16_data
template < template <index_t NumDimG,
index_t NumDimG,
index_t NumDimM, index_t NumDimM,
index_t NumDimN, index_t NumDimN,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
typename Acc1BiasDataType, typename Acc1BiasDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -82,10 +75,27 @@ template < ...@@ -82,10 +75,27 @@ template <
typename C0DEElementwiseOperation, typename C0DEElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename C1DEElementwiseOperation, typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec>
typename enable_if<is_same<remove_cvref_t<ADataType>, ck::bhalf_t>::value, bool>::type = false> struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
auto create_device_instances() NumDimM,
NumDimN,
NumDimK,
NumDimO,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>
{ {
static auto create_device_instances()
{
return bf16_data:: return bf16_data::
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances< device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances<
NumDimG, NumDimG,
...@@ -93,12 +103,13 @@ auto create_device_instances() ...@@ -93,12 +103,13 @@ auto create_device_instances()
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
ADataType, ck::bhalf_t,
F32, float,
Acc0BiasDataType, Acc0BiasDataType,
C0DEElementwiseOperation, C0DEElementwiseOperation,
MaskingSpec>{}; MaskingSpec>{};
} }
};
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,8 +14,6 @@ namespace tensor_operation { ...@@ -14,8 +14,6 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
namespace fp16_data { namespace fp16_data {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -65,16 +63,11 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo ...@@ -65,16 +63,11 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo
// clang-format on // clang-format on
>; >;
} // namespace fp16_data } // namespace fp16_data
template < template <index_t NumDimG,
index_t NumDimG,
index_t NumDimM, index_t NumDimM,
index_t NumDimN, index_t NumDimN,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
typename Acc1BiasDataType, typename Acc1BiasDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -82,10 +75,27 @@ template < ...@@ -82,10 +75,27 @@ template <
typename C0DEElementwiseOperation, typename C0DEElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename C1DEElementwiseOperation, typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec>
typename enable_if<is_same<remove_cvref_t<ADataType>, ck::half_t>::value, bool>::type = false> struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
auto create_device_instances() NumDimM,
NumDimN,
NumDimK,
NumDimO,
ck::half_t,
ck::half_t,
ck::half_t,
ck::half_t,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>
{ {
static auto create_device_instances()
{
return fp16_data:: return fp16_data::
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances< device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances<
NumDimG, NumDimG,
...@@ -93,12 +103,13 @@ auto create_device_instances() ...@@ -93,12 +103,13 @@ auto create_device_instances()
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
ADataType, ck::half_t,
F32, float,
Acc0BiasDataType, Acc0BiasDataType,
C0DEElementwiseOperation, C0DEElementwiseOperation,
MaskingSpec>{}; MaskingSpec>{};
} }
};
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment