"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f83d9c8da7e232b9116ced80a23a98d1f1d64bec"
Unverified Commit c603c80f authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

FX support for ConvNext, Wav2Vec2 and ResNet (#19053)

* Support for ConvNext

* Support for Wav2Vec2

* Support for Resnet

* Fix small issue in test_modeling_convnext
parent c8e40d6f
...@@ -960,7 +960,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): ...@@ -960,7 +960,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
# take argmax in non-differentiable way # take argmax in non-differentiable way
# comptute hard codevector distribution (one hot) # comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1) codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0 -1, codevector_idx.view(-1, 1), 1.0
) )
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
......
...@@ -1023,7 +1023,7 @@ class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module): ...@@ -1023,7 +1023,7 @@ class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
# take argmax in non-differentiable way # take argmax in non-differentiable way
# comptute hard codevector distribution (one hot) # comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1) codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
-1, codevector_idx.view(-1, 1), 1.0 -1, codevector_idx.view(-1, 1), 1.0
) )
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
......
...@@ -104,6 +104,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ ...@@ -104,6 +104,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"blenderbot-small", "blenderbot-small",
"bloom", "bloom",
"clip", "clip",
"convnext",
"deberta", "deberta",
"deberta-v2", "deberta-v2",
"distilbert", "distilbert",
...@@ -125,6 +126,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ ...@@ -125,6 +126,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"opt", "opt",
"pegasus", "pegasus",
"plbart", "plbart",
"resnet",
"roberta", "roberta",
"speech_to_text", "speech_to_text",
"speech_to_text_2", "speech_to_text_2",
...@@ -133,6 +135,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ ...@@ -133,6 +135,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"trocr", "trocr",
"vit", "vit",
"xglm", "xglm",
"wav2vec2",
# "xlnet", # "xlnet",
] ]
...@@ -743,7 +746,7 @@ class HFTracer(Tracer): ...@@ -743,7 +746,7 @@ class HFTracer(Tracer):
elif hasattr(model.config, "encoder"): elif hasattr(model.config, "encoder"):
image_size = model.config.encoder.image_size image_size = model.config.encoder.image_size
else: else:
raise AttributeError('Could not find the "image_size" field in the model config') image_size = (_generate_random_int(), _generate_random_int())
# If no num_channels is in the config, use some arbitrary value. # If no num_channels is in the config, use some arbitrary value.
num_channels = getattr(model.config, "num_channels", 3) num_channels = getattr(model.config, "num_channels", 3)
......
...@@ -137,6 +137,7 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -137,6 +137,7 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
else () else ()
) )
fx_compatible = True
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
......
...@@ -126,6 +126,7 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -126,6 +126,7 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (ResNetModel, ResNetForImageClassification) if is_torch_available() else () all_model_classes = (ResNetModel, ResNetForImageClassification) if is_torch_available() else ()
fx_compatible = True
test_pruning = False test_pruning = False
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
......
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
""" Testing suite for the PyTorch Wav2Vec2 model. """ """ Testing suite for the PyTorch Wav2Vec2 model. """
import math import math
import os
import pickle
import tempfile
import unittest import unittest
import numpy as np import numpy as np
...@@ -32,6 +35,7 @@ from transformers.testing_utils import ( ...@@ -32,6 +35,7 @@ from transformers.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from transformers.utils import is_torch_fx_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ( from ...test_modeling_common import (
...@@ -72,6 +76,10 @@ if is_pyctcdecode_available(): ...@@ -72,6 +76,10 @@ if is_pyctcdecode_available():
from transformers import Wav2Vec2ProcessorWithLM from transformers import Wav2Vec2ProcessorWithLM
if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
class Wav2Vec2ModelTester: class Wav2Vec2ModelTester:
def __init__( def __init__(
self, self,
...@@ -411,6 +419,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -411,6 +419,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
fx_compatible = True
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
...@@ -633,6 +642,106 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -633,6 +642,106 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsNotNone(model) self.assertIsNotNone(model)
# Wav2Vec2 cannot be torchscripted because of group norm.
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
try:
input_names = [
"attention_mask",
"bbox",
"input_features",
"input_ids",
"input_values",
"pixel_values",
"token_type_ids",
"visual_feats",
"visual_pos",
]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
if labels is not None:
input_names.append("labels")
if start_positions is not None:
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
if (
isinstance(model, Wav2Vec2ForSequenceClassification)
and not hasattr(model.config, "problem_type")
or model.config.problem_type is None
):
model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except Exception as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
for x in output:
if isinstance(x, (tuple, list)):
flatten += flatten_output(x)
elif not isinstance(x, torch.Tensor):
continue
else:
flatten.append(x)
return flatten
model_output = flatten_output(model_output)
traced_output = flatten_output(traced_output)
num_outputs = len(model_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], traced_output[i]),
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()
@require_torch @require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment