"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8957324363d8b239d82db4909fbf8c0875683e3d"
Unverified Commit 84469834 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Make test precision stricter for Classification (#6380)

* Make test precision stricter for Classification

* Update classification threshold.

* Update quantized classification threshold.
parent 96dbada4
...@@ -614,7 +614,7 @@ def test_classification_model(model_fn, dev): ...@@ -614,7 +614,7 @@ def test_classification_model(model_fn, dev):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev) x = torch.rand(input_shape).to(device=dev)
out = model(x) out = model(x)
_assert_expected(out.cpu(), model_name, prec=0.1) _assert_expected(out.cpu(), model_name, prec=1e-3)
assert out.shape[-1] == num_classes assert out.shape[-1] == num_classes
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out) _check_fx_compatible(model, x, eager_out=out)
...@@ -841,7 +841,7 @@ def test_video_model(model_fn, dev): ...@@ -841,7 +841,7 @@ def test_video_model(model_fn, dev):
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev) x = torch.rand(input_shape).to(device=dev)
out = model(x) out = model(x)
_assert_expected(out.cpu(), model_name, prec=0.1) _assert_expected(out.cpu(), model_name, prec=1e-5)
assert out.shape[-1] == num_classes assert out.shape[-1] == num_classes
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out) _check_fx_compatible(model, x, eager_out=out)
...@@ -884,7 +884,7 @@ def test_quantized_classification_model(model_fn): ...@@ -884,7 +884,7 @@ def test_quantized_classification_model(model_fn):
out = model(x) out = model(x)
if model_name not in quantized_flaky_models: if model_name not in quantized_flaky_models:
_assert_expected(out, model_name + "_quantized", prec=0.1) _assert_expected(out, model_name + "_quantized", prec=2e-2)
assert out.shape[-1] == 5 assert out.shape[-1] == 5
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
_check_fx_compatible(model, x, eager_out=out) _check_fx_compatible(model, x, eager_out=out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment