Unverified Commit bfb474b9 authored by Aidyn-A's avatar Aidyn-A Committed by GitHub
Browse files

Update test detection model (#6939)

* update tests

* revert replacing expect

* test_detection_model freeze_rng_state

* update tests
parent 4a310f26
......@@ -40,7 +40,7 @@ def _get_image(input_shape, real_image, device):
- `fcos_resnet50_fpn`,
- `maskrcnn_resnet50_fpn`,
- `maskrcnn_resnet50_fpn_v2`,
in `test_classification_model` and `test_detection_mode`.
in `test_classification_model` and `test_detection_model`.
To do so, a keyword argument `real_image` was added to the abovelisted models in `_model_params`
"""
if real_image:
......@@ -167,6 +167,7 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, eager_out=None):
return imported
sm = torch.jit.script(nn_module)
sm.eval()
if eager_out is None:
with torch.no_grad(), freeze_rng_state():
......@@ -192,7 +193,8 @@ def _check_fx_compatible(model, inputs, eager_out=None):
model_fx = torch.fx.symbolic_trace(model)
if eager_out is None:
eager_out = model(inputs)
fx_out = model_fx(inputs)
with torch.no_grad(), freeze_rng_state():
fx_out = model_fx(inputs)
torch.testing.assert_close(eager_out, fx_out)
......@@ -717,7 +719,8 @@ def test_segmentation_model(model_fn, dev):
model.eval().to(device=dev)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
with torch.no_grad(), freeze_rng_state():
out = model(x)
def check_out(out):
prec = 0.01
......@@ -745,7 +748,7 @@ def test_segmentation_model(model_fn, dev):
_check_fx_compatible(model, x, eager_out=out)
if dev == "cuda":
with torch.cuda.amp.autocast():
with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state():
out = model(x)
# See autocast_flaky_numerics comment at top of file.
if model_name not in autocast_flaky_numerics:
......@@ -782,7 +785,8 @@ def test_detection_model(model_fn, dev):
model.eval().to(device=dev)
x = _get_image(input_shape=input_shape, real_image=real_image, device=dev)
model_input = [x]
out = model(model_input)
with torch.no_grad(), freeze_rng_state():
out = model(model_input)
assert model_input[0] is x
def check_out(out):
......@@ -843,7 +847,7 @@ def test_detection_model(model_fn, dev):
_check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
if dev == "cuda":
with torch.cuda.amp.autocast():
with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state():
out = model(model_input)
# See autocast_flaky_numerics comment at top of file.
if model_name not in autocast_flaky_numerics:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment