Commit 508b643f authored by ltqin's avatar ltqin
Browse files

add ArchitectureEnumSequence

parent 8dd5c549
...@@ -30,14 +30,35 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins ...@@ -30,14 +30,35 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins
} }
enum struct ArchitectureEnum enum struct ArchitectureEnum
{
None,
Gfx908,
Gfx90a,
Gfx1030,
All
};
enum struct GemmFeatureEnum
{ {
Xdl, Xdl,
Dl Dl,
Wmd
};
template <ArchitectureEnum... Is>
struct ArchitectureEnumSequence
{
static constexpr int mSize = sizeof...(Is);
__host__ __device__ static constexpr ArchitectureEnum At(int I)
{
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const ArchitectureEnum mData[mSize + 1] = {Is..., ArchitectureEnum::None};
return mData[I];
}
}; };
template <typename DeviceOp, ArchitectureEnum Arch = ArchitectureEnum::Xdl> template <typename DeviceOp, GemmFeatureEnum Feature>
struct DeviceOperationInstances; struct DeviceOperationInstances;
template <typename DeviceOp, ArchitectureEnum Arch> template <typename DeviceOp, typename Arch>
struct DeviceOperationInstanceCreator; struct DeviceOperationInstanceCreator;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -35,7 +35,7 @@ template <index_t NumDimG, ...@@ -35,7 +35,7 @@ template <index_t NumDimG,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename C1DEElementwiseOperation, typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
ArchitectureEnum Arch> typename Arch>
struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -74,8 +74,11 @@ struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDim ...@@ -74,8 +74,11 @@ struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDim
MaskingSpec>; MaskingSpec>;
static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances) static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances)
{ {
add_device_operation_instances( if constexpr(DeviceOperationInstances<DeviceOp,
instances, DeviceOperationInstances<DeviceOp, Arch>::get_device_instances()); GemmFeatureEnum::Xdl>::template is_surport<Arch>())
add_device_operation_instances(
instances,
DeviceOperationInstances<DeviceOp, GemmFeatureEnum::Xdl>::get_device_instances());
} }
}; };
......
...@@ -44,7 +44,7 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, ...@@ -44,7 +44,7 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
B1ElementwiseOperation, B1ElementwiseOperation,
C1DEElementwiseOperation, C1DEElementwiseOperation,
MaskingSpec>, MaskingSpec>,
ArchitectureEnum::Xdl> GemmFeatureEnum::Xdl>
{ {
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -88,6 +88,18 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, ...@@ -88,6 +88,18 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
// clang-format on // clang-format on
>; >;
template <typename Archs>
static constexpr bool is_surport()
{
bool is_surport = false;
ck::static_for<0, Archs::mSize, 1>{}([&](auto I) {
if constexpr(Archs::At(I) == ArchitectureEnum::All ||
Archs::At(I) == ArchitectureEnum::Gfx908 ||
Archs::At(I) == ArchitectureEnum::Gfx90a)
is_surport = true;
});
return is_surport;
}
static auto get_device_instances() static auto get_device_instances()
{ {
return device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances{}; return device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances{};
......
...@@ -44,7 +44,7 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, ...@@ -44,7 +44,7 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
B1ElementwiseOperation, B1ElementwiseOperation,
C1DEElementwiseOperation, C1DEElementwiseOperation,
MaskingSpec>, MaskingSpec>,
ArchitectureEnum::Xdl> GemmFeatureEnum::Xdl>
{ {
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -89,6 +89,18 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, ...@@ -89,6 +89,18 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
// clang-format on // clang-format on
>; >;
template <typename Archs>
static constexpr bool is_surport()
{
bool is_surport = false;
ck::static_for<0, Archs::mSize, 1>{}([&](auto I) {
if constexpr(Archs::At(I) == ArchitectureEnum::All ||
Archs::At(I) == ArchitectureEnum::Gfx908 ||
Archs::At(I) == ArchitectureEnum::Gfx90a)
is_surport = true;
});
return is_surport;
}
static auto get_device_instances() static auto get_device_instances()
{ {
return device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances{}; return device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances{};
......
...@@ -66,8 +66,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -66,8 +66,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances( DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
instances); add_device_instances(instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
...@@ -108,8 +108,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -108,8 +108,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>; MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances( DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
instances); add_device_instances(instances);
} }
} // namespace instance } // namespace instance
......
...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances( DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
instances); add_device_instances(instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
...@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>; MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances( DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
instances); add_device_instances(instances);
} }
} // namespace instance } // namespace instance
......
...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances( DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
instances); add_device_instances(instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
...@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>; MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances( DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
instances); add_device_instances(instances);
} }
} // namespace instance } // namespace instance
......
...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances( DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
instances); add_device_instances(instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
...@@ -106,8 +106,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -106,8 +106,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>; MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnum::Xdl>::add_device_instances( DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>::
instances); add_device_instances(instances);
} }
} // namespace instance } // namespace instance
......
...@@ -137,7 +137,7 @@ int main() ...@@ -137,7 +137,7 @@ int main()
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
ck::tensor_operation::device::instance::DeviceOperationInstanceCreator< ck::tensor_operation::device::instance::DeviceOperationInstanceCreator<
DeviceOp, DeviceOp,
ck::tensor_operation::device::instance::ArchitectureEnum::Xdl>:: ck::tensor_operation::device::instance::GemmFeatureEnum::Xdl>::
add_device_instances(op_ptrs); add_device_instances(op_ptrs);
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment