Commit 46233520 authored by ltqin's avatar ltqin
Browse files

add device creator

parent ad45a6cd
......@@ -29,6 +29,8 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins
});
}
template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceCreator;
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -10,7 +10,6 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_fp16_gmk_gnk_gno_gmo_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_bf16_gmk_gnk_gno_gmo_instance.hpp"
......@@ -55,24 +54,25 @@ void add_device_instances(
C1DEElementwiseOperation,
MaskingSpec>>>& instances)
{
add_device_operation_instances(instances,
create_device_instances<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>());
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>;
add_device_operation_instances(
instances, DeviceOperationInstanceCreator<DeviceOp>::create_device_instances());
}
} // namespace instance
} // namespace device
......
......@@ -14,8 +14,6 @@ namespace tensor_operation {
namespace device {
namespace instance {
namespace bf16_data {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -65,40 +63,53 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo
// clang-format on
>;
} // namespace bf16_data
template <
index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec,
typename enable_if<is_same<remove_cvref_t<ADataType>, ck::bhalf_t>::value, bool>::type = false>
auto create_device_instances()
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>
{
return bf16_data::
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
F32,
Acc0BiasDataType,
C0DEElementwiseOperation,
MaskingSpec>{};
}
static auto create_device_instances()
{
return bf16_data::
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ck::bhalf_t,
float,
Acc0BiasDataType,
C0DEElementwiseOperation,
MaskingSpec>{};
}
};
} // namespace instance
} // namespace device
......
......@@ -14,8 +14,6 @@ namespace tensor_operation {
namespace device {
namespace instance {
namespace fp16_data {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -65,40 +63,53 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo
// clang-format on
>;
} // namespace fp16_data
template <
index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec,
typename enable_if<is_same<remove_cvref_t<ADataType>, ck::half_t>::value, bool>::type = false>
auto create_device_instances()
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ck::half_t,
ck::half_t,
ck::half_t,
ck::half_t,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>
{
return fp16_data::
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
F32,
Acc0BiasDataType,
C0DEElementwiseOperation,
MaskingSpec>{};
}
static auto create_device_instances()
{
return fp16_data::
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ck::half_t,
float,
Acc0BiasDataType,
C0DEElementwiseOperation,
MaskingSpec>{};
}
};
} // namespace instance
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment