Commit 46233520 authored by ltqin's avatar ltqin
Browse files

add device creator

parent ad45a6cd
...@@ -29,6 +29,8 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins ...@@ -29,6 +29,8 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins
}); });
} }
template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceCreator;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_fp16_gmk_gnk_gno_gmo_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_fp16_gmk_gnk_gno_gmo_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_bf16_gmk_gnk_gno_gmo_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_bf16_gmk_gnk_gno_gmo_instance.hpp"
...@@ -55,24 +54,25 @@ void add_device_instances( ...@@ -55,24 +54,25 @@ void add_device_instances(
C1DEElementwiseOperation, C1DEElementwiseOperation,
MaskingSpec>>>& instances) MaskingSpec>>>& instances)
{ {
add_device_operation_instances(instances, using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
create_device_instances<NumDimG, NumDimM,
NumDimM, NumDimN,
NumDimN, NumDimK,
NumDimK, NumDimO,
NumDimO, ADataType,
ADataType, B0DataType,
B0DataType, B1DataType,
B1DataType, CDataType,
CDataType, Acc0BiasDataType,
Acc0BiasDataType, Acc1BiasDataType,
Acc1BiasDataType, AElementwiseOperation,
AElementwiseOperation, B0ElementwiseOperation,
B0ElementwiseOperation, C0DEElementwiseOperation,
C0DEElementwiseOperation, B1ElementwiseOperation,
B1ElementwiseOperation, C1DEElementwiseOperation,
C1DEElementwiseOperation, MaskingSpec>;
MaskingSpec>()); add_device_operation_instances(
instances, DeviceOperationInstanceCreator<DeviceOp>::create_device_instances());
} }
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,8 +14,6 @@ namespace tensor_operation { ...@@ -14,8 +14,6 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
namespace bf16_data { namespace bf16_data {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -65,40 +63,53 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo ...@@ -65,40 +63,53 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo
// clang-format on // clang-format on
>; >;
} // namespace bf16_data } // namespace bf16_data
template < template <index_t NumDimG,
index_t NumDimG, index_t NumDimM,
index_t NumDimM, index_t NumDimN,
index_t NumDimN, index_t NumDimK,
index_t NumDimK, index_t NumDimO,
index_t NumDimO, typename Acc0BiasDataType,
typename ADataType, typename Acc1BiasDataType,
typename B0DataType, typename AElementwiseOperation,
typename B1DataType, typename B0ElementwiseOperation,
typename CDataType, typename C0DEElementwiseOperation,
typename Acc0BiasDataType, typename B1ElementwiseOperation,
typename Acc1BiasDataType, typename C1DEElementwiseOperation,
typename AElementwiseOperation, MaskingSpecialization MaskingSpec>
typename B0ElementwiseOperation, struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
typename C0DEElementwiseOperation, NumDimM,
typename B1ElementwiseOperation, NumDimN,
typename C1DEElementwiseOperation, NumDimK,
MaskingSpecialization MaskingSpec, NumDimO,
typename enable_if<is_same<remove_cvref_t<ADataType>, ck::bhalf_t>::value, bool>::type = false> ck::bhalf_t,
auto create_device_instances() ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>
{ {
return bf16_data:: static auto create_device_instances()
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances< {
NumDimG, return bf16_data::
NumDimM, device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances<
NumDimN, NumDimG,
NumDimK, NumDimM,
NumDimO, NumDimN,
ADataType, NumDimK,
F32, NumDimO,
Acc0BiasDataType, ck::bhalf_t,
C0DEElementwiseOperation, float,
MaskingSpec>{}; Acc0BiasDataType,
} C0DEElementwiseOperation,
MaskingSpec>{};
}
};
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -14,8 +14,6 @@ namespace tensor_operation { ...@@ -14,8 +14,6 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
namespace fp16_data { namespace fp16_data {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -65,40 +63,53 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo ...@@ -65,40 +63,53 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo
// clang-format on // clang-format on
>; >;
} // namespace fp16_data } // namespace fp16_data
template < template <index_t NumDimG,
index_t NumDimG, index_t NumDimM,
index_t NumDimM, index_t NumDimN,
index_t NumDimN, index_t NumDimK,
index_t NumDimK, index_t NumDimO,
index_t NumDimO, typename Acc0BiasDataType,
typename ADataType, typename Acc1BiasDataType,
typename B0DataType, typename AElementwiseOperation,
typename B1DataType, typename B0ElementwiseOperation,
typename CDataType, typename C0DEElementwiseOperation,
typename Acc0BiasDataType, typename B1ElementwiseOperation,
typename Acc1BiasDataType, typename C1DEElementwiseOperation,
typename AElementwiseOperation, MaskingSpecialization MaskingSpec>
typename B0ElementwiseOperation, struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
typename C0DEElementwiseOperation, NumDimM,
typename B1ElementwiseOperation, NumDimN,
typename C1DEElementwiseOperation, NumDimK,
MaskingSpecialization MaskingSpec, NumDimO,
typename enable_if<is_same<remove_cvref_t<ADataType>, ck::half_t>::value, bool>::type = false> ck::half_t,
auto create_device_instances() ck::half_t,
ck::half_t,
ck::half_t,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>
{ {
return fp16_data:: static auto create_device_instances()
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances< {
NumDimG, return fp16_data::
NumDimM, device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances<
NumDimN, NumDimG,
NumDimK, NumDimM,
NumDimO, NumDimN,
ADataType, NumDimK,
F32, NumDimO,
Acc0BiasDataType, ck::half_t,
C0DEElementwiseOperation, float,
MaskingSpec>{}; Acc0BiasDataType,
} C0DEElementwiseOperation,
MaskingSpec>{};
}
};
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment