Unverified Commit fe4197ab authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] Remove attention_mask and integrate model_main_input_name (#14856)

* up

* save

* correct

* up

* correct more

* up

* up

* up

* up

* up

* correct

* fix tf

* fix

* remove tokenizer
parent 86b40073
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
...@@ -349,9 +350,6 @@ BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOu ...@@ -349,9 +350,6 @@ BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOu
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
ENCODER_MODEL_INPUT_NAMES = ["input_ids", "inputs_embeds", "input_values", "input_features", "pixel_values"]
class GenerationMixin: class GenerationMixin:
""" """
A class containing all of the functions supporting generation, to be used as a mixin in A class containing all of the functions supporting generation, to be used as a mixin in
...@@ -363,58 +361,69 @@ class GenerationMixin: ...@@ -363,58 +361,69 @@ class GenerationMixin:
inputs: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None, bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str]]: ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
""" """
This function extracts the model-specific `inputs` for generation. This function extracts the model-specific `inputs` for generation.
""" """
# filter model input names that are `None` # 1. retrieve all kwargs that are non-None or non-model input related.
model_kwargs = {k: v for k, v in model_kwargs.items() if k not in ENCODER_MODEL_INPUT_NAMES or v is not None} # some encoder-decoder models have different names for model and encoder
# extract keyword arguments that are model input specific if (
model_input_kwarg_names = set(ENCODER_MODEL_INPUT_NAMES) & set(model_kwargs.keys()) self.config.is_encoder_decoder
and hasattr(self, "encoder")
# There are 5 possible scenarios and self.encoder.main_input_name != self.main_input_name
if inputs is not None and len(model_input_kwarg_names) == 0: ):
# 1. `inputs` are passed and no model-specific keyword inputs input_name = self.encoder.main_input_name
# -> return input else:
model_input_name = None input_name = self.main_input_name
return inputs, model_input_name, model_kwargs
elif inputs is not None and len(model_input_kwarg_names) > 0: model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
# 2. `inputs` are passed as well as model-specific keyword inputs
# -> not allowed, raise Error # 2. check whether model_input_name is passed as kwarg
# if yes and `inputs` is None use kwarg inputs
inputs_kwarg = model_kwargs.pop(input_name, None)
if inputs_kwarg is not None and inputs is not None:
raise ValueError( raise ValueError(
f"`inputs`: {inputs}` were passed alongside " f"`inputs`: {inputs}` were passed alongside "
f"{model_input_kwarg_names} which is not allowed." f"{input_name} which is not allowed."
f"Make sure to not pass any of {model_input_kwarg_names} " f"Make sure to either pass {inputs} or {input_name}=..."
"when `inputs` is defined." )
) elif inputs_kwarg is not None:
elif inputs is None and len(model_input_kwarg_names) == 0: inputs = inputs_kwarg
# 3. no `inputs` and no model-specific keyword inputs are passed
# -> try to create `input_ids` from BOS # 3. models with `input_ids` can also make use of `inputs_embeds`
input_tensor = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
return input_tensor, "input_ids", model_kwargs inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
elif inputs is None and len(model_input_kwarg_names) == 1:
# 4. no `inputs` are passed and exactly one model-specific keyword input # 4. Only encoder-decoder models can have non `input_ids` input format
# -> return that model-specific keyword input tensor if not self.config.is_encoder_decoder and input_name != "input_ids":
model_input_name = model_input_kwarg_names.pop()
input_tensor = model_kwargs.pop(model_input_name)
# make sure model is encoder decoder if not `input_ids`
if not self.config.is_encoder_decoder and model_input_name != "input_ids":
raise ValueError( raise ValueError(
f"If {model_input_name} is passed as model-specific keyword " f"If {input_name} is passed as model-specific keyword "
"input then model has to be an encoder-decoder and not a " "input then model has to be an encoder-decoder and not a "
f"{self.__class__.__name__}." f"{self.__class__.__name__}."
) )
return input_tensor, model_input_name, model_kwargs
else: # 5. if `inputs` is still None, try to create `input_ids` from BOS token
# 5. no `inputs` are passed and multiple model-specific keyword inputs if inputs is None:
# -> not allowed, raise Error inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
raise ValueError(
f"Can only pass one of {ENCODER_MODEL_INPUT_NAMES}, " return inputs, input_name, model_kwargs
f"but passed {model_input_kwarg_names}."
f"Make sure to only pass one of {model_input_kwarg_names}." def _can_retrieve_inputs_from_name(
self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""
If `inputs` is None and `name` is in both forward function and keyword
arguments, then inputs can be retrieved from name
"""
can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
inspect.signature(self.forward).parameters.keys()
) )
if can_retrieve_inputs and inputs is not None:
raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}")
return can_retrieve_inputs
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
""" """
Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the
...@@ -461,29 +470,22 @@ class GenerationMixin: ...@@ -461,29 +470,22 @@ class GenerationMixin:
def _prepare_encoder_decoder_kwargs_for_generation( def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if "encoder_outputs" not in model_kwargs:
# 1. get encoder # 1. get encoder
encoder = self.get_encoder() encoder = self.get_encoder()
# 2. prepare encoder args and encoder kwargs from model kwargs # 2. prepare encoder args and encoder kwargs from model kwargs
encoder_args = (inputs_tensor,)
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = { encoder_kwargs = {
argument: value argument: value
for argument, value in model_kwargs.items() for argument, value in model_kwargs.items()
if not any(argument.startswith(p) for p in irrelevant_prefix) if not any(argument.startswith(p) for p in irrelevant_prefix)
} }
# 3. make sure that encoder returns `ModelOutput` # 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
encoder_kwargs["return_dict"] = True encoder_kwargs["return_dict"] = True
# 4. if model_input_name is not defined then pass input_tensor as
# first input argument and remove from args
if model_input_name is not None:
# make sure inputs_tensor is None in case model
# accepts multiple model input arguments
encoder_kwargs[model_input_name] = inputs_tensor encoder_kwargs[model_input_name] = inputs_tensor
encoder_args = () model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
model_kwargs["encoder_outputs"]: ModelOutput = encoder(*encoder_args, **encoder_kwargs)
return model_kwargs return model_kwargs
...@@ -1013,12 +1015,13 @@ class GenerationMixin: ...@@ -1013,12 +1015,13 @@ class GenerationMixin:
model_kwargs["output_hidden_states"] = output_hidden_states model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["use_cache"] = use_cache model_kwargs["use_cache"] = use_cache
if model_kwargs.get("attention_mask", None) is None: has_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
if model_kwargs.get("attention_mask", None) is None and has_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, pad_token_id, eos_token_id inputs_tensor, pad_token_id, eos_token_id
) )
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created # if model is encoder decoder encoder_outputs are created
# and added to `model_kwargs` # and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
......
...@@ -57,8 +57,6 @@ class KerasMetricCallback(Callback): ...@@ -57,8 +57,6 @@ class KerasMetricCallback(Callback):
Validation data to be used to generate predictions for the `metric_fn`. Validation data to be used to generate predictions for the `metric_fn`.
metric_fn_kwargs (`dict`, *optional*): metric_fn_kwargs (`dict`, *optional*):
Additional keyword arguments to be passed to the metric_fn. Additional keyword arguments to be passed to the metric_fn.
tokenizer ([`PretrainedTokenizerBase`], *optional*):
Tokenizer used to validate column names to be passed to the generate() function.
output_cols (`List[str], *optional*): output_cols (`List[str], *optional*):
A list of columns to be retained from the model output as the predictions. Defaults to all. A list of columns to be retained from the model output as the predictions. Defaults to all.
label_cols ('`List[str]`, *optional*'): label_cols ('`List[str]`, *optional*'):
...@@ -75,7 +73,6 @@ class KerasMetricCallback(Callback): ...@@ -75,7 +73,6 @@ class KerasMetricCallback(Callback):
self, self,
metric_fn: Callable, metric_fn: Callable,
eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict], eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
tokenizer: Optional[PreTrainedTokenizerBase] = None,
metric_fn_kwargs: Optional[dict] = None, metric_fn_kwargs: Optional[dict] = None,
output_cols: Optional[List[str]] = None, output_cols: Optional[List[str]] = None,
label_cols: Optional[List[str]] = None, label_cols: Optional[List[str]] = None,
...@@ -97,10 +94,11 @@ class KerasMetricCallback(Callback): ...@@ -97,10 +94,11 @@ class KerasMetricCallback(Callback):
self.predict_with_generate = predict_with_generate self.predict_with_generate = predict_with_generate
self.output_cols = output_cols self.output_cols = output_cols
self.metric_fn_kwargs = metric_fn_kwargs or dict() self.metric_fn_kwargs = metric_fn_kwargs or dict()
if tokenizer is not None:
self.model_input_names = tokenizer.model_input_names if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
self.main_input_name = self.model.encoder.main_input_name
else: else:
self.model_input_names = ["input_ids"] self.main_input_name = self.model.main_input_name
# This next block attempts to parse out which elements of the dataset should be appended to the labels list # This next block attempts to parse out which elements of the dataset should be appended to the labels list
# that is passed to the metric_fn # that is passed to the metric_fn
...@@ -161,9 +159,13 @@ class KerasMetricCallback(Callback): ...@@ -161,9 +159,13 @@ class KerasMetricCallback(Callback):
labels = None labels = None
if self.predict_with_generate: if self.predict_with_generate:
if isinstance(batch, dict): if isinstance(batch, dict):
# generate() gets stressed out by any unexpected keys generation_inputs = batch[self.main_input_name]
batch = {key: array for key, array in batch.items() if key in self.model_input_names} attention_mask = batch.get("attention_mask", None)
predictions = self.model.generate(batch) else:
generation_inputs = batch
attention_mask = None
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
else: else:
predictions = self.model.predict(batch) predictions = self.model.predict(batch)
predictions = dict(predictions) predictions = dict(predictions)
......
...@@ -478,7 +478,6 @@ class DeiTModel(DeiTPreTrainedModel): ...@@ -478,7 +478,6 @@ class DeiTModel(DeiTPreTrainedModel):
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
attention_mask=None,
head_mask=None, head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
......
...@@ -69,19 +69,11 @@ SPEECH_ENCODER_DECODER_START_DOCSTRING = r""" ...@@ -69,19 +69,11 @@ SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r""" SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Args: Args:
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): inputs (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac* or *.wav* audio file
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or [`Speech2TextProcessor`] should
be used for padding and conversion into a tensor of type *torch.FloatTensor*. See be used for padding and conversion into a tensor of type *torch.FloatTensor*.
[`Wav2Vec2Processor.__call__`] for details.
input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`, *optional*):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array
into `input_features`, the [`Speech2TextTokenizer`] should be used for extracting
the fbank features, padding and conversion into a tensor of type `torch.FloatTensor`. See
[`~Speech2TextTokenizer.__call__`]
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
...@@ -137,6 +129,19 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r""" ...@@ -137,6 +129,19 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. more detail.
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should
be used for padding and conversion into a tensor of type *torch.FloatTensor*. See
[`Wav2Vec2Processor.__call__`] for details.
input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`, *optional*):
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array
into `input_features`, the [`Speech2TextTokenizer`] should be used for extracting
the fbank features, padding and conversion into a tensor of type `torch.FloatTensor`. See
[`~Speech2TextTokenizer.__call__`]
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
If set to `True`, the model will return a [`~file_utils.Seq2SeqLMOutput`] instead of a If set to `True`, the model will return a [`~file_utils.Seq2SeqLMOutput`] instead of a
plain tuple. plain tuple.
...@@ -176,7 +181,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -176,7 +181,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
""" """
config_class = SpeechEncoderDecoderConfig config_class = SpeechEncoderDecoderConfig
base_model_prefix = "speech_encoder_decoder" base_model_prefix = "speech_encoder_decoder"
main_input_name = "input_values" main_input_name = "inputs"
def __init__( def __init__(
self, self,
...@@ -417,8 +422,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -417,8 +422,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_values=None, inputs=None,
input_features=None,
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
...@@ -429,6 +433,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -429,6 +433,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
input_values=None,
input_features=None,
return_dict=None, return_dict=None,
**kwargs, **kwargs,
): ):
...@@ -463,7 +469,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -463,7 +469,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
} }
if encoder_outputs is None: if encoder_outputs is None and inputs is None:
if input_values is not None and input_features is not None: if input_values is not None and input_features is not None:
raise ValueError("You cannot specify both input_values and input_features at the same time") raise ValueError("You cannot specify both input_values and input_features at the same time")
elif input_values is not None: elif input_values is not None:
......
...@@ -507,7 +507,6 @@ class ViTModel(ViTPreTrainedModel): ...@@ -507,7 +507,6 @@ class ViTModel(ViTPreTrainedModel):
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
attention_mask=None,
head_mask=None, head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
......
...@@ -161,11 +161,17 @@ class Seq2SeqTrainer(Trainer): ...@@ -161,11 +161,17 @@ class Seq2SeqTrainer(Trainer):
"synced_gpus": True if is_deepspeed_zero3_enabled() else False, "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
} }
model_input_names = self.tokenizer.model_input_names if self.tokenizer is not None else ["input_ids"] # prepare generation inputs
generation_inputs = {k: v for k, v in inputs.items() if k in model_input_names} # some encoder-decoder models can have varying encder's and thus
# varying model input names
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
generation_inputs = inputs[self.model.encoder.main_input_name]
else:
generation_inputs = inputs[self.model.main_input_name]
generated_tokens = self.model.generate( generated_tokens = self.model.generate(
**generation_inputs, generation_inputs,
attention_mask=inputs.get("attention_mask", None),
**gen_kwargs, **gen_kwargs,
) )
# in case the batch is shorter than max length, the output should be padded # in case the batch is shorter than max length, the output should be padded
......
...@@ -1856,7 +1856,7 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1856,7 +1856,7 @@ class GenerationIntegrationTests(unittest.TestCase):
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device) model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model.generate(input_ids=input_ids, input_values=input_ids) model.generate(input_ids=input_ids, inputs_embeds=input_ids)
def test_generate_input_values_as_encoder_kwarg(self): def test_generate_input_values_as_encoder_kwarg(self):
input_values = floats_tensor((2, 250)) input_values = floats_tensor((2, 250))
......
...@@ -64,14 +64,7 @@ class EncoderDecoderMixin: ...@@ -64,14 +64,7 @@ class EncoderDecoderMixin:
pass pass
def check_encoder_decoder_model_from_pretrained_configs( def check_encoder_decoder_model_from_pretrained_configs(
self, self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
config,
attention_mask,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
pixel_values=None,
**kwargs
): ):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
self.assertTrue(encoder_decoder_config.decoder.is_decoder) self.assertTrue(encoder_decoder_config.decoder.is_decoder)
...@@ -84,7 +77,6 @@ class EncoderDecoderMixin: ...@@ -84,7 +77,6 @@ class EncoderDecoderMixin:
outputs_encoder_decoder = enc_dec_model( outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values, pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
) )
...@@ -94,14 +86,7 @@ class EncoderDecoderMixin: ...@@ -94,14 +86,7 @@ class EncoderDecoderMixin:
) )
def check_encoder_decoder_model( def check_encoder_decoder_model(
self, self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
config,
attention_mask,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
pixel_values=None,
**kwargs
): ):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
...@@ -111,7 +96,6 @@ class EncoderDecoderMixin: ...@@ -111,7 +96,6 @@ class EncoderDecoderMixin:
enc_dec_model.to(torch_device) enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model( outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values, pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
output_hidden_states=True, output_hidden_states=True,
...@@ -122,7 +106,6 @@ class EncoderDecoderMixin: ...@@ -122,7 +106,6 @@ class EncoderDecoderMixin:
encoder_outputs = BaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1]) encoder_outputs = BaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1])
outputs_encoder_decoder = enc_dec_model( outputs_encoder_decoder = enc_dec_model(
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
) )
...@@ -134,7 +117,6 @@ class EncoderDecoderMixin: ...@@ -134,7 +117,6 @@ class EncoderDecoderMixin:
def check_encoder_decoder_model_from_pretrained( def check_encoder_decoder_model_from_pretrained(
self, self,
config, config,
attention_mask,
decoder_config, decoder_config,
decoder_input_ids, decoder_input_ids,
decoder_attention_mask, decoder_attention_mask,
...@@ -148,7 +130,6 @@ class EncoderDecoderMixin: ...@@ -148,7 +130,6 @@ class EncoderDecoderMixin:
enc_dec_model.to(torch_device) enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model( outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values, pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
output_hidden_states=True, output_hidden_states=True,
...@@ -160,14 +141,7 @@ class EncoderDecoderMixin: ...@@ -160,14 +141,7 @@ class EncoderDecoderMixin:
) )
def check_save_and_load( def check_save_and_load(
self, self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
config,
attention_mask,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
pixel_values=None,
**kwargs
): ):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
...@@ -176,7 +150,6 @@ class EncoderDecoderMixin: ...@@ -176,7 +150,6 @@ class EncoderDecoderMixin:
with torch.no_grad(): with torch.no_grad():
outputs = enc_dec_model( outputs = enc_dec_model(
pixel_values=pixel_values, pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
) )
...@@ -190,7 +163,6 @@ class EncoderDecoderMixin: ...@@ -190,7 +163,6 @@ class EncoderDecoderMixin:
after_outputs = enc_dec_model( after_outputs = enc_dec_model(
pixel_values=pixel_values, pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
) )
...@@ -200,14 +172,7 @@ class EncoderDecoderMixin: ...@@ -200,14 +172,7 @@ class EncoderDecoderMixin:
self.assertLessEqual(max_diff, 1e-5) self.assertLessEqual(max_diff, 1e-5)
def check_save_and_load_encoder_decoder_model( def check_save_and_load_encoder_decoder_model(
self, self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
config,
attention_mask,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
pixel_values=None,
**kwargs
): ):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
...@@ -216,7 +181,6 @@ class EncoderDecoderMixin: ...@@ -216,7 +181,6 @@ class EncoderDecoderMixin:
with torch.no_grad(): with torch.no_grad():
outputs = enc_dec_model( outputs = enc_dec_model(
pixel_values=pixel_values, pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
) )
...@@ -233,7 +197,6 @@ class EncoderDecoderMixin: ...@@ -233,7 +197,6 @@ class EncoderDecoderMixin:
after_outputs = enc_dec_model( after_outputs = enc_dec_model(
pixel_values=pixel_values, pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
) )
...@@ -245,7 +208,6 @@ class EncoderDecoderMixin: ...@@ -245,7 +208,6 @@ class EncoderDecoderMixin:
def check_encoder_decoder_model_output_attentions( def check_encoder_decoder_model_output_attentions(
self, self,
config, config,
attention_mask,
decoder_config, decoder_config,
decoder_input_ids, decoder_input_ids,
decoder_attention_mask, decoder_attention_mask,
...@@ -261,7 +223,6 @@ class EncoderDecoderMixin: ...@@ -261,7 +223,6 @@ class EncoderDecoderMixin:
enc_dec_model.to(torch_device) enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model( outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values, pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
output_attentions=True, output_attentions=True,
...@@ -382,13 +343,10 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -382,13 +343,10 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
] ]
) )
# for DEiT, the sequence length is equal to the number of patches + 2 (for the [CLS] and distillation tokens) # for DEiT, the sequence length is equal to the number of patches + 2 (for the [CLS] and distillation tokens)
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 2
attention_mask = random_attention_mask([batch_size, seq_len])
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size) decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4]) decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs = { inputs = {
"pixel_values": pixel_values, "pixel_values": pixel_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
} }
...@@ -398,7 +356,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -398,7 +356,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
def check_encoder_decoder_model_output_attentions( def check_encoder_decoder_model_output_attentions(
self, self,
config, config,
attention_mask,
decoder_config, decoder_config,
decoder_input_ids, decoder_input_ids,
decoder_attention_mask, decoder_attention_mask,
...@@ -414,7 +371,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -414,7 +371,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
enc_dec_model.to(torch_device) enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model( outputs_encoder_decoder = enc_dec_model(
pixel_values=pixel_values, pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
output_attentions=True, output_attentions=True,
...@@ -463,7 +419,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -463,7 +419,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
encoder_config_and_inputs = deit_model_tester.prepare_config_and_inputs() encoder_config_and_inputs = deit_model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder() decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
config, pixel_values, _ = encoder_config_and_inputs config, pixel_values, _ = encoder_config_and_inputs
input_mask = None # TODO add once attention_mask is supported for vision models
( (
decoder_config, decoder_config,
decoder_input_ids, decoder_input_ids,
...@@ -481,7 +436,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -481,7 +436,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
return { return {
"config": config, "config": config,
"pixel_values": pixel_values, "pixel_values": pixel_values,
"attention_mask": input_mask,
"decoder_config": decoder_config, "decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids, "decoder_token_type_ids": decoder_token_type_ids,
...@@ -509,13 +463,10 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -509,13 +463,10 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
] ]
) )
# for ViT, the sequence length is equal to the number of patches + 1 (for the [CLS] token) # for ViT, the sequence length is equal to the number of patches + 1 (for the [CLS] token)
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 1
attention_mask = random_attention_mask([batch_size, seq_len])
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size) decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4]) decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs = { inputs = {
"pixel_values": pixel_values, "pixel_values": pixel_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
} }
...@@ -534,7 +485,6 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -534,7 +485,6 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder() decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
config, pixel_values, _ = encoder_config_and_inputs config, pixel_values, _ = encoder_config_and_inputs
input_mask = None # TODO add once attention_mask is supported for vision models
( (
decoder_config, decoder_config,
...@@ -553,7 +503,6 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -553,7 +503,6 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
return { return {
"config": config, "config": config,
"pixel_values": pixel_values, "pixel_values": pixel_values,
"attention_mask": input_mask,
"decoder_config": decoder_config, "decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids, "decoder_token_type_ids": decoder_token_type_ids,
...@@ -580,7 +529,6 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase): ...@@ -580,7 +529,6 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs() encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs() decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
config, pixel_values, _ = encoder_config_and_inputs config, pixel_values, _ = encoder_config_and_inputs
input_mask = None # TODO add once attention_mask is supported for vision models
(decoder_config, decoder_input_ids, decoder_attention_mask, _) = decoder_config_and_inputs (decoder_config, decoder_input_ids, decoder_attention_mask, _) = decoder_config_and_inputs
# make sure that cross attention layers are added # make sure that cross attention layers are added
...@@ -590,7 +538,6 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase): ...@@ -590,7 +538,6 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
return { return {
"config": config, "config": config,
"pixel_values": pixel_values, "pixel_values": pixel_values,
"attention_mask": input_mask,
"decoder_config": decoder_config, "decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment