Unverified Commit 7203ea67 authored by Matt's avatar Matt Committed by GitHub
Browse files

Reduce memory usage in TF building (#24046)

* Make the default dummies (2, 2) instead of (3, 3)

* Fix for Funnel

* Actually fix Funnel
parent 072188d6
...@@ -1116,8 +1116,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1116,8 +1116,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
dummies = {} dummies = {}
sig = self._prune_signature(self.input_signature) sig = self._prune_signature(self.input_signature)
for key, spec in sig.items(): for key, spec in sig.items():
# 3 is the most correct arbitrary size. I will not be taking questions # 2 is the most correct arbitrary size. I will not be taking questions
dummies[key] = tf.ones(shape=[dim if dim is not None else 3 for dim in spec.shape], dtype=spec.dtype) dummies[key] = tf.ones(shape=[dim if dim is not None else 2 for dim in spec.shape], dtype=spec.dtype)
if key == "token_type_ids": if key == "token_type_ids":
# Some models have token_type_ids but with a vocab_size of 1 # Some models have token_type_ids but with a vocab_size of 1
dummies[key] = tf.zeros_like(dummies[key]) dummies[key] = tf.zeros_like(dummies[key])
...@@ -1125,7 +1125,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1125,7 +1125,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if "encoder_hidden_states" not in dummies: if "encoder_hidden_states" not in dummies:
if self.main_input_name == "input_ids": if self.main_input_name == "input_ids":
dummies["encoder_hidden_states"] = tf.ones( dummies["encoder_hidden_states"] = tf.ones(
shape=(3, 3, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states" shape=(2, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -242,6 +242,7 @@ class TFFunnelAttentionStructure: ...@@ -242,6 +242,7 @@ class TFFunnelAttentionStructure:
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model)) # rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype) rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
rel_pos = rel_pos + zero_offset rel_pos = rel_pos + zero_offset
tf.debugging.assert_less(rel_pos, tf.shape(pos_embed)[0])
position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0) position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0)
position_embeds_list.append([position_embeds_no_pooling, position_embeds_pooling]) position_embeds_list.append([position_embeds_no_pooling, position_embeds_pooling])
...@@ -974,6 +975,11 @@ class TFFunnelPreTrainedModel(TFPreTrainedModel): ...@@ -974,6 +975,11 @@ class TFFunnelPreTrainedModel(TFPreTrainedModel):
config_class = FunnelConfig config_class = FunnelConfig
base_model_prefix = "funnel" base_model_prefix = "funnel"
@property
def dummy_inputs(self):
# Funnel misbehaves with very small inputs, so we override and make them a bit bigger
return {"input_ids": tf.ones((3, 3), dtype=tf.int32)}
@dataclass @dataclass
class TFFunnelForPreTrainingOutput(ModelOutput): class TFFunnelForPreTrainingOutput(ModelOutput):
...@@ -1424,6 +1430,10 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1424,6 +1430,10 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
self.funnel = TFFunnelBaseLayer(config, name="funnel") self.funnel = TFFunnelBaseLayer(config, name="funnel")
self.classifier = TFFunnelClassificationHead(config, 1, name="classifier") self.classifier = TFFunnelClassificationHead(config, 1, name="classifier")
@property
def dummy_inputs(self):
return {"input_ids": tf.ones((3, 3, 4), dtype=tf.int32)}
@unpack_inputs @unpack_inputs
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment