Commit 1cda3b80 authored by ltqin's avatar ltqin
Browse files

add builder

parent 7b73260c
......@@ -36,6 +36,9 @@ enum struct ArchitectureEnum
};
template <typename DeviceOp, ArchitectureEnum Arch = ArchitectureEnum::Xdl>
struct DeviceOperationInstanceCreator;
template <typename DeviceOp, ArchitectureEnum Arch>
struct DeviceOperationInstanceBuilder;
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -35,9 +35,8 @@ template <index_t NumDimG,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec,
ArchitectureEnum Arch = ArchitectureEnum::Xdl>
void add_device_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
ArchitectureEnum Arch>
struct DeviceOperationInstanceBuilder<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
......@@ -53,7 +52,8 @@ void add_device_instances(
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>>& instances)
MaskingSpec>,
Arch>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
......@@ -72,9 +72,13 @@ void add_device_instances(
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>;
static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances)
{
add_device_operation_instances(
instances, DeviceOperationInstanceCreator<DeviceOp, Arch>::create_device_instances());
}
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -47,7 +47,27 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_instances(instances);
using DeviceOp =
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
......@@ -71,7 +91,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskDisabled>>>&
instances)
{
add_device_instances(instances);
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
} // namespace instance
......
......@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_instances(instances);
using DeviceOp =
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
......@@ -71,7 +90,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskDisabled>>>&
instances)
{
add_device_instances(instances);
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
} // namespace instance
......
......@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_instances(instances);
using DeviceOp =
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
......@@ -71,7 +90,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskDisabled>>>&
instances)
{
add_device_instances(instances);
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
} // namespace instance
......
......@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_instances(instances);
using DeviceOp =
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
......@@ -70,7 +89,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskDisabled>>>&
instances)
{
add_device_instances(instances);
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
} // namespace instance
......
......@@ -135,25 +135,10 @@ int main()
// get device op instances
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
ck::tensor_operation::device::instance::add_device_instances<
2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<D00DataType, D01DataType>,
ck::Tuple<>,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
MaskingSpec,
ck::tensor_operation::device::instance::ArchitectureEnum::Xdl>(op_ptrs);
ck::tensor_operation::device::instance::DeviceOperationInstanceBuilder<
DeviceOp,
ck::tensor_operation::device::instance::ArchitectureEnum::Xdl>::
add_device_instances(op_ptrs);
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment