Commit f084651e authored by ltqin's avatar ltqin
Browse files

fix enum

parent 185af92b
...@@ -34,14 +34,16 @@ enum struct ArchitectureEnum ...@@ -34,14 +34,16 @@ enum struct ArchitectureEnum
None, None,
Gfx908, Gfx908,
Gfx90a, Gfx90a,
Gfx940,
Gfx1030, Gfx1030,
All All
}; };
enum struct GemmFeatureEnum enum struct ArchFeatureEnum
{ {
None,
Xdl, Xdl,
Dl, Dl,
Wmd Wmma
}; };
template <ArchitectureEnum... Is> template <ArchitectureEnum... Is>
struct ArchitectureEnumSequence struct ArchitectureEnumSequence
...@@ -55,7 +57,7 @@ struct ArchitectureEnumSequence ...@@ -55,7 +57,7 @@ struct ArchitectureEnumSequence
return mData[I]; return mData[I];
} }
}; };
template <GemmFeatureEnum Feature, typename DeviceOp> template <ArchFeatureEnum Feature, typename DeviceOp>
struct DeviceOperationInstances; struct DeviceOperationInstances;
template <typename Arch, typename DeviceOp> template <typename Arch, typename DeviceOp>
......
...@@ -74,11 +74,11 @@ struct DeviceOperationInstanceCreator<Arch, ...@@ -74,11 +74,11 @@ struct DeviceOperationInstanceCreator<Arch,
MaskingSpec>; MaskingSpec>;
static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances) static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances)
{ {
if constexpr(DeviceOperationInstances<GemmFeatureEnum::Xdl, if constexpr(DeviceOperationInstances<ArchFeatureEnum::Xdl,
DeviceOp>::template is_surport<Arch>()) DeviceOp>::template is_surport<Arch>())
add_device_operation_instances( add_device_operation_instances(
instances, instances,
DeviceOperationInstances<GemmFeatureEnum::Xdl, DeviceOp>::get_device_instances()); DeviceOperationInstances<ArchFeatureEnum::Xdl, DeviceOp>::get_device_instances());
} }
}; };
......
...@@ -27,7 +27,7 @@ template <index_t NumDimG, ...@@ -27,7 +27,7 @@ template <index_t NumDimG,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename C1DEElementwiseOperation, typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
struct DeviceOperationInstances<GemmFeatureEnum::Xdl, struct DeviceOperationInstances<ArchFeatureEnum::Xdl,
DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -95,7 +95,8 @@ struct DeviceOperationInstances<GemmFeatureEnum::Xdl, ...@@ -95,7 +95,8 @@ struct DeviceOperationInstances<GemmFeatureEnum::Xdl,
ck::static_for<0, Archs::mSize, 1>{}([&](auto I) { ck::static_for<0, Archs::mSize, 1>{}([&](auto I) {
if constexpr(Archs::At(I) == ArchitectureEnum::All || if constexpr(Archs::At(I) == ArchitectureEnum::All ||
Archs::At(I) == ArchitectureEnum::Gfx908 || Archs::At(I) == ArchitectureEnum::Gfx908 ||
Archs::At(I) == ArchitectureEnum::Gfx90a) Archs::At(I) == ArchitectureEnum::Gfx90a ||
Archs::At(I) == ArchitectureEnum::Gfx940)
is_surport = true; is_surport = true;
}); });
return is_surport; return is_surport;
......
...@@ -27,7 +27,7 @@ template <index_t NumDimG, ...@@ -27,7 +27,7 @@ template <index_t NumDimG,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename C1DEElementwiseOperation, typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
struct DeviceOperationInstances<GemmFeatureEnum::Xdl, struct DeviceOperationInstances<ArchFeatureEnum::Xdl,
DeviceBatchedGemmSoftmaxGemmPermute<NumDimG, DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -96,7 +96,8 @@ struct DeviceOperationInstances<GemmFeatureEnum::Xdl, ...@@ -96,7 +96,8 @@ struct DeviceOperationInstances<GemmFeatureEnum::Xdl,
ck::static_for<0, Archs::mSize, 1>{}([&](auto I) { ck::static_for<0, Archs::mSize, 1>{}([&](auto I) {
if constexpr(Archs::At(I) == ArchitectureEnum::All || if constexpr(Archs::At(I) == ArchitectureEnum::All ||
Archs::At(I) == ArchitectureEnum::Gfx908 || Archs::At(I) == ArchitectureEnum::Gfx908 ||
Archs::At(I) == ArchitectureEnum::Gfx90a) Archs::At(I) == ArchitectureEnum::Gfx90a ||
Archs::At(I) == ArchitectureEnum::Gfx940)
is_surport = true; is_surport = true;
}); });
return is_surport; return is_surport;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment