Unverified Commit 1558d191 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF BART for saved model creation (#9252)



* Fix TF BART for saved model creation

* Apply style

* Update src/transformers/models/bart/modeling_tf_bart.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/bart/modeling_tf_bart.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Rework the fix

* Fix condition

* Apply style

* Fix condition

* Fix shape_list

* Apply Patrick's solution

* Apply Patrick's solution

* Rebase

* make tests pass
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarpatrickvonplaten <patrick.v.platen@gmail.com>
parent 37d6fb5d
...@@ -1356,7 +1356,7 @@ def shape_list(tensor: tf.Tensor) -> List[int]: ...@@ -1356,7 +1356,7 @@ def shape_list(tensor: tf.Tensor) -> List[int]:
dynamic = tf.shape(tensor) dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None): if tensor.shape == tf.TensorShape(None):
return dynamic.as_list() return dynamic
static = tensor.shape.as_list() static = tensor.shape.as_list()
......
...@@ -684,23 +684,21 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -684,23 +684,21 @@ class TFBartEncoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["inputs_embeds"] is None: if inputs["inputs_embeds"] is None:
inputs_embeds = self.embed_tokens(inputs["input_ids"]) inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
else: else:
inputs_embeds = inputs["inputs_embeds"] inputs["inputs_embeds"] = inputs["inputs_embeds"]
inputs_embeds = inputs_embeds * self.embed_scale inputs["inputs_embeds"] = inputs["inputs_embeds"] * self.embed_scale
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs["inputs_embeds"] + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"]) hidden_states = self.dropout(hidden_states, training=inputs["training"])
# check attention mask and invert # check attention mask and invert
if inputs["attention_mask"] is not None: if inputs["attention_mask"] is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(inputs["attention_mask"]) inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])
else:
attention_mask = None
encoder_states = () if inputs["output_hidden_states"] else None encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None all_attentions = () if inputs["output_attentions"] else None
...@@ -715,7 +713,7 @@ class TFBartEncoder(tf.keras.layers.Layer): ...@@ -715,7 +713,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask) hidden_states, attn = encoder_layer(hidden_states, inputs["attention_mask"])
if inputs["output_attentions"]: if inputs["output_attentions"]:
all_attentions += (attn,) all_attentions += (attn,)
...@@ -876,37 +874,43 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -876,37 +874,43 @@ class TFBartDecoder(tf.keras.layers.Layer):
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
if inputs_embeds is None: if inputs["inputs_embeds"] is None:
inputs_embeds = self.embed_tokens(inputs["input_ids"]) inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
else:
inputs_embeds = inputs["inputs_embeds"]
hidden_states = inputs_embeds * self.embed_scale hidden_states = inputs["inputs_embeds"] * self.embed_scale
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
else:
combined_attention_mask = _expand_mask(
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
)
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1: if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:
attention_mask = tf.cast( inputs["attention_mask"] = tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
) )
attention_mask = tf.concat( inputs["attention_mask"] = tf.concat(
[tf.ones((input_shape[0], past_key_values_length), dtype=attention_mask.dtype), attention_mask], [
tf.ones((input_shape[0], past_key_values_length), dtype=inputs["attention_mask"].dtype),
inputs["attention_mask"],
],
axis=-1, axis=-1,
) )
else: else:
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32) inputs["attention_mask"] = tf.ones(
(input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32
)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) combined_attention_mask = combined_attention_mask + _expand_mask(
inputs["attention_mask"], tgt_len=input_shape[-1]
)
encoder_hidden_states = inputs["encoder_hidden_states"] if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
if encoder_hidden_states is not None and inputs["encoder_attention_mask"] is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1])
if self.do_blenderbot_90_layernorm: if self.do_blenderbot_90_layernorm:
hidden_states = self.layernorm_embedding(hidden_states) + positions hidden_states = self.layernorm_embedding(hidden_states) + positions
...@@ -932,8 +936,8 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -932,8 +936,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
hidden_states, layer_self_attn, present_key_value = decoder_layer( hidden_states, layer_self_attn, present_key_value = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_value=past_key_value, past_key_value=past_key_value,
) )
...@@ -954,7 +958,7 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -954,7 +958,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None present_key_values = (inputs["encoder_hidden_states"], present_key_values) if inputs["use_cache"] else None
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment