"vscode:/vscode.git/clone" did not exist on "364d59d13b64762c3a0e6ce9ebbe4226b8008ed3"
Commit 1cda3b80 authored by ltqin's avatar ltqin
Browse files

add builder

parent 7b73260c
......@@ -36,6 +36,9 @@ enum struct ArchitectureEnum
};
template <typename DeviceOp, ArchitectureEnum Arch = ArchitectureEnum::Xdl>
struct DeviceOperationInstanceCreator;
template <typename DeviceOp, ArchitectureEnum Arch>
struct DeviceOperationInstanceBuilder;
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -35,25 +35,25 @@ template <index_t NumDimG,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec,
ArchitectureEnum Arch = ArchitectureEnum::Xdl>
void add_device_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>>>& instances)
ArchitectureEnum Arch>
struct DeviceOperationInstanceBuilder<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
B0ElementwiseOperation,
C0DEElementwiseOperation,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>,
Arch>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
......@@ -72,9 +72,13 @@ void add_device_instances(
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>;
add_device_operation_instances(
instances, DeviceOperationInstanceCreator<DeviceOp, Arch>::create_device_instances());
}
static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances)
{
add_device_operation_instances(
instances, DeviceOperationInstanceCreator<DeviceOp, Arch>::create_device_instances());
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -47,7 +47,27 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_instances(instances);
using DeviceOp =
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
......@@ -71,7 +91,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskDisabled>>>&
instances)
{
add_device_instances(instances);
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
} // namespace instance
......
......@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_instances(instances);
using DeviceOp =
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
......@@ -71,7 +90,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskDisabled>>>&
instances)
{
add_device_instances(instances);
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
} // namespace instance
......
......@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_instances(instances);
using DeviceOp =
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
......@@ -71,7 +90,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskDisabled>>>&
instances)
{
add_device_instances(instances);
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
} // namespace instance
......
......@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_instances(instances);
using DeviceOp =
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
......@@ -70,7 +89,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization::MaskDisabled>>>&
instances)
{
add_device_instances(instances);
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceBuilder<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
}
} // namespace instance
......
......@@ -135,25 +135,10 @@ int main()
// get device op instances
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
ck::tensor_operation::device::instance::add_device_instances<
2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<D00DataType, D01DataType>,
ck::Tuple<>,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
MaskingSpec,
ck::tensor_operation::device::instance::ArchitectureEnum::Xdl>(op_ptrs);
ck::tensor_operation::device::instance::DeviceOperationInstanceBuilder<
DeviceOp,
ck::tensor_operation::device::instance::ArchitectureEnum::Xdl>::
add_device_instances(op_ptrs);
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment