Unverified Commit 1558d191 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF BART for saved model creation (#9252)



* Fix TF BART for saved model creation

* Apply style

* Update src/transformers/models/bart/modeling_tf_bart.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/bart/modeling_tf_bart.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Rework the fix

* Fix condition

* Apply style

* Fix condition

* Fix shape_list

* Apply Patrick's solution

* Apply Patrick's solution

* Rebase

* make tests pass
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarpatrickvonplaten <patrick.v.platen@gmail.com>
parent 37d6fb5d
......@@ -1356,7 +1356,7 @@ def shape_list(tensor: tf.Tensor) -> List[int]:
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic.as_list()
return dynamic
static = tensor.shape.as_list()
......
......@@ -684,23 +684,21 @@ class TFBartEncoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["inputs_embeds"] is None:
inputs_embeds = self.embed_tokens(inputs["input_ids"])
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
else:
inputs_embeds = inputs["inputs_embeds"]
inputs["inputs_embeds"] = inputs["inputs_embeds"]
inputs_embeds = inputs_embeds * self.embed_scale
inputs["inputs_embeds"] = inputs["inputs_embeds"] * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
hidden_states = inputs["inputs_embeds"] + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
# check attention mask and invert
if inputs["attention_mask"] is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(inputs["attention_mask"])
else:
attention_mask = None
inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])
encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
......@@ -715,7 +713,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
hidden_states, attn = encoder_layer(hidden_states, inputs["attention_mask"])
if inputs["output_attentions"]:
all_attentions += (attn,)
......@@ -876,37 +874,43 @@ class TFBartDecoder(tf.keras.layers.Layer):
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(inputs["input_ids"])
else:
inputs_embeds = inputs["inputs_embeds"]
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
hidden_states = inputs_embeds * self.embed_scale
hidden_states = inputs["inputs_embeds"] * self.embed_scale
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
else:
combined_attention_mask = _expand_mask(
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
)
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:
attention_mask = tf.cast(
inputs["attention_mask"] = tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
)
attention_mask = tf.concat(
[tf.ones((input_shape[0], past_key_values_length), dtype=attention_mask.dtype), attention_mask],
inputs["attention_mask"] = tf.concat(
[
tf.ones((input_shape[0], past_key_values_length), dtype=inputs["attention_mask"].dtype),
inputs["attention_mask"],
],
axis=-1,
)
else:
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32)
inputs["attention_mask"] = tf.ones(
(input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32
)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
inputs["attention_mask"], tgt_len=input_shape[-1]
)
encoder_hidden_states = inputs["encoder_hidden_states"]
if encoder_hidden_states is not None and inputs["encoder_attention_mask"] is not None:
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1])
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1])
if self.do_blenderbot_90_layernorm:
hidden_states = self.layernorm_embedding(hidden_states) + positions
......@@ -932,8 +936,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
hidden_states, layer_self_attn, present_key_value = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
past_key_value=past_key_value,
)
......@@ -954,7 +958,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
present_key_values = (inputs["encoder_hidden_states"], present_key_values) if inputs["use_cache"] else None
if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment