Unverified Commit 28d00482 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Fx support for multiple model architectures (#17393)

* Support for Bart and LayoutLM, and partial support for XLNet

* Support for mbart

* A lot of new models supported

* Support for other models

* LayoutLM fix

* Use strings instead of classes
parent 04681c1d
...@@ -213,6 +213,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, unittest ...@@ -213,6 +213,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
all_model_classes = (BlenderbotSmallModel, BlenderbotSmallForConditionalGeneration) if is_torch_available() else () all_model_classes = (BlenderbotSmallModel, BlenderbotSmallForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
...@@ -152,7 +152,7 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -152,7 +152,7 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
""" """
all_model_classes = (CLIPVisionModel,) if is_torch_available() else () all_model_classes = (CLIPVisionModel,) if is_torch_available() else ()
fx_compatible = True
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
...@@ -303,6 +303,7 @@ class CLIPTextModelTester: ...@@ -303,6 +303,7 @@ class CLIPTextModelTester:
class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase): class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPTextModel,) if is_torch_available() else () all_model_classes = (CLIPTextModel,) if is_torch_available() else ()
fx_compatible = True
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
...@@ -388,6 +389,7 @@ class CLIPModelTester: ...@@ -388,6 +389,7 @@ class CLIPModelTester:
@require_torch @require_torch
class CLIPModelTest(ModelTesterMixin, unittest.TestCase): class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPModel,) if is_torch_available() else () all_model_classes = (CLIPModel,) if is_torch_available() else ()
fx_compatible = True
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
......
...@@ -215,6 +215,7 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -215,6 +215,7 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else None else None
) )
fx_compatible = True
def setUp(self): def setUp(self):
self.model_tester = LayoutLMModelTester(self) self.model_tester = LayoutLMModelTester(self)
......
...@@ -231,6 +231,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase ...@@ -231,6 +231,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
) )
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
...@@ -230,6 +230,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase ...@@ -230,6 +230,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
all_model_classes = (MarianModel, MarianMTModel) if is_torch_available() else () all_model_classes = (MarianModel, MarianMTModel) if is_torch_available() else ()
all_generative_model_classes = (MarianMTModel,) if is_torch_available() else () all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
...@@ -224,6 +224,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ...@@ -224,6 +224,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
) )
all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
...@@ -178,6 +178,7 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -178,6 +178,7 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (OPTModel, OPTForCausalLM) if is_torch_available() else () all_model_classes = (OPTModel, OPTForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else () all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
is_encoder_decoder = False is_encoder_decoder = False
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
...@@ -229,6 +229,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas ...@@ -229,6 +229,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else () all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_resize_position_embeddings = True test_resize_position_embeddings = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
...@@ -219,6 +219,7 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase ...@@ -219,6 +219,7 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
) )
all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import copy import copy
import inspect import inspect
import os import os
import pickle
import tempfile import tempfile
import unittest import unittest
...@@ -30,7 +31,7 @@ from transformers.testing_utils import ( ...@@ -30,7 +31,7 @@ from transformers.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from transformers.utils import cached_property from transformers.utils import cached_property, is_torch_fx_available
from ...generation.test_generation_utils import GenerationTesterMixin from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -43,6 +44,9 @@ if is_torch_available(): ...@@ -43,6 +44,9 @@ if is_torch_available():
from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder
if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
def prepare_speech_to_text_inputs_dict( def prepare_speech_to_text_inputs_dict(
config, config,
...@@ -271,6 +275,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes ...@@ -271,6 +275,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else () all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
...@@ -715,6 +720,105 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes ...@@ -715,6 +720,105 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
self.assertTrue(models_equal) self.assertTrue(models_equal)
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"input_features",
]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values", "input_features"]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
if labels is not None:
input_names.append("labels")
if start_positions is not None:
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except RuntimeError as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
for x in output:
if isinstance(x, (tuple, list)):
flatten += flatten_output(x)
elif not isinstance(x, torch.Tensor):
continue
else:
flatten.append(x)
return flatten
model_output = flatten_output(model_output)
traced_output = flatten_output(traced_output)
num_outputs = len(model_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], traced_output[i]),
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
@require_torch @require_torch
@require_torchaudio @require_torchaudio
......
...@@ -179,6 +179,7 @@ class Speech2Text2StandaloneDecoderModelTester: ...@@ -179,6 +179,7 @@ class Speech2Text2StandaloneDecoderModelTester:
class Speech2Text2StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class Speech2Text2StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else () all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else () all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else ()
fx_compatible = True
test_pruning = False test_pruning = False
def setUp( def setUp(
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Swin model. """ """ Testing suite for the PyTorch Swin model. """
import copy
import inspect import inspect
import os import os
import pickle import pickle
...@@ -26,7 +25,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc ...@@ -26,7 +25,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
...@@ -45,14 +44,6 @@ if is_torch_fx_available(): ...@@ -45,14 +44,6 @@ if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace from transformers.utils.fx import symbolic_trace
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
setattr(configs_no_init, key, 1e-10)
return configs_no_init
class SwinModelTester: class SwinModelTester:
def __init__( def __init__(
self, self,
...@@ -407,7 +398,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -407,7 +398,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
if labels is not None: if labels is not None:
input_names.append("labels") input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)
...@@ -427,7 +420,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -427,7 +420,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
input_names.append("end_positions") input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = filtered_inputs.keys() input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)
......
...@@ -509,8 +509,8 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -509,8 +509,8 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = True
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
fx_compatible = True
test_pruning = False test_pruning = False
test_resize_embeddings = True test_resize_embeddings = True
test_model_parallel = True test_model_parallel = True
......
...@@ -161,6 +161,7 @@ class TrOCRStandaloneDecoderModelTester: ...@@ -161,6 +161,7 @@ class TrOCRStandaloneDecoderModelTester:
class TrOCRStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class TrOCRStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (TrOCRDecoder, TrOCRForCausalLM) if is_torch_available() else () all_model_classes = (TrOCRDecoder, TrOCRForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (TrOCRForCausalLM,) if is_torch_available() else () all_generative_model_classes = (TrOCRForCausalLM,) if is_torch_available() else ()
fx_compatible = True
test_pruning = False test_pruning = False
def setUp(self): def setUp(self):
......
...@@ -13,17 +13,26 @@ ...@@ -13,17 +13,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime import datetime
import math import math
import os
import pickle
import tempfile
import unittest import unittest
from transformers import XGLMConfig, is_torch_available from transformers import XGLMConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import is_torch_fx_available
from ...generation.test_generation_utils import GenerationTesterMixin from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
random_attention_mask,
)
if is_torch_available(): if is_torch_available():
...@@ -31,6 +40,9 @@ if is_torch_available(): ...@@ -31,6 +40,9 @@ if is_torch_available():
from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer
if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
class XGLMModelTester: class XGLMModelTester:
def __init__( def __init__(
...@@ -299,6 +311,7 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -299,6 +311,7 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (XGLMModel, XGLMForCausalLM) if is_torch_available() else () all_model_classes = (XGLMModel, XGLMForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (XGLMForCausalLM,) if is_torch_available() else () all_generative_model_classes = (XGLMForCausalLM,) if is_torch_available() else ()
fx_compatible = True
test_missing_keys = False test_missing_keys = False
test_pruning = False test_pruning = False
...@@ -337,6 +350,112 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -337,6 +350,112 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs) self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs)
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"input_features",
]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = [
"input_ids",
"attention_mask",
"token_type_ids",
"pixel_values",
"bbox",
"input_features",
]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
if labels is not None:
input_names.append("labels")
if start_positions is not None:
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except RuntimeError as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
for x in output:
if isinstance(x, (tuple, list)):
flatten += flatten_output(x)
elif not isinstance(x, torch.Tensor):
continue
else:
flatten.append(x)
return flatten
model_output = flatten_output(model_output)
traced_output = flatten_output(traced_output)
num_outputs = len(model_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], traced_output[i]),
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
@slow @slow
def test_batch_generation(self): def test_batch_generation(self):
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
......
...@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ...@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
all_generative_model_classes = ( all_generative_model_classes = (
(XLNetLMHeadModel,) if is_torch_available() else () (XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable ) # TODO (PVP): Check other models whether language generation is also applicable
fx_compatible = False
test_pruning = False test_pruning = False
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here # XLNet has 2 QA models -> need to manually set the correct labels for one of them here
......
...@@ -738,17 +738,32 @@ class ModelTesterMixin: ...@@ -738,17 +738,32 @@ class ModelTesterMixin:
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None) labels = inputs.get("labels", None)
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] input_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"input_features",
]
if labels is not None: if labels is not None:
input_names.append("labels") input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names) traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs) traced_output = traced_model(**filtered_inputs)
else: else:
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"] input_names = [
"input_ids",
"attention_mask",
"token_type_ids",
"pixel_values",
"bbox",
"input_features",
]
labels = inputs.get("labels", None) labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None) start_positions = inputs.get("start_positions", None)
...@@ -761,7 +776,7 @@ class ModelTesterMixin: ...@@ -761,7 +776,7 @@ class ModelTesterMixin:
input_names.append("end_positions") input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = filtered_inputs.keys() input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment