Unverified Commit 1fc832b4 authored by Matt's avatar Matt Committed by GitHub
Browse files

Make the TF dummies even smaller (#24071)

* Let's see if we can use the smallest possible dummies

* Make GPT-2's dummies a little longer

* Just use (1,2) as the default shape

* Update other dummies in sync

* Correct imports for Keras 2.13

* Shrink the Wav2Vec2 dummies
parent 092c14c3
......@@ -74,7 +74,7 @@ from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
if parse(tf.__version__).minor >= 13:
from keras import backend as K
from keras.__internal__ import KerasTensor
from keras.engine.base_layer_utils import call_context
from keras.src.engine.base_layer_utils import call_context
elif parse(tf.__version__).minor >= 11:
from keras import backend as K
from keras.engine.base_layer_utils import call_context
......@@ -1125,7 +1125,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
sig = self._prune_signature(self.input_signature)
for key, spec in sig.items():
# 2 is the most correct arbitrary size. I will not be taking questions
dummies[key] = tf.ones(shape=[dim if dim is not None else 2 for dim in spec.shape], dtype=spec.dtype)
dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
if spec.shape[0] is None:
# But let's make the batch size 1 to save memory anyway
dummy_shape[0] = 1
dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype)
if key == "token_type_ids":
# Some models have token_type_ids but with a vocab_size of 1
dummies[key] = tf.zeros_like(dummies[key])
......@@ -1133,7 +1137,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if "encoder_hidden_states" not in dummies:
if self.main_input_name == "input_ids":
dummies["encoder_hidden_states"] = tf.ones(
shape=(2, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
shape=(1, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
)
else:
raise NotImplementedError(
......
......@@ -978,7 +978,7 @@ class TFFunnelPreTrainedModel(TFPreTrainedModel):
@property
def dummy_inputs(self):
# Funnel misbehaves with very small inputs, so we override and make them a bit bigger
return {"input_ids": tf.ones((3, 3), dtype=tf.int32)}
return {"input_ids": tf.ones((1, 3), dtype=tf.int32)}
@dataclass
......
......@@ -1147,21 +1147,6 @@ class TFSamPreTrainedModel(TFPreTrainedModel):
base_model_prefix = "sam"
main_input_name = "pixel_values"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
# We override the default dummy inputs here because SAM has some really explosive memory usage in the
# attention layers, so we want to pass the smallest possible batches
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(
1,
self.config.vision_config.num_channels,
self.config.vision_config.image_size,
self.config.vision_config.image_size,
),
dtype=tf.float32,
)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
SAM_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
......
......@@ -1194,8 +1194,8 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
@property
def dummy_inputs(self):
return {
"input_values": tf.random.uniform(shape=(1, 16000), dtype=tf.float32),
"attention_mask": tf.ones(shape=(1, 16000), dtype=tf.float32),
"input_values": tf.random.uniform(shape=(1, 500), dtype=tf.float32),
"attention_mask": tf.ones(shape=(1, 500), dtype=tf.float32),
}
def __init__(self, config, *inputs, **kwargs):
......
......@@ -481,9 +481,9 @@ class TFWhisperPreTrainedModel(TFPreTrainedModel):
"""
return {
self.main_input_name: tf.random.uniform(
[2, self.config.num_mel_bins, self.config.max_source_positions * 2 - 1], dtype=tf.float32
[1, self.config.num_mel_bins, self.config.max_source_positions * 2 - 1], dtype=tf.float32
),
"decoder_input_ids": tf.constant([[2, 3]], dtype=tf.int32),
"decoder_input_ids": tf.constant([[1, 3]], dtype=tf.int32),
}
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment