Unverified Commit f06c2c2b authored by Johannes Kolbe's avatar Johannes Kolbe Committed by GitHub
Browse files

TF unpack_input decorator for convnext (#16181)



* unpack_input decorator for tf_convnext

* set unpack_input as top decorator
Co-authored-by: default avatarJohannes Kolbe <johannes.kolbe@tech.better.team>
parent d35e0c62
...@@ -28,8 +28,8 @@ from ...modeling_tf_utils import ( ...@@ -28,8 +28,8 @@ from ...modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...utils import logging from ...utils import logging
from .configuration_convnext import ConvNextConfig from .configuration_convnext import ConvNextConfig
...@@ -287,6 +287,7 @@ class TFConvNextMainLayer(tf.keras.layers.Layer): ...@@ -287,6 +287,7 @@ class TFConvNextMainLayer(tf.keras.layers.Layer):
# NCHW output format # NCHW output format
self.pooler = tf.keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None self.pooler = tf.keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None
@unpack_inputs
def call( def call(
self, self,
pixel_values: Optional[TFModelInputType] = None, pixel_values: Optional[TFModelInputType] = None,
...@@ -300,29 +301,16 @@ class TFConvNextMainLayer(tf.keras.layers.Layer): ...@@ -300,29 +301,16 @@ class TFConvNextMainLayer(tf.keras.layers.Layer):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs = input_processing( if pixel_values is None:
func=self.call,
config=self.config,
input_ids=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
if inputs["pixel_values"] is None:
raise ValueError("You have to specify pixel_values") raise ValueError("You have to specify pixel_values")
embedding_output = self.embeddings(inputs["pixel_values"], training=inputs["training"]) embedding_output = self.embeddings(pixel_values, training=training)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=inputs["training"], training=training,
) )
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs[0]
...@@ -443,6 +431,7 @@ class TFConvNextModel(TFConvNextPreTrainedModel): ...@@ -443,6 +431,7 @@ class TFConvNextModel(TFConvNextPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext") self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext")
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -478,27 +467,14 @@ class TFConvNextModel(TFConvNextPreTrainedModel): ...@@ -478,27 +467,14 @@ class TFConvNextModel(TFConvNextPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs = input_processing( if pixel_values is None:
func=self.call,
config=self.config,
input_ids=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
if inputs["pixel_values"] is None:
raise ValueError("You have to specify pixel_values") raise ValueError("You have to specify pixel_values")
outputs = self.convnext( outputs = self.convnext(
pixel_values=inputs["pixel_values"], pixel_values=pixel_values,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=inputs["training"], training=training,
) )
if not return_dict: if not return_dict:
...@@ -533,6 +509,7 @@ class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClas ...@@ -533,6 +509,7 @@ class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClas
name="classifier", name="classifier",
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -578,36 +555,22 @@ class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClas ...@@ -578,36 +555,22 @@ class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClas
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs = input_processing( if pixel_values is None:
func=self.call,
config=self.config,
input_ids=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
if inputs["pixel_values"] is None:
raise ValueError("You have to specify pixel_values") raise ValueError("You have to specify pixel_values")
outputs = self.convnext( outputs = self.convnext(
inputs["pixel_values"], pixel_values,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=inputs["training"], training=training,
) )
pooled_output = outputs.pooler_output if return_dict else outputs[1] pooled_output = outputs.pooler_output if return_dict else outputs[1]
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment