Unverified Commit e7adffb9 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Fix] Skip filtered_lrelu ut when cuda is less than 10.2 (#2677)

parent d31b2212
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import pytest import pytest
import torch import torch
from mmengine.utils import digit_version
from mmcv.ops import filtered_lrelu from mmcv.ops import filtered_lrelu
...@@ -113,7 +114,10 @@ class TestFilteredLrelu: ...@@ -113,7 +114,10 @@ class TestFilteredLrelu:
self.input_tensor, bias=self.bias, flip_filter=True) self.input_tensor, bias=self.bias, flip_filter=True)
assert out.shape == (1, 3, 16, 16) assert out.shape == (1, 3, 16, 16)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') @pytest.mark.skipif(
not torch.cuda.is_available()
or digit_version(torch.version.cuda) < digit_version('10.2'),
reason='requires cuda>=10.2')
def test_filtered_lrelu_cuda(self): def test_filtered_lrelu_cuda(self):
out = filtered_lrelu(self.input_tensor.cuda(), bias=self.bias.cuda()) out = filtered_lrelu(self.input_tensor.cuda(), bias=self.bias.cuda())
assert out.shape == (1, 3, 16, 16) assert out.shape == (1, 3, 16, 16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment