Unverified Commit 4a55e478 authored by Matt's avatar Matt Committed by GitHub
Browse files

Move TF building to an actual build() method (#23760)

* A fun new PR where I break the entire codebase again

* A fun new PR where I break the entire codebase again

* Handle cross-attention

* Move calls to model(model.dummy_inputs) to the new build() method

* Seeing what fails with the build context thing

* make fix-copies

* Let's see what fails with new build methods

* Fix the pytorch crossload build calls

* Fix the overridden build methods in vision_text_dual_encoder

* Make sure all our build methods set self.built or call super().build(), which also sets it

* make fix-copies

* Remove finished TODO

* Tentatively remove unneeded (?) line

* Transpose b in deberta correctly and remove unused threading local

* Get rid of build_with_dummies and all it stands for

* Rollback some changes to TF-PT crossloading

* Correctly call super().build()
parent cbf6bc23
...@@ -341,9 +341,6 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -341,9 +341,6 @@ def load_pytorch_state_dict_in_tf2_model(
K.batch_set_value(weight_value_tuples) K.batch_set_value(weight_value_tuples)
if tf_inputs is not None:
tf_model(tf_inputs, training=False) # Make sure restore ops are run
logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.") logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")
unexpected_keys = list(all_pytorch_weights) unexpected_keys = list(all_pytorch_weights)
......
...@@ -40,7 +40,12 @@ from .activations_tf import get_tf_activation ...@@ -40,7 +40,12 @@ from .activations_tf import get_tf_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .generation import GenerationConfig, TFGenerationMixin from .generation import GenerationConfig, TFGenerationMixin
from .tf_utils import expand_1d, load_attributes_from_hdf5_group, save_attributes_to_hdf5_group, shape_list from .tf_utils import (
expand_1d,
load_attributes_from_hdf5_group,
save_attributes_to_hdf5_group,
shape_list,
)
from .utils import ( from .utils import (
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_NAME,
...@@ -69,11 +74,14 @@ from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files ...@@ -69,11 +74,14 @@ from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
if parse(tf.__version__).minor >= 13: if parse(tf.__version__).minor >= 13:
from keras import backend as K from keras import backend as K
from keras.__internal__ import KerasTensor from keras.__internal__ import KerasTensor
from keras.engine.base_layer_utils import call_context
elif parse(tf.__version__).minor >= 11: elif parse(tf.__version__).minor >= 11:
from keras import backend as K from keras import backend as K
from keras.engine.base_layer_utils import call_context
from keras.engine.keras_tensor import KerasTensor from keras.engine.keras_tensor import KerasTensor
else: else:
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import call_context
from tensorflow.python.keras.engine.keras_tensor import KerasTensor from tensorflow.python.keras.engine.keras_tensor import KerasTensor
...@@ -1140,6 +1148,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1140,6 +1148,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
""" """
return "tf" return "tf"
def build(self, input_shape=None):
if self.built or call_context().in_call:
self.built = True
else:
self(self.dummy_inputs, training=False)
self.built = True
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
...@@ -1867,7 +1882,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1867,7 +1882,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
main_layer.set_input_embeddings(value) main_layer.set_input_embeddings(value)
except AttributeError: except AttributeError:
logger.info("Building the model") logger.info("Building the model")
self(self.dummy_inputs) self.build()
main_layer.set_input_embeddings(value) main_layer.set_input_embeddings(value)
def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]: def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
...@@ -1884,7 +1899,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1884,7 +1899,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return lm_head.get_output_embeddings() return lm_head.get_output_embeddings()
except AttributeError: except AttributeError:
logger.info("Building the model") logger.info("Building the model")
self(self.dummy_inputs) self.build()
return lm_head().get_output_embeddings() return lm_head().get_output_embeddings()
...@@ -1904,7 +1919,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1904,7 +1919,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
lm_head.set_output_embeddings(value) lm_head.set_output_embeddings(value)
except AttributeError: except AttributeError:
logger.info("Building the model") logger.info("Building the model")
self(self.dummy_inputs) self.build()
lm_head.set_output_embeddings(value) lm_head.set_output_embeddings(value)
def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]: def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
...@@ -1942,7 +1957,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1942,7 +1957,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
try: try:
return lm_head.get_bias() return lm_head.get_bias()
except AttributeError: except AttributeError:
self(self.dummy_inputs) self.build()
return lm_head.get_bias() return lm_head.get_bias()
return None return None
...@@ -1960,7 +1975,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1960,7 +1975,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
try: try:
lm_head.set_bias(value) lm_head.set_bias(value)
except AttributeError: except AttributeError:
self(self.dummy_inputs) self.build()
lm_head.set_bias(value) lm_head.set_bias(value)
def get_lm_head(self) -> tf.keras.layers.Layer: def get_lm_head(self) -> tf.keras.layers.Layer:
...@@ -2047,7 +2062,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2047,7 +2062,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# The reason why the attributes don't exist might be # The reason why the attributes don't exist might be
# because the model is not built, so retry getting # because the model is not built, so retry getting
# the argument after building the model # the argument after building the model
model(model.dummy_inputs) model.build()
embeds = getattr(embedding_layer, "weight", None) embeds = getattr(embedding_layer, "weight", None)
if embeds is not None: if embeds is not None:
...@@ -2870,9 +2885,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2870,9 +2885,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# we might need to extend the variable scope for composite models # we might need to extend the variable scope for composite models
if load_weight_prefix is not None: if load_weight_prefix is not None:
with tf.compat.v1.variable_scope(load_weight_prefix): with tf.compat.v1.variable_scope(load_weight_prefix):
model(model.dummy_inputs) # build the network with dummy inputs model.build() # build the network with dummy inputs
else: else:
model(model.dummy_inputs) # build the network with dummy inputs model.build() # build the network with dummy inputs
if safetensors_from_pt: if safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
...@@ -2925,8 +2940,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2925,8 +2940,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. " "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
) )
model(model.dummy_inputs) # Make sure restore ops are run
if cls._keys_to_ignore_on_load_missing is not None: if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls._keys_to_ignore_on_load_missing: for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None] missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
......
...@@ -258,6 +258,7 @@ class TFBlipVisionEmbeddings(tf.keras.layers.Layer): ...@@ -258,6 +258,7 @@ class TFBlipVisionEmbeddings(tf.keras.layers.Layer):
trainable=True, trainable=True,
name="position_embedding", name="position_embedding",
) )
super().build(input_shape)
def call(self, pixel_values: tf.Tensor) -> tf.Tensor: def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
# Input is channels-first, we transpose. PyTorch transposes after the conv because PyTorch # Input is channels-first, we transpose. PyTorch transposes after the conv because PyTorch
...@@ -282,7 +283,7 @@ class TFBlipTextEmbeddings(tf.keras.layers.Layer): ...@@ -282,7 +283,7 @@ class TFBlipTextEmbeddings(tf.keras.layers.Layer):
self.config = config self.config = config
def build(self, input_shape: tf.TensorShape): def build(self, input_shape: tf.TensorShape = None):
with tf.name_scope("token_embedding"): with tf.name_scope("token_embedding"):
self.weight = self.add_weight( self.weight = self.add_weight(
shape=(self.config.vocab_size, self.embed_dim), shape=(self.config.vocab_size, self.embed_dim),
...@@ -757,13 +758,14 @@ class TFBlipMainLayer(tf.keras.layers.Layer): ...@@ -757,13 +758,14 @@ class TFBlipMainLayer(tf.keras.layers.Layer):
self.config = config self.config = config
def build(self, input_shape): def build(self, input_shape=None):
self.logit_scale = self.add_weight( self.logit_scale = self.add_weight(
name="logit_scale", name="logit_scale",
shape=[], shape=[],
initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value), initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value),
trainable=True, trainable=True,
) )
super().build(input_shape)
@unpack_inputs @unpack_inputs
def call( def call(
......
...@@ -543,8 +543,9 @@ class TFBlipTextLMPredictionHead(tf.keras.layers.Layer): ...@@ -543,8 +543,9 @@ class TFBlipTextLMPredictionHead(tf.keras.layers.Layer):
) )
self.config = config self.config = config
def build(self, input_shape): def build(self, input_shape=None):
self.bias = self.add_weight(name="bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True) self.bias = self.add_weight(name="bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
super().build(input_shape)
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.transform(hidden_states) hidden_states = self.transform(hidden_states)
......
...@@ -151,7 +151,7 @@ class TFCLIPVisionEmbeddings(tf.keras.layers.Layer): ...@@ -151,7 +151,7 @@ class TFCLIPVisionEmbeddings(tf.keras.layers.Layer):
name="patch_embedding", name="patch_embedding",
) )
def build(self, input_shape: tf.TensorShape): def build(self, input_shape: tf.TensorShape = None):
factor = self.config.initializer_factor factor = self.config.initializer_factor
self.class_embedding = self.add_weight( self.class_embedding = self.add_weight(
...@@ -204,7 +204,7 @@ class TFCLIPTextEmbeddings(tf.keras.layers.Layer): ...@@ -204,7 +204,7 @@ class TFCLIPTextEmbeddings(tf.keras.layers.Layer):
self.config = config self.config = config
def build(self, input_shape: tf.TensorShape): def build(self, input_shape: tf.TensorShape = None):
with tf.name_scope("token_embedding"): with tf.name_scope("token_embedding"):
self.weight = self.add_weight( self.weight = self.add_weight(
shape=(self.config.vocab_size, self.embed_dim), shape=(self.config.vocab_size, self.embed_dim),
...@@ -739,7 +739,7 @@ class TFCLIPMainLayer(tf.keras.layers.Layer): ...@@ -739,7 +739,7 @@ class TFCLIPMainLayer(tf.keras.layers.Layer):
name="text_projection", name="text_projection",
) )
def build(self, input_shape: tf.TensorShape): def build(self, input_shape: tf.TensorShape = None):
self.logit_scale = self.add_weight( self.logit_scale = self.add_weight(
shape=(1,), shape=(1,),
initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value), initializer=tf.keras.initializers.Constant(self.config.logit_scale_init_value),
......
...@@ -346,7 +346,7 @@ class GroupedLinearLayer(tf.keras.layers.Layer): ...@@ -346,7 +346,7 @@ class GroupedLinearLayer(tf.keras.layers.Layer):
self.group_in_dim = self.input_size // self.num_groups self.group_in_dim = self.input_size // self.num_groups
self.group_out_dim = self.output_size // self.num_groups self.group_out_dim = self.output_size // self.num_groups
def build(self, input_shape): def build(self, input_shape=None):
self.kernel = self.add_weight( self.kernel = self.add_weight(
"kernel", "kernel",
shape=[self.group_out_dim, self.group_in_dim, self.num_groups], shape=[self.group_out_dim, self.group_in_dim, self.num_groups],
...@@ -357,6 +357,7 @@ class GroupedLinearLayer(tf.keras.layers.Layer): ...@@ -357,6 +357,7 @@ class GroupedLinearLayer(tf.keras.layers.Layer):
self.bias = self.add_weight( self.bias = self.add_weight(
"bias", shape=[self.output_size], initializer=self.kernel_initializer, dtype=self.dtype, trainable=True "bias", shape=[self.output_size], initializer=self.kernel_initializer, dtype=self.dtype, trainable=True
) )
super().build(input_shape)
def call(self, hidden_states): def call(self, hidden_states):
batch_size = shape_list(hidden_states)[0] batch_size = shape_list(hidden_states)[0]
......
...@@ -155,7 +155,7 @@ class TFConvNextLayer(tf.keras.layers.Layer): ...@@ -155,7 +155,7 @@ class TFConvNextLayer(tf.keras.layers.Layer):
else tf.keras.layers.Activation("linear", name="drop_path") else tf.keras.layers.Activation("linear", name="drop_path")
) )
def build(self, input_shape: tf.TensorShape): def build(self, input_shape: tf.TensorShape = None):
# PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa) # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
self.layer_scale_parameter = ( self.layer_scale_parameter = (
self.add_weight( self.add_weight(
......
...@@ -576,7 +576,7 @@ class TFCTRLLMHead(tf.keras.layers.Layer): ...@@ -576,7 +576,7 @@ class TFCTRLLMHead(tf.keras.layers.Layer):
# an output-only bias for each token. # an output-only bias for each token.
self.input_embeddings = input_embeddings self.input_embeddings = input_embeddings
def build(self, input_shape): def build(self, input_shape=None):
self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape) super().build(input_shape)
......
...@@ -464,7 +464,7 @@ class TFData2VecVisionLayer(tf.keras.layers.Layer): ...@@ -464,7 +464,7 @@ class TFData2VecVisionLayer(tf.keras.layers.Layer):
) )
self.init_values = config.layer_scale_init_value self.init_values = config.layer_scale_init_value
def build(self, input_shape: tf.TensorShape): def build(self, input_shape: tf.TensorShape = None):
if self.init_values > 0: if self.init_values > 0:
self.lambda_1 = self.add_weight( self.lambda_1 = self.add_weight(
shape=(self.config.hidden_size), shape=(self.config.hidden_size),
......
...@@ -593,11 +593,10 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -593,11 +593,10 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer):
else: else:
def linear(w, b, x): def linear(w, b, x):
return tf.cond( out = tf.matmul(x, w, transpose_b=True)
b is not None, if b is not None:
lambda: tf.matmul(x, w, transpose_b=True) + tf.transpose(b), out += tf.transpose(b)
lambda: tf.matmul(x, w, transpose_b=True), return out
)
ws = tf.split( ws = tf.split(
tf.transpose(self.in_proj.weight[0]), num_or_size_splits=self.num_attention_heads * 3, axis=0 tf.transpose(self.in_proj.weight[0]), num_or_size_splits=self.num_attention_heads * 3, axis=0
......
...@@ -532,7 +532,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): ...@@ -532,7 +532,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
try: try:
return self.ctx_encoder.bert_model.get_input_embeddings() return self.ctx_encoder.bert_model.get_input_embeddings()
except AttributeError: except AttributeError:
self(self.dummy_inputs) self.build()
return self.ctx_encoder.bert_model.get_input_embeddings() return self.ctx_encoder.bert_model.get_input_embeddings()
@unpack_inputs @unpack_inputs
...@@ -613,7 +613,7 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder): ...@@ -613,7 +613,7 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
try: try:
return self.question_encoder.bert_model.get_input_embeddings() return self.question_encoder.bert_model.get_input_embeddings()
except AttributeError: except AttributeError:
self(self.dummy_inputs) self.build()
return self.question_encoder.bert_model.get_input_embeddings() return self.question_encoder.bert_model.get_input_embeddings()
@unpack_inputs @unpack_inputs
...@@ -693,7 +693,7 @@ class TFDPRReader(TFDPRPretrainedReader): ...@@ -693,7 +693,7 @@ class TFDPRReader(TFDPRPretrainedReader):
try: try:
return self.span_predictor.encoder.bert_model.get_input_embeddings() return self.span_predictor.encoder.bert_model.get_input_embeddings()
except AttributeError: except AttributeError:
self(self.dummy_inputs) self.build()
return self.span_predictor.encoder.bert_model.get_input_embeddings() return self.span_predictor.encoder.bert_model.get_input_embeddings()
@unpack_inputs @unpack_inputs
......
...@@ -538,7 +538,7 @@ class TFGroupViTTextEmbeddings(tf.keras.layers.Layer): ...@@ -538,7 +538,7 @@ class TFGroupViTTextEmbeddings(tf.keras.layers.Layer):
self.config = config self.config = config
def build(self, input_shape: tf.TensorShape): def build(self, input_shape: tf.TensorShape = None):
with tf.name_scope("token_embedding"): with tf.name_scope("token_embedding"):
self.weight = self.add_weight( self.weight = self.add_weight(
shape=(self.config.vocab_size, self.embed_dim), shape=(self.config.vocab_size, self.embed_dim),
......
...@@ -135,6 +135,7 @@ class TFLEDLearnedPositionalEmbedding(tf.keras.layers.Embedding): ...@@ -135,6 +135,7 @@ class TFLEDLearnedPositionalEmbedding(tf.keras.layers.Embedding):
class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, layer_id, **kwargs): def __init__(self, config, layer_id, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
if config.hidden_size % config.num_attention_heads != 0: if config.hidden_size % config.num_attention_heads != 0:
raise ValueError( raise ValueError(
...@@ -191,6 +192,16 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -191,6 +192,16 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
self.one_sided_attn_window_size = attention_window // 2 self.one_sided_attn_window_size = attention_window // 2
def build(self, input_shape=None):
if not self.built:
with tf.name_scope("query_global"):
self.query_global.build((self.config.hidden_size,))
with tf.name_scope("key_global"):
self.key_global.build((self.config.hidden_size,))
with tf.name_scope("value_global"):
self.value_global.build((self.config.hidden_size,))
super().build(input_shape)
def call( def call(
self, self,
inputs, inputs,
...@@ -271,9 +282,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -271,9 +282,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
) = self._get_global_attn_indices(is_index_global_attn) ) = self._get_global_attn_indices(is_index_global_attn)
# this function is only relevant for global attention # this function is only relevant for global attention
attn_scores = tf.cond( if is_global_attn:
is_global_attn, attn_scores = self._concat_with_global_key_attn_probs(
lambda: self._concat_with_global_key_attn_probs(
attn_scores=attn_scores, attn_scores=attn_scores,
query_vectors=query_vectors, query_vectors=query_vectors,
key_vectors=key_vectors, key_vectors=key_vectors,
...@@ -281,25 +291,23 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -281,25 +291,23 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
),
lambda: attn_scores,
) )
attn_probs = stable_softmax(attn_scores, axis=-1) attn_probs = stable_softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0 # softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape: # Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_index = tf.cond( if is_global_attn:
is_global_attn, masked_index = tf.tile(
lambda: tf.tile(
is_index_masked[:, :, None, None], is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
), )
lambda: tf.tile( else:
masked_index = tf.tile(
is_index_masked[:, :, None, None], is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
) )
attn_probs = tf.where( attn_probs = tf.where(
masked_index, masked_index,
...@@ -324,18 +332,18 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -324,18 +332,18 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
# if global attention, compute sum of global and local attn # if global attention, compute sum of global and local attn
attn_output = tf.cond(
is_global_attn, if is_global_attn:
lambda: self._compute_attn_output_with_global_indices( attn_output = self._compute_attn_output_with_global_indices(
value_vectors=value_vectors, value_vectors=value_vectors,
attn_probs=attn_probs, attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices, max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
), )
lambda: self._sliding_chunks_matmul_attn_probs_value( else:
attn_output = self._sliding_chunks_matmul_attn_probs_value(
attn_probs, value_vectors, self.one_sided_attn_window_size attn_probs, value_vectors, self.one_sided_attn_window_size
),
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
...@@ -345,10 +353,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -345,10 +353,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
# compute value for global attention and overwrite to attention output # compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation if is_global_attn:
attn_output, global_attn_probs = tf.cond( attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
is_global_attn,
lambda: self._compute_global_attn_output_from_hidden(
attn_output=attn_output, attn_output=attn_output,
hidden_states=hidden_states, hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices, max_num_global_attn_indices=max_num_global_attn_indices,
...@@ -358,24 +364,24 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -358,24 +364,24 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
training=training, training=training,
),
lambda: (attn_output, tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))),
) )
else:
# Leave attn_output unchanged
global_attn_probs = tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))
# make sure that local attention probabilities are set to 0 for indices of global attn # make sure that local attention probabilities are set to 0 for indices of global attn
# Make sure to create a mask with the proper shape: # Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_global_attn_index = tf.cond( if is_global_attn:
is_global_attn, masked_global_attn_index = tf.tile(
lambda: tf.tile(
is_index_global_attn[:, :, None, None], is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
), )
lambda: tf.tile( else:
masked_global_attn_index = tf.tile(
is_index_global_attn[:, :, None, None], is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
) )
attn_probs = tf.where( attn_probs = tf.where(
masked_global_attn_index, masked_global_attn_index,
...@@ -1864,13 +1870,10 @@ class TFLEDEncoder(tf.keras.layers.Layer): ...@@ -1864,13 +1870,10 @@ class TFLEDEncoder(tf.keras.layers.Layer):
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
if inputs_embeds is not None: if inputs_embeds is not None:
if padding_len > 0:
def pad_embeddings():
input_ids_padding = tf.fill((batch_size, padding_len), pad_token_id) input_ids_padding = tf.fill((batch_size, padding_len), pad_token_id)
inputs_embeds_padding = self.embed_tokens(input_ids_padding) inputs_embeds_padding = self.embed_tokens(input_ids_padding)
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds)
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
......
...@@ -652,6 +652,7 @@ class TFLongformerSelfOutput(tf.keras.layers.Layer): ...@@ -652,6 +652,7 @@ class TFLongformerSelfOutput(tf.keras.layers.Layer):
class TFLongformerSelfAttention(tf.keras.layers.Layer): class TFLongformerSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, layer_id, **kwargs): def __init__(self, config, layer_id, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
if config.hidden_size % config.num_attention_heads != 0: if config.hidden_size % config.num_attention_heads != 0:
raise ValueError( raise ValueError(
...@@ -708,6 +709,16 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -708,6 +709,16 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
self.one_sided_attn_window_size = attention_window // 2 self.one_sided_attn_window_size = attention_window // 2
def build(self, input_shape=None):
if not self.built:
with tf.name_scope("query_global"):
self.query_global.build((self.config.hidden_size,))
with tf.name_scope("key_global"):
self.key_global.build((self.config.hidden_size,))
with tf.name_scope("value_global"):
self.value_global.build((self.config.hidden_size,))
super().build(input_shape)
def call( def call(
self, self,
inputs, inputs,
...@@ -788,9 +799,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -788,9 +799,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
) = self._get_global_attn_indices(is_index_global_attn) ) = self._get_global_attn_indices(is_index_global_attn)
# this function is only relevant for global attention # this function is only relevant for global attention
attn_scores = tf.cond( if is_global_attn:
is_global_attn, attn_scores = self._concat_with_global_key_attn_probs(
lambda: self._concat_with_global_key_attn_probs(
attn_scores=attn_scores, attn_scores=attn_scores,
query_vectors=query_vectors, query_vectors=query_vectors,
key_vectors=key_vectors, key_vectors=key_vectors,
...@@ -798,25 +808,23 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -798,25 +808,23 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
),
lambda: attn_scores,
) )
attn_probs = stable_softmax(attn_scores, axis=-1) attn_probs = stable_softmax(attn_scores, axis=-1)
# softmax sometimes inserts NaN if all positions are masked, replace them with 0 # softmax sometimes inserts NaN if all positions are masked, replace them with 0
# Make sure to create a mask with the proper shape: # Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_index = tf.cond( if is_global_attn:
is_global_attn, masked_index = tf.tile(
lambda: tf.tile(
is_index_masked[:, :, None, None], is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
), )
lambda: tf.tile( else:
masked_index = tf.tile(
is_index_masked[:, :, None, None], is_index_masked[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
) )
attn_probs = tf.where( attn_probs = tf.where(
masked_index, masked_index,
...@@ -841,18 +849,18 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -841,18 +849,18 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim)) value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
# if global attention, compute sum of global and local attn # if global attention, compute sum of global and local attn
attn_output = tf.cond(
is_global_attn, if is_global_attn:
lambda: self._compute_attn_output_with_global_indices( attn_output = self._compute_attn_output_with_global_indices(
value_vectors=value_vectors, value_vectors=value_vectors,
attn_probs=attn_probs, attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices, max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
), )
lambda: self._sliding_chunks_matmul_attn_probs_value( else:
attn_output = self._sliding_chunks_matmul_attn_probs_value(
attn_probs, value_vectors, self.one_sided_attn_window_size attn_probs, value_vectors, self.one_sided_attn_window_size
),
) )
tf.debugging.assert_equal( tf.debugging.assert_equal(
...@@ -862,10 +870,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -862,10 +870,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim)) attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
# compute value for global attention and overwrite to attention output # compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation if is_global_attn:
attn_output, global_attn_probs = tf.cond( attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
is_global_attn,
lambda: self._compute_global_attn_output_from_hidden(
attn_output=attn_output, attn_output=attn_output,
hidden_states=hidden_states, hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices, max_num_global_attn_indices=max_num_global_attn_indices,
...@@ -875,24 +881,24 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -875,24 +881,24 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
training=training, training=training,
),
lambda: (attn_output, tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))),
) )
else:
# Leave attn_output unchanged
global_attn_probs = tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))
# make sure that local attention probabilities are set to 0 for indices of global attn # make sure that local attention probabilities are set to 0 for indices of global attn
# Make sure to create a mask with the proper shape: # Make sure to create a mask with the proper shape:
# if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1] # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
# if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1] # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
masked_global_attn_index = tf.cond( if is_global_attn:
is_global_attn, masked_global_attn_index = tf.tile(
lambda: tf.tile(
is_index_global_attn[:, :, None, None], is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1), (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
), )
lambda: tf.tile( else:
masked_global_attn_index = tf.tile(
is_index_global_attn[:, :, None, None], is_index_global_attn[:, :, None, None],
(1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1), (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
),
) )
attn_probs = tf.where( attn_probs = tf.where(
masked_global_attn_index, masked_global_attn_index,
...@@ -1828,13 +1834,10 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1828,13 +1834,10 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id) position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)
if inputs_embeds is not None: if inputs_embeds is not None:
if padding_len > 0:
def pad_embeddings():
input_ids_padding = tf.cast(tf.fill((batch_size, padding_len), self.pad_token_id), tf.int64) input_ids_padding = tf.cast(tf.fill((batch_size, padding_len), self.pad_token_id), tf.int64)
inputs_embeds_padding = self.embeddings(input_ids_padding) inputs_embeds_padding = self.embeddings(input_ids_padding)
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds)
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
......
...@@ -151,6 +151,7 @@ class TFNoNorm(tf.keras.layers.Layer): ...@@ -151,6 +151,7 @@ class TFNoNorm(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
self.bias = self.add_weight("bias", shape=[self.feat_size], initializer="zeros") self.bias = self.add_weight("bias", shape=[self.feat_size], initializer="zeros")
self.weight = self.add_weight("weight", shape=[self.feat_size], initializer="ones") self.weight = self.add_weight("weight", shape=[self.feat_size], initializer="ones")
super().build(input_shape)
def call(self, inputs: tf.Tensor): def call(self, inputs: tf.Tensor):
return inputs * self.weight + self.bias return inputs * self.weight + self.bias
......
...@@ -581,6 +581,7 @@ class TFSamPositionalEmbedding(tf.keras.layers.Layer): ...@@ -581,6 +581,7 @@ class TFSamPositionalEmbedding(tf.keras.layers.Layer):
initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.scale), initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.scale),
trainable=False, trainable=False,
) )
super().build(input_shape)
def call(self, input_coords, input_shape=None): def call(self, input_coords, input_shape=None):
"""Positionally encode points that are normalized to [0,1].""" """Positionally encode points that are normalized to [0,1]."""
......
...@@ -225,6 +225,7 @@ class TFVisionTextDualEncoderModel(TFPreTrainedModel): ...@@ -225,6 +225,7 @@ class TFVisionTextDualEncoderModel(TFPreTrainedModel):
# Build in the build() method to make sure the names are right # Build in the build() method to make sure the names are right
initializer = tf.keras.initializers.Constant(self.config.logit_scale_init_value) initializer = tf.keras.initializers.Constant(self.config.logit_scale_init_value)
self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale") self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale")
super().build(input_shape)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
...@@ -591,7 +592,7 @@ class TFVisionTextDualEncoderModel(TFPreTrainedModel): ...@@ -591,7 +592,7 @@ class TFVisionTextDualEncoderModel(TFPreTrainedModel):
if text_model.name != "text_model": if text_model.name != "text_model":
raise ValueError("text model must be created with the name `text_model`.") raise ValueError("text model must be created with the name `text_model`.")
model(model.dummy_inputs) # Ensure model is fully built model.build() # Ensure model is fully built
return model return model
......
...@@ -966,11 +966,8 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): ...@@ -966,11 +966,8 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
""" """
patch_size, num_channels = self.config.patch_size, self.config.num_channels patch_size, num_channels = self.config.patch_size, self.config.num_channels
# make sure channels are last # make sure channels are last
pixel_values = tf.cond( if shape_list(pixel_values)[1] == num_channels:
tf.math.equal(shape_list(pixel_values)[1], num_channels), pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
lambda: tf.transpose(pixel_values, perm=(0, 2, 3, 1)),
lambda: pixel_values,
)
# sanity checks # sanity checks
tf.debugging.assert_equal( tf.debugging.assert_equal(
......
...@@ -766,10 +766,11 @@ class TFWhisperDecoder(tf.keras.layers.Layer): ...@@ -766,10 +766,11 @@ class TFWhisperDecoder(tf.keras.layers.Layer):
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
batch_size, seq_len = input_shape[0], input_shape[1] batch_size, seq_len = input_shape[0], input_shape[1]
combined_attention_mask = tf.cond( if seq_len > 1:
tf.math.greater(seq_len, 1), combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
lambda: _make_causal_mask(input_shape, past_key_values_length=past_key_values_length), else:
lambda: _expand_mask(tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len), combined_attention_mask = _expand_mask(
tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len
) )
if attention_mask is not None: if attention_mask is not None:
......
...@@ -476,6 +476,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -476,6 +476,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
self.mask_emb = self.add_weight( self.mask_emb = self.add_weight(
shape=(1, 1, self.d_model), initializer=initializer, trainable=True, name="mask_emb" shape=(1, 1, self.d_model), initializer=initializer, trainable=True, name="mask_emb"
) )
super().build(input_shape)
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment