Unverified Commit 84162068 authored by João Gustavo A. Amorim's avatar João Gustavo A. Amorim Committed by GitHub
Browse files

apply unpack_input decorator to ViT model (#16102)

parent 62b05b69
......@@ -30,8 +30,8 @@ from ...modeling_tf_utils import (
TFPreTrainedModel,
TFSequenceClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import logging
......@@ -477,6 +477,7 @@ class TFViTMainLayer(tf.keras.layers.Layer):
"""
raise NotImplementedError
@unpack_inputs
def call(
self,
pixel_values: Optional[TFModelInputType] = None,
......@@ -488,29 +489,14 @@ class TFViTMainLayer(tf.keras.layers.Layer):
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
if inputs["pixel_values"] is None:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
embedding_output = self.embeddings(
pixel_values=inputs["pixel_values"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
training=inputs["training"],
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training,
)
# Prepare head mask if needed
......@@ -518,25 +504,25 @@ class TFViTMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None:
if head_mask is not None:
raise NotImplementedError
else:
inputs["head_mask"] = [None] * self.config.num_hidden_layers
head_mask = [None] * self.config.num_hidden_layers
encoder_outputs = self.encoder(
hidden_states=embedding_output,
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(inputs=sequence_output)
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
if not inputs["return_dict"]:
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling(
......@@ -659,6 +645,7 @@ class TFViTModel(TFViTPreTrainedModel):
self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit")
@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def call(
......@@ -692,30 +679,15 @@ class TFViTModel(TFViTPreTrainedModel):
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,
outputs = self.vit(
pixel_values=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
outputs = self.vit(
pixel_values=inputs["pixel_values"],
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
......@@ -773,6 +745,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
name="classifier",
)
@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def call(
......@@ -816,37 +789,21 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
```"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,
outputs = self.vit(
pixel_values=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
outputs = self.vit(
pixel_values=inputs["pixel_values"],
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.classifier(inputs=sequence_output[:, 0, :])
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment