Unverified Commit 3f77c26d authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix Longformer and LED (#9942)

* Fix Longformer and LED

* Add a test for graph execution with inputs_embeds

* Apply style
parent d55e10be
...@@ -1665,7 +1665,6 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1665,7 +1665,6 @@ class TFLEDEncoder(tf.keras.layers.Layer):
def compute_hidden_states(self, hidden_states, padding_len): def compute_hidden_states(self, hidden_states, padding_len):
return hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states return hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
@tf.function
def _pad_to_window_size( def _pad_to_window_size(
self, self,
input_ids, input_ids,
...@@ -1685,26 +1684,28 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1685,26 +1684,28 @@ class TFLEDEncoder(tf.keras.layers.Layer):
batch_size, seq_len = input_shape[:2] batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window padding_len = (attention_window - seq_len % attention_window) % attention_window
if padding_len > 0: if tf.math.greater(padding_len, 0):
logger.info( logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window seq_len, seq_len + padding_len, attention_window
) )
) )
paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]]) paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
if input_ids is not None: if inputs_embeds is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
if inputs_embeds is not None: def pad_embeddings():
input_ids_padding = tf.fill((batch_size, padding_len), pad_token_id) input_ids_padding = tf.fill((batch_size, padding_len), pad_token_id)
inputs_embeds_padding = self.embed_tokens(input_ids_padding) inputs_embeds_padding = self.embed_tokens(input_ids_padding)
inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds)
attention_mask = tf.pad( attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
attention_mask, paddings, constant_values=False
) # no attention on the padding tokens
return ( return (
padding_len, padding_len,
......
...@@ -1836,7 +1836,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1836,7 +1836,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
batch_size, seq_len = input_shape[:2] batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window padding_len = (attention_window - seq_len % attention_window) % attention_window
if padding_len > 0: if tf.math.greater(padding_len, 0):
logger.info( logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window seq_len, seq_len + padding_len, attention_window
...@@ -1859,7 +1859,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1859,7 +1859,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
inputs_embeds_padding = self.embeddings(input_ids_padding) inputs_embeds_padding = self.embeddings(input_ids_padding)
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
inputs_embeds = tf.cond(padding_len > 0, pad_embeddings, lambda: inputs_embeds) inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds)
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
......
...@@ -884,6 +884,35 @@ class TFModelTesterMixin: ...@@ -884,6 +884,35 @@ class TFModelTesterMixin:
model(inputs) model(inputs)
def test_graph_mode_with_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
del inputs["input_ids"]
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
if not self.is_encoder_decoder:
inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
else:
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
@tf.function
def run_in_graph_mode():
return model(inputs)
outputs = run_in_graph_mode()
self.assertIsNotNone(outputs)
def test_numpy_arrays_inputs(self): def test_numpy_arrays_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment