Unverified Commit 9fc34235 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Use shape_list to safely get shapes for Swin (#17591)

* Use shape_list to safely get shapes

* Add relevant test

* Tidy and add metrics

* Resolve dynamic shaping issues and move test

* Tidy up and all samples in batch

* Formatting
parent e0be053e
...@@ -648,7 +648,12 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -648,7 +648,12 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer):
context_layer = tf.matmul(attention_probs, value_layer) context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
new_context_layer_shape = shape_list(context_layer)[:-2] + [-1] context_layer_shape = shape_list(context_layer)
# Set the final dimension here explicitly.
# Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing
# the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput
# requires final input dimension to be defined
new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]]
context_layer = tf.reshape(context_layer, new_context_layer_shape) context_layer = tf.reshape(context_layer, new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs return outputs
......
...@@ -620,11 +620,15 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -620,11 +620,15 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
self.dropout = TFDebertaV2StableDropout(config.attention_probs_dropout_prob, name="dropout") self.dropout = TFDebertaV2StableDropout(config.attention_probs_dropout_prob, name="dropout")
def transpose_for_scores(self, tensor: tf.Tensor, attention_heads: int) -> tf.Tensor: def transpose_for_scores(self, tensor: tf.Tensor, attention_heads: int) -> tf.Tensor:
shape = shape_list(tensor)[:-1] + [attention_heads, -1] tensor_shape = shape_list(tensor)
# In graph mode mode, we can't reshape with -1 as the final dimension if the first dimension (batch size) is None
shape = tensor_shape[:-1] + [attention_heads, tensor_shape[-1] // attention_heads]
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=shape) tensor = tf.reshape(tensor=tensor, shape=shape)
tensor = tf.transpose(tensor, perm=[0, 2, 1, 3])
x_shape = shape_list(tensor) x_shape = shape_list(tensor)
return tf.reshape(tf.transpose(tensor, perm=[0, 2, 1, 3]), shape=[-1, x_shape[1], x_shape[-1]]) tensor = tf.reshape(tensor, shape=[-1, x_shape[-2], x_shape[-1]])
return tensor
def call( def call(
self, self,
...@@ -686,7 +690,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -686,7 +690,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
if rel_att is not None: if rel_att is not None:
attention_scores = attention_scores + rel_att attention_scores = attention_scores + rel_att
attention_scores = attention_scores
attention_scores = tf.reshape( attention_scores = tf.reshape(
attention_scores, attention_scores,
(-1, self.num_attention_heads, shape_list(attention_scores)[-2], shape_list(attention_scores)[-1]), (-1, self.num_attention_heads, shape_list(attention_scores)[-2], shape_list(attention_scores)[-1]),
...@@ -706,9 +709,12 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -706,9 +709,12 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
), ),
[0, 2, 1, 3], [0, 2, 1, 3],
) )
new_context_layer_shape = shape_list(context_layer)[:-2] + [ # Set the final dimension here explicitly.
-1, # Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing
] # the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput
# requires final input dimension to be defined
context_layer_shape = shape_list(context_layer)
new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]]
context_layer = tf.reshape(context_layer, new_context_layer_shape) context_layer = tf.reshape(context_layer, new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs return outputs
......
...@@ -213,7 +213,7 @@ def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor: ...@@ -213,7 +213,7 @@ def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor:
""" """
Partitions the given input into windows. Partitions the given input into windows.
""" """
batch_size, height, width, num_channels = input_feature.shape batch_size, height, width, num_channels = shape_list(input_feature)
input_feature = tf.reshape( input_feature = tf.reshape(
input_feature, input_feature,
(batch_size, height // window_size, window_size, width // window_size, window_size, num_channels), (batch_size, height // window_size, window_size, width // window_size, window_size, num_channels),
...@@ -227,7 +227,9 @@ def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int ...@@ -227,7 +227,9 @@ def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int
""" """
Merges windows to produce higher resolution features. Merges windows to produce higher resolution features.
""" """
batch_size = int(windows.shape[0] / (height * width / window_size / window_size)) x = shape_list(windows)[0]
y = tf.cast(height * width / window_size / window_size, tf.int32)
batch_size = int(x / y)
windows = tf.reshape( windows = tf.reshape(
windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1) windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1)
) )
...@@ -245,7 +247,9 @@ def drop_path( ...@@ -245,7 +247,9 @@ def drop_path(
if drop_prob == 0.0 or not training: if drop_prob == 0.0 or not training:
return input return input
keep_prob = 1 - drop_prob keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets input_shape = shape_list(input)
ndim = len(input_shape)
shape = [input_shape[0]] + [1] * (ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = tf.random.uniform(shape) random_tensor = tf.random.uniform(shape)
random_tensor = tf.where(random_tensor <= keep_prob, 1.0, 0.0) random_tensor = tf.where(random_tensor <= keep_prob, 1.0, 0.0)
if keep_prob > 0.0 and scale_by_keep: if keep_prob > 0.0 and scale_by_keep:
...@@ -295,7 +299,7 @@ class TFSwinEmbeddings(tf.keras.layers.Layer): ...@@ -295,7 +299,7 @@ class TFSwinEmbeddings(tf.keras.layers.Layer):
) -> Tuple[tf.Tensor, Tuple[int, int]]: ) -> Tuple[tf.Tensor, Tuple[int, int]]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values, training=training) embeddings, output_dimensions = self.patch_embeddings(pixel_values, training=training)
embeddings = self.norm(embeddings, training=training) embeddings = self.norm(embeddings, training=training)
batch_size, seq_len, _ = embeddings.shape batch_size, seq_len, _ = shape_list(embeddings)
if bool_masked_pos is not None: if bool_masked_pos is not None:
mask_tokens = tf.repeat(self.mask_token, batch_size, 0) mask_tokens = tf.repeat(self.mask_token, batch_size, 0)
...@@ -357,10 +361,10 @@ class TFSwinPatchEmbeddings(tf.keras.layers.Layer): ...@@ -357,10 +361,10 @@ class TFSwinPatchEmbeddings(tf.keras.layers.Layer):
# B,H,W,C -> B,C,H,W # B,H,W,C -> B,C,H,W
embeddings = tf.transpose(embeddings, (0, 3, 1, 2)) embeddings = tf.transpose(embeddings, (0, 3, 1, 2))
_, _, height, width = embeddings.shape batch_size, channels, height, width = shape_list(embeddings)
output_dimensions = (height, width) output_dimensions = (height, width)
embeddings = tf.reshape(embeddings, (embeddings.shape[0], embeddings.shape[1], -1)) embeddings = tf.reshape(embeddings, (batch_size, channels, -1))
embeddings = tf.transpose(embeddings, (0, 2, 1)) embeddings = tf.transpose(embeddings, (0, 2, 1))
return embeddings, output_dimensions return embeddings, output_dimensions
...@@ -402,7 +406,7 @@ class TFSwinPatchMerging(tf.keras.layers.Layer): ...@@ -402,7 +406,7 @@ class TFSwinPatchMerging(tf.keras.layers.Layer):
def call(self, input_feature: tf.Tensor, input_dimensions: Tuple[int, int], training: bool = False) -> tf.Tensor: def call(self, input_feature: tf.Tensor, input_dimensions: Tuple[int, int], training: bool = False) -> tf.Tensor:
height, width = input_dimensions height, width = input_dimensions
# `dim` is height * width # `dim` is height * width
batch_size, _, num_channels = input_feature.shape batch_size, _, num_channels = shape_list(input_feature)
input_feature = tf.reshape(input_feature, (batch_size, height, width, num_channels)) input_feature = tf.reshape(input_feature, (batch_size, height, width, num_channels))
# pad input to be disible by width and height, if needed # pad input to be disible by width and height, if needed
...@@ -456,7 +460,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer): ...@@ -456,7 +460,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
coords_h = tf.range(self.window_size[0]) coords_h = tf.range(self.window_size[0])
coords_w = tf.range(self.window_size[1]) coords_w = tf.range(self.window_size[1])
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij")) coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
coords_flatten = tf.reshape(coords, (coords.shape[0], -1)) coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1))
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = tf.transpose(relative_coords, (1, 2, 0)) relative_coords = tf.transpose(relative_coords, (1, 2, 0))
...@@ -497,7 +501,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer): ...@@ -497,7 +501,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
new_x_shape = x.shape[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size]
x = tf.reshape(x, new_x_shape) x = tf.reshape(x, new_x_shape)
return tf.transpose(x, (0, 2, 1, 3)) return tf.transpose(x, (0, 2, 1, 3))
...@@ -509,7 +513,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer): ...@@ -509,7 +513,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
output_attentions: bool = False, output_attentions: bool = False,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor, ...]: ) -> Tuple[tf.Tensor, ...]:
batch_size, dim, _ = hidden_states.shape batch_size, dim, _ = shape_list(hidden_states)
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states)) key_layer = self.transpose_for_scores(self.key(hidden_states))
...@@ -533,7 +537,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer): ...@@ -533,7 +537,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in SwinModel forward() function) # Apply the attention mask is (precomputed for all layers in SwinModel forward() function)
mask_shape = attention_mask.shape[0] mask_shape = shape_list(attention_mask)[0]
attention_scores = tf.reshape( attention_scores = tf.reshape(
attention_scores, (batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim) attention_scores, (batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim)
) )
...@@ -555,7 +559,9 @@ class TFSwinSelfAttention(tf.keras.layers.Layer): ...@@ -555,7 +559,9 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
context_layer = tf.matmul(attention_probs, value_layer) context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, (0, 2, 1, 3)) context_layer = tf.transpose(context_layer, (0, 2, 1, 3))
new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,) new_context_layer_shape = shape_list(context_layer)[:-2] + [
self.all_head_size,
]
context_layer = tf.reshape(context_layer, new_context_layer_shape) context_layer = tf.reshape(context_layer, new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
...@@ -720,7 +726,7 @@ class TFSwinLayer(tf.keras.layers.Layer): ...@@ -720,7 +726,7 @@ class TFSwinLayer(tf.keras.layers.Layer):
) -> tf.Tensor: ) -> tf.Tensor:
self.set_shift_and_window_size(input_dimensions) self.set_shift_and_window_size(input_dimensions)
height, width = input_dimensions height, width = input_dimensions
batch_size, _, channels = hidden_states.shape batch_size, _, channels = shape_list(hidden_states)
shortcut = hidden_states shortcut = hidden_states
hidden_states = self.layernorm_before(hidden_states, training=training) hidden_states = self.layernorm_before(hidden_states, training=training)
...@@ -728,7 +734,7 @@ class TFSwinLayer(tf.keras.layers.Layer): ...@@ -728,7 +734,7 @@ class TFSwinLayer(tf.keras.layers.Layer):
# pad hidden_states to multiples of window size # pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
_, height_pad, width_pad, _ = hidden_states.shape _, height_pad, width_pad, _ = shape_list(hidden_states)
# cyclic shift # cyclic shift
if self.shift_size > 0: if self.shift_size > 0:
shifted_hidden_states = tf.roll(hidden_states, shift=(-self.shift_size, -self.shift_size), axis=(1, 2)) shifted_hidden_states = tf.roll(hidden_states, shift=(-self.shift_size, -self.shift_size), axis=(1, 2))
...@@ -881,7 +887,7 @@ class TFSwinEncoder(tf.keras.layers.Layer): ...@@ -881,7 +887,7 @@ class TFSwinEncoder(tf.keras.layers.Layer):
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
if output_hidden_states: if output_hidden_states:
batch_size, _, hidden_size = hidden_states.shape batch_size, _, hidden_size = shape_list(hidden_states)
# rearrange b (h w) c -> b c h w # rearrange b (h w) c -> b c h w
reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size)) reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size))
reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2)) reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2))
...@@ -902,7 +908,7 @@ class TFSwinEncoder(tf.keras.layers.Layer): ...@@ -902,7 +908,7 @@ class TFSwinEncoder(tf.keras.layers.Layer):
all_input_dimensions += (input_dimensions,) all_input_dimensions += (input_dimensions,)
if output_hidden_states: if output_hidden_states:
batch_size, _, hidden_size = hidden_states.shape batch_size, _, hidden_size = shape_list(hidden_states)
# rearrange b (h w) c -> b c h w # rearrange b (h w) c -> b c h w
reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size)) reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size))
reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2)) reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2))
...@@ -1152,7 +1158,7 @@ class TFSwinModel(TFSwinPreTrainedModel): ...@@ -1152,7 +1158,7 @@ class TFSwinModel(TFSwinPreTrainedModel):
pooled_output = None pooled_output = None
if self.pooler is not None: if self.pooler is not None:
batch_size, _, num_features = sequence_output.shape batch_size, _, num_features = shape_list(sequence_output)
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
pooled_output = tf.reshape(pooled_output, (batch_size, num_features)) pooled_output = tf.reshape(pooled_output, (batch_size, num_features))
...@@ -1206,7 +1212,7 @@ class TFSwinDecoder(tf.keras.layers.Layer): ...@@ -1206,7 +1212,7 @@ class TFSwinDecoder(tf.keras.layers.Layer):
# B,C,H,W -> B,H,W,C # B,C,H,W -> B,H,W,C
hidden_states = tf.transpose(hidden_states, (0, 2, 3, 1)) hidden_states = tf.transpose(hidden_states, (0, 2, 3, 1))
hidden_states = self.conv2d(hidden_states) hidden_states = self.conv2d(hidden_states)
batch_size, _, _, num_input_channels = hidden_states.shape batch_size, _, _, num_input_channels = shape_list(hidden_states)
block_size_squared = self._block_size**2 block_size_squared = self._block_size**2
output_depth = int(num_input_channels / block_size_squared) output_depth = int(num_input_channels / block_size_squared)
# When the number of output channels >= 2, PyTorch's PixelShuffle and # When the number of output channels >= 2, PyTorch's PixelShuffle and
...@@ -1293,7 +1299,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel): ...@@ -1293,7 +1299,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
sequence_output = outputs[0] sequence_output = outputs[0]
# Reshape to (batch_size, num_channels, height, width) # Reshape to (batch_size, num_channels, height, width)
sequence_output = tf.transpose(sequence_output, (0, 2, 1)) sequence_output = tf.transpose(sequence_output, (0, 2, 1))
batch_size, num_channels, sequence_length = sequence_output.shape batch_size, num_channels, sequence_length = shape_list(sequence_output)
height = width = int(sequence_length**0.5) height = width = int(sequence_length**0.5)
sequence_output = tf.reshape(sequence_output, (batch_size, num_channels, height, width)) sequence_output = tf.reshape(sequence_output, (batch_size, num_channels, height, width))
......
...@@ -1406,6 +1406,24 @@ class TFModelTesterMixin: ...@@ -1406,6 +1406,24 @@ class TFModelTesterMixin:
if metrics: if metrics:
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!") self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
# Make sure fit works with tf.data.Dataset and results are consistent
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
# Pass in all samples as a batch to match other `fit` calls
dataset = dataset.batch(len(dataset))
history3 = model.fit(
dataset,
validation_data=dataset,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss3 = history3.history["val_loss"][0]
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3))
self.assertEqual(history1.history.keys(), history3.history.keys())
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
def test_int64_inputs(self): def test_int64_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment