Unverified Commit 84162068 authored by João Gustavo A. Amorim's avatar João Gustavo A. Amorim Committed by GitHub
Browse files

apply unpack_input decorator to ViT model (#16102)

parent 62b05b69
...@@ -30,8 +30,8 @@ from ...modeling_tf_utils import ( ...@@ -30,8 +30,8 @@ from ...modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
...@@ -477,6 +477,7 @@ class TFViTMainLayer(tf.keras.layers.Layer): ...@@ -477,6 +477,7 @@ class TFViTMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
@unpack_inputs
def call( def call(
self, self,
pixel_values: Optional[TFModelInputType] = None, pixel_values: Optional[TFModelInputType] = None,
...@@ -488,29 +489,14 @@ class TFViTMainLayer(tf.keras.layers.Layer): ...@@ -488,29 +489,14 @@ class TFViTMainLayer(tf.keras.layers.Layer):
training: bool = False, training: bool = False,
**kwargs, **kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
if inputs["pixel_values"] is None: if pixel_values is None:
raise ValueError("You have to specify pixel_values") raise ValueError("You have to specify pixel_values")
embedding_output = self.embeddings( embedding_output = self.embeddings(
pixel_values=inputs["pixel_values"], pixel_values=pixel_values,
interpolate_pos_encoding=inputs["interpolate_pos_encoding"], interpolate_pos_encoding=interpolate_pos_encoding,
training=inputs["training"], training=training,
) )
# Prepare head mask if needed # Prepare head mask if needed
...@@ -518,25 +504,25 @@ class TFViTMainLayer(tf.keras.layers.Layer): ...@@ -518,25 +504,25 @@ class TFViTMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None: if head_mask is not None:
raise NotImplementedError raise NotImplementedError
else: else:
inputs["head_mask"] = [None] * self.config.num_hidden_layers head_mask = [None] * self.config.num_hidden_layers
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
hidden_states=embedding_output, hidden_states=embedding_output,
head_mask=inputs["head_mask"], head_mask=head_mask,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(inputs=sequence_output) sequence_output = self.layernorm(inputs=sequence_output)
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
if not inputs["return_dict"]: if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:] return (sequence_output, pooled_output) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling( return TFBaseModelOutputWithPooling(
...@@ -659,6 +645,7 @@ class TFViTModel(TFViTPreTrainedModel): ...@@ -659,6 +645,7 @@ class TFViTModel(TFViTPreTrainedModel):
self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit") self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit")
@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -692,30 +679,15 @@ class TFViTModel(TFViTPreTrainedModel): ...@@ -692,30 +679,15 @@ class TFViTModel(TFViTPreTrainedModel):
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state >>> last_hidden_states = outputs.last_hidden_state
```""" ```"""
inputs = input_processing(
func=self.call, outputs = self.vit(
config=self.config, pixel_values=pixel_values,
input_ids=pixel_values,
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding, interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
outputs = self.vit(
pixel_values=inputs["pixel_values"],
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -773,6 +745,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification ...@@ -773,6 +745,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
name="classifier", name="classifier",
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -816,37 +789,21 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification ...@@ -816,37 +789,21 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0] >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)]) >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
```""" ```"""
inputs = input_processing(
func=self.call, outputs = self.vit(
config=self.config, pixel_values=pixel_values,
input_ids=pixel_values,
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding, interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
outputs = self.vit(
pixel_values=inputs["pixel_values"],
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.classifier(inputs=sequence_output[:, 0, :]) logits = self.classifier(inputs=sequence_output[:, 0, :])
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment