Unverified Commit bfb474b9 authored by Aidyn-A's avatar Aidyn-A Committed by GitHub
Browse files

Update test detection model (#6939)

* update tests

* revert replacing expect

* test_detection_model freeze_rng_state

* update tests
parent 4a310f26
...@@ -40,7 +40,7 @@ def _get_image(input_shape, real_image, device): ...@@ -40,7 +40,7 @@ def _get_image(input_shape, real_image, device):
- `fcos_resnet50_fpn`, - `fcos_resnet50_fpn`,
- `maskrcnn_resnet50_fpn`, - `maskrcnn_resnet50_fpn`,
- `maskrcnn_resnet50_fpn_v2`, - `maskrcnn_resnet50_fpn_v2`,
in `test_classification_model` and `test_detection_mode`. in `test_classification_model` and `test_detection_model`.
To do so, a keyword argument `real_image` was added to the abovelisted models in `_model_params` To do so, a keyword argument `real_image` was added to the abovelisted models in `_model_params`
""" """
if real_image: if real_image:
...@@ -167,6 +167,7 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, eager_out=None): ...@@ -167,6 +167,7 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, eager_out=None):
return imported return imported
sm = torch.jit.script(nn_module) sm = torch.jit.script(nn_module)
sm.eval()
if eager_out is None: if eager_out is None:
with torch.no_grad(), freeze_rng_state(): with torch.no_grad(), freeze_rng_state():
...@@ -192,6 +193,7 @@ def _check_fx_compatible(model, inputs, eager_out=None): ...@@ -192,6 +193,7 @@ def _check_fx_compatible(model, inputs, eager_out=None):
model_fx = torch.fx.symbolic_trace(model) model_fx = torch.fx.symbolic_trace(model)
if eager_out is None: if eager_out is None:
eager_out = model(inputs) eager_out = model(inputs)
with torch.no_grad(), freeze_rng_state():
fx_out = model_fx(inputs) fx_out = model_fx(inputs)
torch.testing.assert_close(eager_out, fx_out) torch.testing.assert_close(eager_out, fx_out)
...@@ -717,6 +719,7 @@ def test_segmentation_model(model_fn, dev): ...@@ -717,6 +719,7 @@ def test_segmentation_model(model_fn, dev):
model.eval().to(device=dev) model.eval().to(device=dev)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev) x = torch.rand(input_shape).to(device=dev)
with torch.no_grad(), freeze_rng_state():
out = model(x) out = model(x)
def check_out(out): def check_out(out):
...@@ -745,7 +748,7 @@ def test_segmentation_model(model_fn, dev): ...@@ -745,7 +748,7 @@ def test_segmentation_model(model_fn, dev):
_check_fx_compatible(model, x, eager_out=out) _check_fx_compatible(model, x, eager_out=out)
if dev == "cuda": if dev == "cuda":
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state():
out = model(x) out = model(x)
# See autocast_flaky_numerics comment at top of file. # See autocast_flaky_numerics comment at top of file.
if model_name not in autocast_flaky_numerics: if model_name not in autocast_flaky_numerics:
...@@ -782,6 +785,7 @@ def test_detection_model(model_fn, dev): ...@@ -782,6 +785,7 @@ def test_detection_model(model_fn, dev):
model.eval().to(device=dev) model.eval().to(device=dev)
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev) x = _get_image(input_shape=input_shape, real_image=real_image, device=dev)
model_input = [x] model_input = [x]
with torch.no_grad(), freeze_rng_state():
out = model(model_input) out = model(model_input)
assert model_input[0] is x assert model_input[0] is x
...@@ -843,7 +847,7 @@ def test_detection_model(model_fn, dev): ...@@ -843,7 +847,7 @@ def test_detection_model(model_fn, dev):
_check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) _check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
if dev == "cuda": if dev == "cuda":
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state():
out = model(model_input) out = model(model_input)
# See autocast_flaky_numerics comment at top of file. # See autocast_flaky_numerics comment at top of file.
if model_name not in autocast_flaky_numerics: if model_name not in autocast_flaky_numerics:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment