Unverified Commit 28d00482 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Fx support for multiple model architectures (#17393)

* Support for Bart and LayoutLM, and partial support for XLNet

* Support for mbart

* A lot of new models supported

* Support for other models

* LayoutLM fix

* Use strings instead of classes
parent 04681c1d
......@@ -213,6 +213,7 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
all_model_classes = (BlenderbotSmallModel, BlenderbotSmallForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
fx_compatible = True
test_pruning = False
test_missing_keys = False
......
......@@ -152,7 +152,7 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
"""
all_model_classes = (CLIPVisionModel,) if is_torch_available() else ()
fx_compatible = True
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
......@@ -303,6 +303,7 @@ class CLIPTextModelTester:
class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPTextModel,) if is_torch_available() else ()
fx_compatible = True
test_pruning = False
test_head_masking = False
......@@ -388,6 +389,7 @@ class CLIPModelTester:
@require_torch
class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPModel,) if is_torch_available() else ()
fx_compatible = True
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
......
......@@ -215,6 +215,7 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else None
)
fx_compatible = True
def setUp(self):
self.model_tester = LayoutLMModelTester(self)
......
......@@ -231,6 +231,7 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
)
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
fx_compatible = True
test_pruning = False
test_missing_keys = False
......
......@@ -230,6 +230,7 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
all_model_classes = (MarianModel, MarianMTModel) if is_torch_available() else ()
all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()
is_encoder_decoder = True
fx_compatible = True
test_pruning = False
test_missing_keys = False
......
......@@ -224,6 +224,7 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
)
all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
fx_compatible = True
test_pruning = False
test_missing_keys = False
......
......@@ -178,6 +178,7 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (OPTModel, OPTForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
is_encoder_decoder = False
fx_compatible = True
test_pruning = False
test_missing_keys = False
......
......@@ -229,6 +229,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
fx_compatible = True
test_resize_position_embeddings = True
test_pruning = False
test_missing_keys = False
......
......@@ -219,6 +219,7 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
)
all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
fx_compatible = True
test_pruning = False
test_missing_keys = False
......
......@@ -17,6 +17,7 @@
import copy
import inspect
import os
import pickle
import tempfile
import unittest
......@@ -30,7 +31,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils import cached_property
from transformers.utils import cached_property, is_torch_fx_available
from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
......@@ -43,6 +44,9 @@ if is_torch_available():
from transformers import Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder
if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
def prepare_speech_to_text_inputs_dict(
config,
......@@ -271,6 +275,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
fx_compatible = True
test_pruning = False
test_missing_keys = False
......@@ -715,6 +720,105 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
self.assertTrue(models_equal)
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"input_features",
]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values", "input_features"]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
if labels is not None:
input_names.append("labels")
if start_positions is not None:
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except RuntimeError as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
for x in output:
if isinstance(x, (tuple, list)):
flatten += flatten_output(x)
elif not isinstance(x, torch.Tensor):
continue
else:
flatten.append(x)
return flatten
model_output = flatten_output(model_output)
traced_output = flatten_output(traced_output)
num_outputs = len(model_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], traced_output[i]),
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
@require_torch
@require_torchaudio
......
......@@ -179,6 +179,7 @@ class Speech2Text2StandaloneDecoderModelTester:
class Speech2Text2StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Speech2Text2Decoder, Speech2Text2ForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (Speech2Text2ForCausalLM,) if is_torch_available() else ()
fx_compatible = True
test_pruning = False
def setUp(
......
......@@ -14,7 +14,6 @@
# limitations under the License.
""" Testing suite for the PyTorch Swin model. """
import copy
import inspect
import os
import pickle
......@@ -26,7 +25,7 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
from transformers.utils import cached_property, is_torch_available, is_torch_fx_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available():
......@@ -45,14 +44,6 @@ if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
setattr(configs_no_init, key, 1e-10)
return configs_no_init
class SwinModelTester:
def __init__(
self,
......@@ -407,7 +398,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
......@@ -427,7 +420,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = filtered_inputs.keys()
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
......
......@@ -509,8 +509,8 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = True
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
fx_compatible = True
test_pruning = False
test_resize_embeddings = True
test_model_parallel = True
......
......@@ -161,6 +161,7 @@ class TrOCRStandaloneDecoderModelTester:
class TrOCRStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (TrOCRDecoder, TrOCRForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (TrOCRForCausalLM,) if is_torch_available() else ()
fx_compatible = True
test_pruning = False
def setUp(self):
......
......@@ -13,17 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import math
import os
import pickle
import tempfile
import unittest
from transformers import XGLMConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.utils import is_torch_fx_available
from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
random_attention_mask,
)
if is_torch_available():
......@@ -31,6 +40,9 @@ if is_torch_available():
from transformers import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMTokenizer
if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
class XGLMModelTester:
def __init__(
......@@ -299,6 +311,7 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (XGLMModel, XGLMForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (XGLMForCausalLM,) if is_torch_available() else ()
fx_compatible = True
test_missing_keys = False
test_pruning = False
......@@ -337,6 +350,112 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xglm_weight_initialization(*config_and_inputs)
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"input_features",
]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = [
"input_ids",
"attention_mask",
"token_type_ids",
"pixel_values",
"bbox",
"input_features",
]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
if labels is not None:
input_names.append("labels")
if start_positions is not None:
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except RuntimeError as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
for x in output:
if isinstance(x, (tuple, list)):
flatten += flatten_output(x)
elif not isinstance(x, torch.Tensor):
continue
else:
flatten.append(x)
return flatten
model_output = flatten_output(model_output)
traced_output = flatten_output(traced_output)
num_outputs = len(model_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], traced_output[i]),
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
@slow
def test_batch_generation(self):
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
......
......@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
all_generative_model_classes = (
(XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
fx_compatible = False
test_pruning = False
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
......
......@@ -738,17 +738,32 @@ class ModelTesterMixin:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
input_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"input_features",
]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
input_names = [
"input_ids",
"attention_mask",
"token_type_ids",
"pixel_values",
"bbox",
"input_features",
]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
......@@ -761,7 +776,7 @@ class ModelTesterMixin:
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = filtered_inputs.keys()
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment