Commit 508b643f authored by ltqin's avatar ltqin
Browse files

add ArchitectureEnumSequence

parent 8dd5c549
......@@ -30,14 +30,35 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins
}
enum struct ArchitectureEnum
{
None,
Gfx908,
Gfx90a,
Gfx1030,
All
};
enum struct GemmFeatureEnum
{
Xdl,
Dl
Dl,
Wmd
};
template <ArchitectureEnum... Is>
struct ArchitectureEnumSequence
{
static constexpr int mSize = sizeof...(Is);
__host__ __device__ static constexpr ArchitectureEnum At(int I)
{
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const ArchitectureEnum mData[mSize + 1] = {Is..., ArchitectureEnum::None};
return mData[I];
}
};
template <typename DeviceOp, ArchitectureEnum Arch = ArchitectureEnum::Xdl>
template <typename DeviceOp, GemmFeatureEnum Feature>
struct DeviceOperationInstances;
template <typename DeviceOp, ArchitectureEnum Arch>
template <typename DeviceOp, typename Arch>
struct DeviceOperationInstanceCreator;
} // namespace instance
} // namespace device
......
......@@ -35,7 +35,7 @@ template <index_t NumDimG,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec,
ArchitectureEnum Arch>
typename Arch>
struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
......@@ -74,8 +74,11 @@ struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDim
MaskingSpec>;
static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances)
{
if constexpr(DeviceOperationInstances<DeviceOp,
GemmFeatureEnum::Xdl>::template is_surport<Arch>())
add_device_operation_instances(
instances, DeviceOperationInstances<DeviceOp, Arch>::get_device_instances());
instances,
DeviceOperationInstances<DeviceOp, GemmFeatureEnum::Xdl>::get_device_instances());
}
};
......
......@@ -44,7 +44,7 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>,
ArchitectureEnum::Xdl>
GemmFeatureEnum::Xdl>
{
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -88,6 +88,18 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
// clang-format on
>;
template <typename Archs>
static constexpr bool is_surport()
{
bool is_surport = false;
ck::static_for<0, Archs::mSize, 1>{}([&](auto I) {
if constexpr(Archs::At(I) == ArchitectureEnum::All ||
Archs::At(I) == ArchitectureEnum::Gfx908 ||
Archs::At(I) == ArchitectureEnum::Gfx90a)
is_surport = true;
});
return is_surport;
}
static auto get_device_instances()
{
return device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances{};
......
......@@ -44,7 +44,7 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
B1ElementwiseOperation,
C1DEElementwiseOperation,
MaskingSpec>,
ArchitectureEnum::Xdl>
GemmFeatureEnum::Xdl>
{
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -89,6 +89,18 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
// clang-format on
>;
template <typename Archs>
static constexpr bool is_surport()
{
bool is_surport = false;
ck::static_for<0, Archs::mSize, 1>{}([&](auto I) {
if constexpr(Archs::At(I) == ArchitectureEnum::All ||
Archs::At(I) == ArchitectureEnum::Gfx908 ||
Archs::At(I) == ArchitectureEnum::Gfx90a)
is_surport = true;
});
return is_surport;
}
static auto get_device_instances()
{
return device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances{};
......
......@@ -66,8 +66,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
add_device_instances(instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
......@@ -108,8 +108,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
add_device_instances(instances);
}
} // namespace instance
......
......@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
add_device_instances(instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
......@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
add_device_instances(instances);
}
} // namespace instance
......
......@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
add_device_instances(instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
......@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
add_device_instances(instances);
}
} // namespace instance
......
......@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
add_device_instances(instances);
}
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector<
......@@ -106,8 +106,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances(
instances);
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
add_device_instances(instances);
}
} // namespace instance
......
......@@ -137,7 +137,7 @@ int main()
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
ck::tensor_operation::device::instance::DeviceOperationInstanceCreator<
DeviceOp,
ck::tensor_operation::device::instance::ArchitectureEnum::Xdl>::
ck::tensor_operation::device::instance::GemmFeatureEnum::Xdl>::
add_device_instances(op_ptrs);
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment