Unverified Commit f4ab3e7e authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

[FBcode->GH] Better logic for ignoring CPU tests on GPU CI machines (#4025) (#4062)

parent 686ff598
...@@ -8,6 +8,9 @@ def pytest_configure(config): ...@@ -8,6 +8,9 @@ def pytest_configure(config):
config.addinivalue_line( config.addinivalue_line(
"markers", "needs_cuda: mark for tests that rely on a CUDA device" "markers", "needs_cuda: mark for tests that rely on a CUDA device"
) )
config.addinivalue_line(
"markers", "dont_collect: mark for tests that should not be collected"
)
def pytest_collection_modifyitems(items): def pytest_collection_modifyitems(items):
...@@ -47,6 +50,10 @@ def pytest_collection_modifyitems(items): ...@@ -47,6 +50,10 @@ def pytest_collection_modifyitems(items):
# to run the CPU-only tests. # to run the CPU-only tests.
item.add_marker(pytest.mark.skip(reason=CIRCLECI_GPU_NO_CUDA_MSG)) item.add_marker(pytest.mark.skip(reason=CIRCLECI_GPU_NO_CUDA_MSG))
if item.get_closest_marker('dont_collect') is not None:
# currently, this is only used for some tests we're sure we dont want to run on fbcode
continue
out_items.append(item) out_items.append(item)
items[:] = out_items items[:] = out_items
...@@ -358,6 +358,18 @@ def test_encode_jpeg_errors(): ...@@ -358,6 +358,18 @@ def test_encode_jpeg_errors():
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8)) encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
def _collect_if(cond):
# TODO: remove this once test_encode_jpeg_reference and test_write_jpeg_reference
# are removed
def _inner(test_func):
if cond:
return test_func
else:
return pytest.mark.dont_collect(test_func)
return _inner
@_collect_if(cond=IS_WINDOWS)
@pytest.mark.parametrize('img_path', [ @pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg") for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
...@@ -389,6 +401,7 @@ def test_encode_jpeg_reference(img_path): ...@@ -389,6 +401,7 @@ def test_encode_jpeg_reference(img_path):
assert_equal(jpeg_bytes, pil_bytes) assert_equal(jpeg_bytes, pil_bytes)
@_collect_if(cond=IS_WINDOWS)
@pytest.mark.parametrize('img_path', [ @pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
for jpeg_path in get_images(ENCODE_JPEG, ".jpg") for jpeg_path in get_images(ENCODE_JPEG, ".jpg")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment