Commit 661b0454 authored by ltqin's avatar ltqin
Browse files

change add device instances function name

parent 5595f635
...@@ -134,8 +134,8 @@ int main() ...@@ -134,8 +134,8 @@ int main()
MaskingSpec>; MaskingSpec>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
DeviceOp>::GetInstances(); ck::tensor_operation::device::instance::add_device_instances(op_ptrs);
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -19,29 +19,26 @@ namespace device { ...@@ -19,29 +19,26 @@ namespace device {
namespace instance { namespace instance {
extern template void extern template void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
std::vector<std::unique_ptr< 2,
DeviceBatchedGemmSoftmaxGemmPermute<2, 1,
1, 1,
1, 1,
1, 1,
1, F16,
F16, F16,
F16, F16,
F16, F16,
F16, ck::Tuple<F16>,
ck::Tuple<F16>, ck::Tuple<>,
ck::Tuple<>, PassThrough,
PassThrough, PassThrough,
PassThrough, ScaleAdd,
ScaleAdd, PassThrough,
PassThrough, PassThrough,
PassThrough, MaskingSpecialization::MaskOutUpperTriangle>>>& instances);
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
extern template void extern template void add_device_instances(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
...@@ -63,29 +60,26 @@ add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_ ...@@ -63,29 +60,26 @@ add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
instances); instances);
extern template void extern template void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
std::vector<std::unique_ptr< 2,
DeviceBatchedGemmSoftmaxGemmPermute<2, 1,
1, 1,
1, 1,
1, 1,
1, BF16,
BF16, BF16,
BF16, BF16,
BF16, BF16,
BF16, ck::Tuple<BF16>,
ck::Tuple<BF16>, ck::Tuple<>,
ck::Tuple<>, PassThrough,
PassThrough, PassThrough,
PassThrough, ScaleAdd,
ScaleAdd, PassThrough,
PassThrough, PassThrough,
PassThrough, MaskingSpecialization::MaskOutUpperTriangle>>>& instances);
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
extern template void extern template void add_device_instances(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
...@@ -107,29 +101,26 @@ add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_ ...@@ -107,29 +101,26 @@ add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
instances); instances);
extern template void extern template void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
std::vector<std::unique_ptr< 2,
DeviceBatchedGemmSoftmaxGemmPermute<2, 1,
1, 1,
1, 1,
1, 1,
1, F16,
F16, F16,
F16, F16,
F16, F16,
F16, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, PassThrough,
PassThrough, PassThrough,
PassThrough, Scale,
Scale, PassThrough,
PassThrough, PassThrough,
PassThrough, MaskingSpecialization::MaskOutUpperTriangle>>>& instances);
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
extern template void extern template void add_device_instances(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
...@@ -151,29 +142,26 @@ add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_ ...@@ -151,29 +142,26 @@ add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
instances); instances);
extern template void extern template void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
std::vector<std::unique_ptr< 2,
DeviceBatchedGemmSoftmaxGemmPermute<2, 1,
1, 1,
1, 1,
1, 1,
1, BF16,
BF16, BF16,
BF16, BF16,
BF16, BF16,
BF16, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, PassThrough,
PassThrough, PassThrough,
PassThrough, Scale,
Scale, PassThrough,
PassThrough, PassThrough,
PassThrough, MaskingSpecialization::MaskOutUpperTriangle>>>& instances);
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
extern template void extern template void add_device_instances(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
...@@ -249,8 +237,7 @@ struct DeviceOperationInstanceFactory< ...@@ -249,8 +237,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( add_device_instances(op_ptrs);
op_ptrs);
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -35,50 +35,44 @@ template <index_t NumDimG, ...@@ -35,50 +35,44 @@ template <index_t NumDimG,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename C1DEElementwiseOperation, typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceFactory< void add_device_instances(
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
ADataType, ADataType,
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
AElementwiseOperation, AElementwiseOperation,
B0ElementwiseOperation, B0ElementwiseOperation,
C0DEElementwiseOperation, C0DEElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
C1DEElementwiseOperation, C1DEElementwiseOperation,
MaskingSpec>> MaskingSpec>>>& instances)
{ {
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, add_device_operation_instances(instances,
NumDimM, create_device_instances<NumDimG,
NumDimN, NumDimM,
NumDimK, NumDimN,
NumDimO, NumDimK,
ADataType, NumDimO,
B0DataType, ADataType,
B1DataType, B0DataType,
CDataType, B1DataType,
Acc0BiasDataType, CDataType,
Acc1BiasDataType, Acc0BiasDataType,
AElementwiseOperation, Acc1BiasDataType,
B0ElementwiseOperation, AElementwiseOperation,
C0DEElementwiseOperation, B0ElementwiseOperation,
B1ElementwiseOperation, C0DEElementwiseOperation,
C1DEElementwiseOperation, B1ElementwiseOperation,
MaskingSpec>; C1DEElementwiseOperation,
static auto GetInstances() MaskingSpec>());
{ }
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
op_ptrs);
return op_ptrs;
}
};
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -85,38 +85,19 @@ template <index_t NumDimG, ...@@ -85,38 +85,19 @@ template <index_t NumDimG,
typename enable_if<is_same<remove_cvref_t<ADataType>, ck::half_t>::value || typename enable_if<is_same<remove_cvref_t<ADataType>, ck::half_t>::value ||
is_same<remove_cvref_t<ADataType>, ck::bhalf_t>::value, is_same<remove_cvref_t<ADataType>, ck::bhalf_t>::value,
bool>::type = false> bool>::type = false>
void add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( auto create_device_instances()
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>>& instances)
{ {
add_device_operation_instances( return device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances<
instances, NumDimG,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances< NumDimM,
NumDimG, NumDimN,
NumDimM, NumDimK,
NumDimN, NumDimO,
NumDimK, ADataType,
NumDimO, F32,
ADataType, Acc0BiasDataType,
F32, C0DEElementwiseOperation,
Acc0BiasDataType, MaskingSpec>{};
C0DEElementwiseOperation,
MaskingSpec>{});
} }
} // namespace instance } // namespace instance
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>; ...@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
template void template void add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( 2,
std::vector<std::unique_ptr< 1,
DeviceBatchedGemmSoftmaxGemmPermute<2, 1,
1, 1,
1, 1,
1, BF16,
1, BF16,
BF16, BF16,
BF16, BF16,
BF16, ck::Tuple<BF16>,
BF16, ck::Tuple<>,
ck::Tuple<BF16>, PassThrough,
ck::Tuple<>, PassThrough,
PassThrough, ScaleAdd,
PassThrough, PassThrough,
ScaleAdd, PassThrough,
PassThrough, MaskingSpecialization::MaskOutUpperTriangle>>>& instances);
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
template void template void add_device_instances(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>; ...@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
template void template void add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( 2,
std::vector<std::unique_ptr< 1,
DeviceBatchedGemmSoftmaxGemmPermute<2, 1,
1, 1,
1, 1,
1, F16,
1, F16,
F16, F16,
F16, F16,
F16, ck::Tuple<F16>,
F16, ck::Tuple<>,
ck::Tuple<F16>, PassThrough,
ck::Tuple<>, PassThrough,
PassThrough, ScaleAdd,
PassThrough, PassThrough,
ScaleAdd, PassThrough,
PassThrough, MaskingSpecialization::MaskOutUpperTriangle>>>& instances);
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
template void template void add_device_instances(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>; ...@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
template void template void add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( 2,
std::vector<std::unique_ptr< 1,
DeviceBatchedGemmSoftmaxGemmPermute<2, 1,
1, 1,
1, 1,
1, BF16,
1, BF16,
BF16, BF16,
BF16, BF16,
BF16, ck::Tuple<>,
BF16, ck::Tuple<>,
ck::Tuple<>, PassThrough,
ck::Tuple<>, PassThrough,
PassThrough, Scale,
PassThrough, PassThrough,
Scale, PassThrough,
PassThrough, MaskingSpecialization::MaskOutUpperTriangle>>>& instances);
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
template void template void add_device_instances(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck { namespace ck {
...@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>; ...@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
template void template void add_device_instances(std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( 2,
std::vector<std::unique_ptr< 1,
DeviceBatchedGemmSoftmaxGemmPermute<2, 1,
1, 1,
1, 1,
1, F16,
1, F16,
F16, F16,
F16, F16,
F16, ck::Tuple<>,
F16, ck::Tuple<>,
ck::Tuple<>, PassThrough,
ck::Tuple<>, PassThrough,
PassThrough, Scale,
PassThrough, PassThrough,
Scale, PassThrough,
PassThrough, MaskingSpecialization::MaskOutUpperTriangle>>>& instances);
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
template void template void add_device_instances(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2, std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1, 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment