"test/vscode:/vscode.git/clone" did not exist on "baaad9ec8e01954eaa1741381def3601fb2de35f"
Commit 5f4a0f73 authored by ltqin's avatar ltqin
Browse files

change arch position

parent 508b643f
...@@ -55,10 +55,10 @@ struct ArchitectureEnumSequence ...@@ -55,10 +55,10 @@ struct ArchitectureEnumSequence
return mData[I]; return mData[I];
} }
}; };
template <typename DeviceOp, GemmFeatureEnum Feature> template <GemmFeatureEnum Feature, typename DeviceOp>
struct DeviceOperationInstances; struct DeviceOperationInstances;
template <typename DeviceOp, typename Arch> template <typename Arch, typename DeviceOp>
struct DeviceOperationInstanceCreator; struct DeviceOperationInstanceCreator;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -18,7 +18,8 @@ namespace tensor_operation { ...@@ -18,7 +18,8 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
template <index_t NumDimG, template <typename Arch,
index_t NumDimG,
index_t NumDimM, index_t NumDimM,
index_t NumDimN, index_t NumDimN,
index_t NumDimK, index_t NumDimK,
...@@ -34,9 +35,9 @@ template <index_t NumDimG, ...@@ -34,9 +35,9 @@ template <index_t NumDimG,
typename C0DEElementwiseOperation, typename C0DEElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename C1DEElementwiseOperation, typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec>
typename Arch> struct DeviceOperationInstanceCreator<Arch,
struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
...@@ -52,8 +53,7 @@ struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDim ...@@ -52,8 +53,7 @@ struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDim
C0DEElementwiseOperation, C0DEElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
C1DEElementwiseOperation, C1DEElementwiseOperation,
MaskingSpec>, MaskingSpec>>
Arch>
{ {
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
...@@ -74,11 +74,11 @@ struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDim ...@@ -74,11 +74,11 @@ struct DeviceOperationInstanceCreator<DeviceBatchedGemmSoftmaxGemmPermute<NumDim
MaskingSpec>; MaskingSpec>;
static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances) static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances)
{ {
if constexpr(DeviceOperationInstances<DeviceOp, if constexpr(DeviceOperationInstances<GemmFeatureEnum::Xdl,
GemmFeatureEnum::Xdl>::template is_surport<Arch>()) DeviceOp>::template is_surport<Arch>())
add_device_operation_instances( add_device_operation_instances(
instances, instances,
DeviceOperationInstances<DeviceOp, GemmFeatureEnum::Xdl>::get_device_instances()); DeviceOperationInstances<GemmFeatureEnum::Xdl, DeviceOp>::get_device_instances());
} }
}; };
......
...@@ -27,7 +27,8 @@ template <index_t NumDimG, ...@@ -27,7 +27,8 @@ template <index_t NumDimG,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename C1DEElementwiseOperation, typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, struct DeviceOperationInstances<GemmFeatureEnum::Xdl,
DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
...@@ -43,8 +44,7 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, ...@@ -43,8 +44,7 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
C0DEElementwiseOperation, C0DEElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
C1DEElementwiseOperation, C1DEElementwiseOperation,
MaskingSpec>, MaskingSpec>>
GemmFeatureEnum::Xdl>
{ {
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
......
...@@ -27,7 +27,8 @@ template <index_t NumDimG, ...@@ -27,7 +27,8 @@ template <index_t NumDimG,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename C1DEElementwiseOperation, typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, struct DeviceOperationInstances<GemmFeatureEnum::Xdl,
DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
NumDimK, NumDimK,
...@@ -43,8 +44,7 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, ...@@ -43,8 +44,7 @@ struct DeviceOperationInstances<DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
C0DEElementwiseOperation, C0DEElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
C1DEElementwiseOperation, C1DEElementwiseOperation,
MaskingSpec>, MaskingSpec>>
GemmFeatureEnum::Xdl>
{ {
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
......
...@@ -66,8 +66,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -66,8 +66,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>:: DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
add_device_instances(instances); DeviceOp>::add_device_instances(instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
...@@ -108,8 +108,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -108,8 +108,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>; MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>:: DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
add_device_instances(instances); DeviceOp>::add_device_instances(instances);
} }
} // namespace instance } // namespace instance
......
...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>:: DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
add_device_instances(instances); DeviceOp>::add_device_instances(instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
...@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>; MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>:: DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
add_device_instances(instances); DeviceOp>::add_device_instances(instances);
} }
} // namespace instance } // namespace instance
......
...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>:: DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
add_device_instances(instances); DeviceOp>::add_device_instances(instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
...@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -107,8 +107,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>; MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>:: DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
add_device_instances(instances); DeviceOp>::add_device_instances(instances);
} }
} // namespace instance } // namespace instance
......
...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -65,8 +65,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpecialization::MaskOutUpperTriangle>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>:: DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
add_device_instances(instances); DeviceOp>::add_device_instances(instances);
} }
void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances( void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<
...@@ -106,8 +106,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i ...@@ -106,8 +106,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>; MaskingSpecialization::MaskDisabled>;
DeviceOperationInstanceCreator<DeviceOp, ArchitectureEnumSequence<ArchitectureEnum::Gfx908>>:: DeviceOperationInstanceCreator<ArchitectureEnumSequence<ArchitectureEnum::Gfx908>,
add_device_instances(instances); DeviceOp>::add_device_instances(instances);
} }
} // namespace instance } // namespace instance
......
...@@ -136,9 +136,9 @@ int main() ...@@ -136,9 +136,9 @@ int main()
// get device op instances // get device op instances
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
ck::tensor_operation::device::instance::DeviceOperationInstanceCreator< ck::tensor_operation::device::instance::DeviceOperationInstanceCreator<
DeviceOp, ck::tensor_operation::device::instance::ArchitectureEnumSequence<
ck::tensor_operation::device::instance::GemmFeatureEnum::Xdl>:: ck::tensor_operation::device::instance::ArchitectureEnum::All>,
add_device_instances(op_ptrs); DeviceOp>::add_device_instances(op_ptrs);
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment