Unverified Commit 81700555 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Test some flaky detection models on float64 instead of float32 (#7204)

parent d75a5241
...@@ -29,7 +29,7 @@ def list_model_fns(module): ...@@ -29,7 +29,7 @@ def list_model_fns(module):
return [get_model_builder(name) for name in list_models(module)] return [get_model_builder(name) for name in list_models(module)]
def _get_image(input_shape, real_image, device): def _get_image(input_shape, real_image, device, dtype=None):
"""This routine loads a real or random image based on `real_image` argument. """This routine loads a real or random image based on `real_image` argument.
Currently, the real image is utilized for the following list of models: Currently, the real image is utilized for the following list of models:
- `retinanet_resnet50_fpn`, - `retinanet_resnet50_fpn`,
...@@ -60,10 +60,10 @@ def _get_image(input_shape, real_image, device): ...@@ -60,10 +60,10 @@ def _get_image(input_shape, real_image, device):
convert_tensor = transforms.ToTensor() convert_tensor = transforms.ToTensor()
image = convert_tensor(img) image = convert_tensor(img)
assert tuple(image.size()) == input_shape assert tuple(image.size()) == input_shape
return image.to(device=device) return image.to(device=device, dtype=dtype)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
return torch.rand(input_shape).to(device=device) return torch.rand(input_shape).to(device=device, dtype=dtype)
@pytest.fixture @pytest.fixture
...@@ -278,6 +278,11 @@ autocast_flaky_numerics = ( ...@@ -278,6 +278,11 @@ autocast_flaky_numerics = (
# tests under test_quantized_classification_model will be skipped for the following models. # tests under test_quantized_classification_model will be skipped for the following models.
quantized_flaky_models = ("inception_v3", "resnet50") quantized_flaky_models = ("inception_v3", "resnet50")
# The tests for the following detection models are flaky.
# We run those tests on float64 to avoid floating point errors.
# FIXME: we shouldn't have to do that :'/
detection_flaky_models = ("keypointrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn", "maskrcnn_resnet50_fpn_v2")
# The following contains configuration parameters for all models which are used by # The following contains configuration parameters for all models which are used by
# the _test_*_model methods. # the _test_*_model methods.
...@@ -777,13 +782,17 @@ def test_detection_model(model_fn, dev): ...@@ -777,13 +782,17 @@ def test_detection_model(model_fn, dev):
"input_shape": (3, 300, 300), "input_shape": (3, 300, 300),
} }
model_name = model_fn.__name__ model_name = model_fn.__name__
if model_name in detection_flaky_models:
dtype = torch.float64
else:
dtype = torch.get_default_dtype()
kwargs = {**defaults, **_model_params.get(model_name, {})} kwargs = {**defaults, **_model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape") input_shape = kwargs.pop("input_shape")
real_image = kwargs.pop("real_image", False) real_image = kwargs.pop("real_image", False)
model = model_fn(**kwargs) model = model_fn(**kwargs)
model.eval().to(device=dev) model.eval().to(device=dev, dtype=dtype)
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev) x = _get_image(input_shape=input_shape, real_image=real_image, device=dev, dtype=dtype)
model_input = [x] model_input = [x]
with torch.no_grad(), freeze_rng_state(): with torch.no_grad(), freeze_rng_state():
out = model(model_input) out = model(model_input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment