Unverified Commit f06c2c2b authored by Johannes Kolbe's avatar Johannes Kolbe Committed by GitHub
Browse files

TF unpack_input decorator for convnext (#16181)



* unpack_input decorator for tf_convnext

* set unpack_input as top decorator
Co-authored-by: default avatarJohannes Kolbe <johannes.kolbe@tech.better.team>
parent d35e0c62
......@@ -28,8 +28,8 @@ from ...modeling_tf_utils import (
TFPreTrainedModel,
TFSequenceClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...utils import logging
from .configuration_convnext import ConvNextConfig
......@@ -287,6 +287,7 @@ class TFConvNextMainLayer(tf.keras.layers.Layer):
# NCHW output format
self.pooler = tf.keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None
@unpack_inputs
def call(
self,
pixel_values: Optional[TFModelInputType] = None,
......@@ -300,29 +301,16 @@ class TFConvNextMainLayer(tf.keras.layers.Layer):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
if inputs["pixel_values"] is None:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
embedding_output = self.embeddings(inputs["pixel_values"], training=inputs["training"])
embedding_output = self.embeddings(pixel_values, training=training)
encoder_outputs = self.encoder(
embedding_output,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
training=training,
)
last_hidden_state = encoder_outputs[0]
......@@ -443,6 +431,7 @@ class TFConvNextModel(TFConvNextPreTrainedModel):
super().__init__(config, *inputs, **kwargs)
self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext")
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def call(
......@@ -478,27 +467,14 @@ class TFConvNextModel(TFConvNextPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
if inputs["pixel_values"] is None:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
outputs = self.convnext(
pixel_values=inputs["pixel_values"],
pixel_values=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
training=training,
)
if not return_dict:
......@@ -533,6 +509,7 @@ class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClas
name="classifier",
)
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def call(
......@@ -578,36 +555,22 @@ class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClas
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
if inputs["pixel_values"] is None:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
outputs = self.convnext(
inputs["pixel_values"],
pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=inputs["training"],
training=training,
)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
logits = self.classifier(pooled_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment