Commit 1cda3b80 authored by ltqin's avatar ltqin
Browse files

add builder

parent 7b73260c
...@@ -36,6 +36,9 @@ enum struct ArchitectureEnum ...@@ -36,6 +36,9 @@ enum struct ArchitectureEnum
}; };
template <typename DeviceOp, ArchitectureEnum Arch = ArchitectureEnum::Xdl> template <typename DeviceOp, ArchitectureEnum Arch = ArchitectureEnum::Xdl>
struct DeviceOperationInstanceCreator; struct DeviceOperationInstanceCreator;
template <typename DeviceOp, ArchitectureEnum Arch>
struct DeviceOperationInstanceBuilder;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -35,25 +35,25 @@ template <index_t NumDimG, ...@@ -35,25 +35,25 @@ template <index_t NumDimG,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename C1DEElementwiseOperation, typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
ArchitectureEnum Arch = ArchitectureEnum::Xdl> ArchitectureEnum Arch>
void add_device_instances( struct DeviceOperationInstanceBuilder<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, NumDimM,
NumDimM, NumDimN,
NumDimN, NumDimK,
NumDimK, NumDimO,
NumDimO, ADataType,
ADataType, B0DataType,
B0DataType, B1DataType,
B1DataType, CDataType,
CDataType, Acc0BiasDataType,
Acc0BiasDataType, Acc1BiasDataType,
Acc1BiasDataType, AElementwiseOperation,
AElementwiseOperation, B0ElementwiseOperation,
B0ElementwiseOperation, C0DEElementwiseOperation,
C0DEElementwiseOperation, B1ElementwiseOperation,
B1ElementwiseOperation, C1DEElementwiseOperation,
C1DEElementwiseOperation, MaskingSpec>,
MaskingSpec>>>& instances) Arch>
{ {
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
...@@ -72,9 +72,13 @@ void add_device_instances( ...@@ -72,9 +72,13 @@ void add_device_instances(
B1ElementwiseOperation, B1ElementwiseOperation,
C1DEElementwiseOperation, C1DEElementwiseOperation,
MaskingSpec>; MaskingSpec>;
add_device_operation_instances( static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances)
instances, DeviceOperationInstanceCreator<DeviceOp, Arch>::create_device_instances()); {
} add_device_operation_instances(
instances, DeviceOperationInstanceCreator<DeviceOp, Arch>::create_device_instances());
}
};
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -47,7 +47,27 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -47,7 +47,27 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskOutUpperTriangle>>>& MaskingSpecialization::MaskOutUpperTriangle>>>&
instances) instances)
{ {
add_device_instances(instances); using DeviceOp =
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
...@@ -71,7 +91,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -71,7 +91,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances) instances)
{ {
add_device_instances(instances); using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
} }
} // namespace instance } // namespace instance
......
...@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskOutUpperTriangle>>>& MaskingSpecialization::MaskOutUpperTriangle>>>&
instances) instances)
{ {
add_device_instances(instances); using DeviceOp =
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
...@@ -71,7 +90,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -71,7 +90,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances) instances)
{ {
add_device_instances(instances); using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
} }
} // namespace instance } // namespace instance
......
...@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskOutUpperTriangle>>>& MaskingSpecialization::MaskOutUpperTriangle>>>&
instances) instances)
{ {
add_device_instances(instances); using DeviceOp =
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
...@@ -71,7 +90,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -71,7 +90,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances) instances)
{ {
add_device_instances(instances); using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
} }
} // namespace instance } // namespace instance
......
...@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskOutUpperTriangle>>>& MaskingSpecialization::MaskOutUpperTriangle>>>&
instances) instances)
{ {
add_device_instances(instances); using DeviceOp =
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
...@@ -70,7 +89,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -70,7 +89,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>&
instances) instances)
{ {
add_device_instances(instances); using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
} }
} // namespace instance } // namespace instance
......
...@@ -135,25 +135,10 @@ int main() ...@@ -135,25 +135,10 @@ int main()
// get device op instances // get device op instances
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
ck::tensor_operation::device::instance::add_device_instances< ck::tensor_operation::device::instance::DeviceOperationInstanceBuilder<
2, DeviceOp,
1, ck::tensor_operation::device::instance::ArchitectureEnum::Xdl>::
1, add_device_instances(op_ptrs);
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<D00DataType, D01DataType>,
ck::Tuple<>,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
MaskingSpec,
ck::tensor_operation::device::instance::ArchitectureEnum::Xdl>(op_ptrs);
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment