"vscode:/vscode.git/clone" did not exist on "7a1c68a8454c25c55f3f8978c182ea90e3412f5c"
Unverified Commit 5c8f6010 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Fx support for Deberta-v[1-2], Hubert and LXMERT (#17539)

* Support for deberta and deberta-v2

* Support for LXMert

* Support for Hubert

* Fix for pt1.11

* Trigger CI
parent 3cab9027
......@@ -104,9 +104,9 @@ class XSoftmax(torch.autograd.Function):
@staticmethod
def forward(self, input, mask, dim):
self.dim = dim
rmask = ~(mask.bool())
rmask = ~(mask.to(torch.bool))
output = input.masked_fill(rmask, float("-inf"))
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
output = torch.softmax(output, self.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
......@@ -129,7 +129,7 @@ class XSoftmax(torch.autograd.Function):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
......@@ -152,7 +152,7 @@ def get_mask(input, local_context):
mask = local_context.mask if local_context.reuse_mask else None
if dropout > 0 and mask is None:
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
if isinstance(local_context, DropoutContext):
if local_context.mask is None:
......@@ -564,7 +564,7 @@ class DisentangledSelfAttention(nn.Module):
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
......@@ -652,7 +652,7 @@ class DisentangledSelfAttention(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
if output_attentions:
return (context_layer, attention_probs)
else:
......
......@@ -107,9 +107,9 @@ class XSoftmax(torch.autograd.Function):
@staticmethod
def forward(self, input, mask, dim):
self.dim = dim
rmask = ~(mask.bool())
rmask = ~(mask.to(torch.bool))
output = input.masked_fill(rmask, float("-inf"))
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
output = torch.softmax(output, self.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
......@@ -132,7 +132,7 @@ class XSoftmax(torch.autograd.Function):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
......@@ -157,7 +157,7 @@ def get_mask(input, local_context):
mask = local_context.mask if local_context.reuse_mask else None
if dropout > 0 and mask is None:
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
if isinstance(local_context, DropoutContext):
if local_context.mask is None:
......@@ -638,7 +638,7 @@ class DisentangledSelfAttention(nn.Module):
def transpose_for_scores(self, x, attention_heads):
new_x_shape = x.size()[:-1] + (attention_heads, -1)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
def forward(
......@@ -719,7 +719,7 @@ class DisentangledSelfAttention(nn.Module):
.contiguous()
)
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
if output_attentions:
return (context_layer, attention_probs)
else:
......
......@@ -336,7 +336,7 @@ class LxmertAttention(nn.Module):
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, context, attention_mask=None, output_attentions=False):
......@@ -365,7 +365,7 @@ class LxmertAttention(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
......@@ -1253,7 +1253,7 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
visual_prediction_scores = visual_prediction_scores_dict[key]
visual_loss = visual_loss_fct(
visual_prediction_scores.view(-1, output_dim),
label.view(*label_shape),
label.view(label_shape),
)
if visual_loss.dim() > 1: # Regression Losses
visual_loss = visual_loss.mean(1)
......
......@@ -261,7 +261,7 @@ def get_mask(input, local_context):
mask = local_context.mask if local_context.reuse_mask else None
if dropout > 0 and mask is None:
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
if isinstance(local_context, DropoutContext):
if local_context.mask is None:
......@@ -532,9 +532,9 @@ class XSoftmax(torch.autograd.Function):
@staticmethod
def forward(self, input, mask, dim):
self.dim = dim
rmask = ~(mask.bool())
rmask = ~(mask.to(torch.bool))
output = input.masked_fill(rmask, float("-inf"))
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
output = torch.softmax(output, self.dim)
output.masked_fill_(rmask, 0)
self.save_for_backward(output)
......@@ -557,7 +557,7 @@ class XSoftmax(torch.autograd.Function):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
......@@ -711,7 +711,7 @@ class DisentangledSelfAttention(nn.Module):
def transpose_for_scores(self, x, attention_heads):
new_x_shape = x.size()[:-1] + (attention_heads, -1)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
def forward(
......@@ -792,7 +792,7 @@ class DisentangledSelfAttention(nn.Module):
.contiguous()
)
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)
if output_attentions:
return (context_layer, attention_probs)
else:
......
......@@ -32,7 +32,9 @@ from torch.fx.proxy import ParameterProxy
from .. import PretrainedConfig, PreTrainedModel, logging
from ..models.auto import get_values
from ..models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_CTC_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
......@@ -72,6 +74,8 @@ def _generate_supported_model_class_names(
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
}
if supported_tasks is None:
......@@ -95,12 +99,16 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"blenderbot",
"blenderbot-small",
"clip",
"deberta",
"deberta-v2",
"distilbert",
"electra",
"gpt2",
"gpt_neo",
"gptj",
"hubert",
"layoutlm",
"lxmert",
"m2m_100",
"marian",
"mbart",
......@@ -118,8 +126,8 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"trocr",
"vit",
"xglm",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# "xlnet",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
]
_REGULAR_SUPPORTED_MODELS = []
......@@ -155,6 +163,10 @@ def torch_nn_layernorm(self, input):
return input
def torch_nn_groupnorm(self, input):
return input
def torch_nn_linear(self, input):
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
......@@ -372,6 +384,27 @@ def torch_nn_conv2d(self, input):
return torch.empty(shape, device="meta")
def torch_squeeze(input, dim=None):
shape = list(input.shape)
if dim is not None:
if dim < 0:
dim = input.dim() + dim
if shape[dim] == 1:
shape.pop(dim)
else:
new_shape = []
for dim_value in shape:
if dim_value == 1:
continue
new_shape.append(dim_value)
shape = new_shape
return torch.empty(shape, device="meta")
def torch_tensor_squeeze(self, dim=None):
return torch_squeeze(self, dim)
def torch_unsqueeze(input, dim):
shape = list(input.shape)
if dim < 0:
......@@ -446,6 +479,7 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.nn.Embedding: torch_nn_embedding,
torch.nn.functional.embedding: torch_nn_functional_embedding,
torch.nn.LayerNorm: torch_nn_layernorm,
torch.nn.GroupNorm: torch_nn_groupnorm,
torch.nn.Linear: torch_nn_linear,
torch.relu: torch_relu,
torch.nn.functional.relu: torch_nn_functional_relu,
......@@ -469,6 +503,8 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.Conv1d: torch_nn_conv1d,
torch.nn.Conv2d: torch_nn_conv2d,
torch.squeeze: torch_squeeze,
torch.Tensor.squeeze: torch_tensor_squeeze,
torch.unsqueeze: torch_unsqueeze,
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
torch.unique_consecutive: torch_unique_consecutive,
......@@ -605,7 +641,7 @@ class HFTracer(Tracer):
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True
allow_insert_stateless_mods: bool = True
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty"]
_TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"]
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
......@@ -704,8 +740,31 @@ class HFTracer(Tracer):
inputs_dict[input_name] = torch.zeros(
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
)
elif "visual_feats" in input_name:
inputs_dict[input_name] = torch.zeros(
shape
+ [
model.config.visual_feat_dim,
],
dtype=torch.float,
device=device,
)
elif "visual_pos" in input_name:
inputs_dict[input_name] = torch.zeros(
shape
+ [
model.config.visual_pos_dim,
],
dtype=torch.float,
device=device,
)
elif "inputs" in input_name:
inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
elif "input_values" in input_name:
batch_size, _ = shape
# Generating big sequence length for audio inputs.
seq_length = _generate_random_int(low=10000, high=20000)
inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
elif "mask" in input_name or "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
else:
......
......@@ -222,6 +222,7 @@ class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
fx_compatible = True
test_torchscript = False
test_pruning = False
test_head_masking = False
......
......@@ -241,6 +241,7 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
else ()
)
fx_compatible = True
test_torchscript = False
test_pruning = False
test_head_masking = False
......
......@@ -16,12 +16,16 @@
import math
import os
import pickle
import tempfile
import unittest
import pytest
from transformers import HubertConfig, is_torch_available
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
from transformers.utils import is_torch_fx_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
......@@ -45,6 +49,9 @@ if is_torch_available():
)
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
if is_torch_fx_available():
from transformers.utils.fx import symbolic_trace
class HubertModelTester:
def __init__(
......@@ -299,6 +306,7 @@ class HubertModelTester:
@require_torch
class HubertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
fx_compatible = True
test_pruning = False
test_headmasking = False
......@@ -417,6 +425,117 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# Hubert cannot be TorchScripted because of torch.nn.utils.weight_norm
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
if not is_torch_fx_available() or not self.fx_compatible:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = [
"attention_mask",
"decoder_attention_mask",
"decoder_input_ids",
"input_features",
"input_ids",
"input_values",
]
if labels is not None:
input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = [
"attention_mask",
"bbox",
"input_features",
"input_ids",
"input_values",
"pixel_values",
"token_type_ids",
"visual_feats",
"visual_pos",
]
labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None)
if labels is not None:
input_names.append("labels")
if start_positions is not None:
input_names.append("start_positions")
if end_positions is not None:
input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except Exception as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
flatten = []
for x in output:
if isinstance(x, (tuple, list)):
flatten += flatten_output(x)
elif not isinstance(x, torch.Tensor):
continue
else:
flatten.append(x)
return flatten
model_output = flatten_output(model_output)
traced_output = flatten_output(traced_output)
num_outputs = len(model_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], traced_output[i]),
f"traced {i}th output doesn't match model {i}th output for {model_class}",
)
# Test that the model can be serialized and restored properly
with tempfile.TemporaryDirectory() as tmp_dir_name:
pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
try:
with open(pkl_file_name, "wb") as f:
pickle.dump(traced_model, f)
with open(pkl_file_name, "rb") as f:
loaded = pickle.load(f)
except Exception as e:
self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
loaded_output = loaded(**filtered_inputs)
loaded_output = flatten_output(loaded_output)
for i in range(num_outputs):
self.assertTrue(
torch.allclose(model_output[i], loaded_output[i]),
f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
)
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
......
......@@ -535,6 +535,7 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (LxmertModel, LxmertForPreTraining, LxmertForQuestionAnswering) if is_torch_available() else ()
fx_compatible = True
test_head_masking = False
test_pruning = False
test_torchscript = False
......
......@@ -740,11 +740,12 @@ class ModelTesterMixin:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
labels = inputs.get("labels", None)
input_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"decoder_input_ids",
"input_features",
"input_ids",
"input_values",
]
if labels is not None:
input_names.append("labels")
......@@ -758,12 +759,15 @@ class ModelTesterMixin:
traced_output = traced_model(**filtered_inputs)
else:
input_names = [
"input_ids",
"attention_mask",
"token_type_ids",
"pixel_values",
"bbox",
"input_features",
"input_ids",
"input_values",
"pixel_values",
"token_type_ids",
"visual_feats",
"visual_pos",
]
labels = inputs.get("labels", None)
......@@ -781,10 +785,17 @@ class ModelTesterMixin:
model_output = model(**filtered_inputs)
if (
isinstance(model, tuple(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values()))
and not hasattr(model.config, "problem_type")
or model.config.problem_type is None
):
model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
except RuntimeError as e:
except Exception as e:
self.fail(f"Couldn't trace module: {e}")
def flatten_output(output):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment