"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "9a00cf194fcf994b2527cd927d691144f5e9c47b"
Unverified Commit 1962fddb authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add check for fx compatibility on segmentation and video models (#4131)

* Add check for fx compatibility on segmentation models.

* Add fx check on video models.
parent d16a1920
...@@ -466,6 +466,7 @@ def test_segmentation_model(model_name, dev): ...@@ -466,6 +466,7 @@ def test_segmentation_model(model_name, dev):
full_validation = check_out(out) full_validation = check_out(out)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None)) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)
if dev == torch.device("cuda"): if dev == torch.device("cuda"):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
...@@ -616,6 +617,7 @@ def test_video_model(model_name, dev): ...@@ -616,6 +617,7 @@ def test_video_model(model_name, dev):
x = torch.rand(input_shape).to(device=dev) x = torch.rand(input_shape).to(device=dev)
out = model(x) out = model(x)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None)) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)
assert out.shape[-1] == 50 assert out.shape[-1] == 50
if dev == torch.device("cuda"): if dev == torch.device("cuda"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment