Unverified Commit b4a075cb authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add test to check that classification models are FX-compatible (#3662)



* Add test to check that classification models are FX-compatible

* Replace torch.equal with torch.allclose

* remove skipling
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 68b128d5
......@@ -7,6 +7,7 @@ from collections import OrderedDict
import functools
import operator
import torch
import torch.fx
import torch.nn as nn
import torchvision
from torchvision import models
......@@ -140,6 +141,13 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
assert_export_import_module(sm, args)
def _check_fx_compatible(model, inputs):
model_fx = torch.fx.symbolic_trace(model)
out = model(inputs)
out_fx = model_fx(inputs)
torch.testing.assert_close(out, out_fx)
# If 'unwrapper' is provided it will be called with the script model outputs
# before they are compared to the eager model outputs. This is useful if the
# model outputs are different between TorchScript / Eager mode
......@@ -408,6 +416,7 @@ def test_classification_model(model_name, dev):
_assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == 50
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
_check_fx_compatible(model, x)
if dev == torch.device("cuda"):
with torch.cuda.amp.autocast():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment