"third_party/vscode:/vscode.git/clone" did not exist on "1d28bf8b71fc441e7dcf10b753d29cd1d2e1fe81"
Unverified Commit 69ce4523 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Validate against expected files on videos (#6077)

* Validate against expected files on videos

* Plus tests for autocast
parent 3a2631ba
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -822,24 +822,36 @@ def test_detection_model_validation(model_fn): ...@@ -822,24 +822,36 @@ def test_detection_model_validation(model_fn):
@pytest.mark.parametrize("model_fn", get_models_from_module(models.video)) @pytest.mark.parametrize("model_fn", get_models_from_module(models.video))
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
def test_video_model(model_fn, dev): def test_video_model(model_fn, dev):
set_rng_seed(0)
# the default input shape is # the default input shape is
# bs * num_channels * clip_len * h *w # bs * num_channels * clip_len * h *w
input_shape = (1, 3, 4, 112, 112) defaults = {
"input_shape": (1, 3, 4, 112, 112),
"num_classes": 50,
}
model_name = model_fn.__name__ model_name = model_fn.__name__
kwargs = {**defaults, **_model_params.get(model_name, {})}
num_classes = kwargs.get("num_classes")
input_shape = kwargs.pop("input_shape")
# test both basicblock and Bottleneck # test both basicblock and Bottleneck
model = model_fn(num_classes=50) model = model_fn(**kwargs)
model.eval().to(device=dev) model.eval().to(device=dev)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev) x = torch.rand(input_shape).to(device=dev)
out = model(x) out = model(x)
_assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == num_classes
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out) _check_fx_compatible(model, x, eager_out=out)
assert out.shape[-1] == 50 assert out.shape[-1] == num_classes
if dev == "cuda": if dev == "cuda":
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
out = model(x) out = model(x)
assert out.shape[-1] == 50 # See autocast_flaky_numerics comment at top of file.
if model_name not in autocast_flaky_numerics:
_assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == num_classes
_check_input_backprop(model, x) _check_input_backprop(model, x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment