Unverified Commit e7adffb9 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Fix] Skip filtered_lrelu ut when cuda is less than 10.2 (#2677)

parent d31b2212
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.utils import digit_version
from mmcv.ops import filtered_lrelu
......@@ -113,7 +114,10 @@ class TestFilteredLrelu:
self.input_tensor, bias=self.bias, flip_filter=True)
assert out.shape == (1, 3, 16, 16)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not torch.cuda.is_available()
or digit_version(torch.version.cuda) < digit_version('10.2'),
reason='requires cuda>=10.2')
def test_filtered_lrelu_cuda(self):
out = filtered_lrelu(self.input_tensor.cuda(), bias=self.bias.cuda())
assert out.shape == (1, 3, 16, 16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment