Unverified Commit ac5e868f authored by vcherepanov-nv's avatar vcherepanov-nv Committed by GitHub
Browse files

Skip fp8 tests on unsupported devices (#2243)


Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
parent 76bced54
......@@ -69,6 +69,34 @@ bool IsMulticastSupported(int device_id) {
return supported;
}
int GetDeviceComputeCapability(int device_id) {
int major{};
int minor{};
CHECK_CU(cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device_id));
CHECK_CU(cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device_id));
return major * 10 + minor;
}
template <typename T>
bool IsDTypeSupported(int /* device_id */) {
return true;
}
template <>
bool IsDTypeSupported<test::fp8e5m2>(int device_id) {
return GetDeviceComputeCapability(device_id) >= 89;
}
template <>
bool IsDTypeSupported<test::fp8e4m3>(int device_id) {
return GetDeviceComputeCapability(device_id) >= 89;
}
template <typename... Ts>
bool AllDTypesSupported(int device_id) {
return (IsDTypeSupported<Ts>(device_id) && ...);
}
template <typename T>
std::vector<T> CopyMatrix(const std::vector<T>& data, size_t mstart, size_t nstart, size_t msize,
size_t nsize, size_t ld) {
......@@ -161,6 +189,9 @@ class CommGemmFixure : public ::testing::TestWithParam<Params> {
template <typename AType, typename BType, typename DType, typename BiasType>
void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) {
if (!AllDTypesSupported<AType, BType, DType, BiasType>(rank_))
GTEST_SKIP() << "FP8 is not supported on device " << rank_;
cudaStream_t stream{};
NVTE_CHECK_CUDA(cudaStreamCreate(&stream));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment