Unverified Commit b4a075cb authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add test to check that classification models are FX-compatible (#3662)



* Add test to check that classification models are FX-compatible

* Replace torch.equal with torch.allclose

* remove skipling
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 68b128d5
...@@ -7,6 +7,7 @@ from collections import OrderedDict ...@@ -7,6 +7,7 @@ from collections import OrderedDict
import functools import functools
import operator import operator
import torch import torch
import torch.fx
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
from torchvision import models from torchvision import models
...@@ -140,6 +141,13 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False): ...@@ -140,6 +141,13 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
assert_export_import_module(sm, args) assert_export_import_module(sm, args)
def _check_fx_compatible(model, inputs):
model_fx = torch.fx.symbolic_trace(model)
out = model(inputs)
out_fx = model_fx(inputs)
torch.testing.assert_close(out, out_fx)
# If 'unwrapper' is provided it will be called with the script model outputs # If 'unwrapper' is provided it will be called with the script model outputs
# before they are compared to the eager model outputs. This is useful if the # before they are compared to the eager model outputs. This is useful if the
# model outputs are different between TorchScript / Eager mode # model outputs are different between TorchScript / Eager mode
...@@ -408,6 +416,7 @@ def test_classification_model(model_name, dev): ...@@ -408,6 +416,7 @@ def test_classification_model(model_name, dev):
_assert_expected(out.cpu(), model_name, prec=0.1) _assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == 50 assert out.shape[-1] == 50
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None)) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)
if dev == torch.device("cuda"): if dev == torch.device("cuda"):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment