Unverified Commit f3d6c6be authored by ptrblck's avatar ptrblck Committed by GitHub
Browse files

disable tf32 in cuDNN for classification models (#7634)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 0ab7d05c
...@@ -25,6 +25,14 @@ ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" ...@@ -25,6 +25,14 @@ ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1" SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"
@contextlib.contextmanager
def disable_tf32():
previous = torch.backends.cudnn.allow_tf32
torch.backends.cudnn.allow_tf32 = False
yield
torch.backends.cudnn.allow_tf32 = previous
def list_model_fns(module): def list_model_fns(module):
return [get_model_builder(name) for name in list_models(module)] return [get_model_builder(name) for name in list_models(module)]
...@@ -671,6 +679,7 @@ def test_vitc_models(model_fn, dev): ...@@ -671,6 +679,7 @@ def test_vitc_models(model_fn, dev):
test_classification_model(model_fn, dev) test_classification_model(model_fn, dev)
@disable_tf32() # see: https://github.com/pytorch/vision/issues/7618
@pytest.mark.parametrize("model_fn", list_model_fns(models)) @pytest.mark.parametrize("model_fn", list_model_fns(models))
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
def test_classification_model(model_fn, dev): def test_classification_model(model_fn, dev):
...@@ -682,11 +691,6 @@ def test_classification_model(model_fn, dev): ...@@ -682,11 +691,6 @@ def test_classification_model(model_fn, dev):
model_name = model_fn.__name__ model_name = model_fn.__name__
if SKIP_BIG_MODEL and is_skippable(model_name, dev): if SKIP_BIG_MODEL and is_skippable(model_name, dev):
pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model") pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
if model_name == "resnet101" and dev == "cuda":
# TODO: Investigate the Failure with CUDA 11.8: https://github.com/pytorch/vision/issues/7618
# TODO: Investigate/followup on previous failure: https://github.com/pytorch/vision/issues/7143
# its not happening on CI with CUDA 11.8 anymore. Follow up is needed if its still not resolved.
pytest.xfail("https://github.com/pytorch/vision/issues/7618")
kwargs = {**defaults, **_model_params.get(model_name, {})} kwargs = {**defaults, **_model_params.get(model_name, {})}
num_classes = kwargs.get("num_classes") num_classes = kwargs.get("num_classes")
input_shape = kwargs.pop("input_shape") input_shape = kwargs.pop("input_shape")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment