Commit f084651e authored by ltqin's avatar ltqin
Browse files

fix enum

parent 185af92b
......@@ -34,14 +34,16 @@ enum struct ArchitectureEnum
None,
Gfx908,
Gfx90a,
Gfx940,
Gfx1030,
All
};
enum struct GemmFeatureEnum
enum struct ArchFeatureEnum
{
None,
Xdl,
Dl,
Wmd
Wmma
};
template <ArchitectureEnum... Is>
struct ArchitectureEnumSequence
......@@ -55,7 +57,7 @@ struct ArchitectureEnumSequence
return mData[I];
}
};
template <GemmFeatureEnum Feature, typename DeviceOp>
template <ArchFeatureEnum Feature, typename DeviceOp>
struct DeviceOperationInstances;
template <typename Arch, typename DeviceOp>
......
......@@ -74,11 +74,11 @@ struct DeviceOperationInstanceCreator<Arch,
MaskingSpec>;
static void add_device_instances(std::vector<std::unique_ptr<DeviceOp>>& instances)
{
if constexpr(DeviceOperationInstances<GemmFeatureEnum::Xdl,
if constexpr(DeviceOperationInstances<ArchFeatureEnum::Xdl,
DeviceOp>::template is_surport<Arch>())
add_device_operation_instances(
instances,
DeviceOperationInstances<GemmFeatureEnum::Xdl, DeviceOp>::get_device_instances());
DeviceOperationInstances<ArchFeatureEnum::Xdl, DeviceOp>::get_device_instances());
}
};
......
......@@ -27,7 +27,7 @@ template <index_t NumDimG,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstances<GemmFeatureEnum::Xdl,
struct DeviceOperationInstances<ArchFeatureEnum::Xdl,
DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
......@@ -95,7 +95,8 @@ struct DeviceOperationInstances<GemmFeatureEnum::Xdl,
ck::static_for<0, Archs::mSize, 1>{}([&](auto I) {
if constexpr(Archs::At(I) == ArchitectureEnum::All ||
Archs::At(I) == ArchitectureEnum::Gfx908 ||
Archs::At(I) == ArchitectureEnum::Gfx90a)
Archs::At(I) == ArchitectureEnum::Gfx90a ||
Archs::At(I) == ArchitectureEnum::Gfx940)
is_surport = true;
});
return is_surport;
......
......@@ -27,7 +27,7 @@ template <index_t NumDimG,
typename B1ElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstances<GemmFeatureEnum::Xdl,
struct DeviceOperationInstances<ArchFeatureEnum::Xdl,
DeviceBatchedGemmSoftmaxGemmPermute<NumDimG,
NumDimM,
NumDimN,
......@@ -96,7 +96,8 @@ struct DeviceOperationInstances<GemmFeatureEnum::Xdl,
ck::static_for<0, Archs::mSize, 1>{}([&](auto I) {
if constexpr(Archs::At(I) == ArchitectureEnum::All ||
Archs::At(I) == ArchitectureEnum::Gfx908 ||
Archs::At(I) == ArchitectureEnum::Gfx90a)
Archs::At(I) == ArchitectureEnum::Gfx90a ||
Archs::At(I) == ArchitectureEnum::Gfx940)
is_surport = true;
});
return is_surport;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment