Commit cb36903f authored by Hao Wu's avatar Hao Wu Committed by A. Unique TensorFlower
Browse files

Internal changes.

PiperOrigin-RevId: 409201490
parent e888406e
...@@ -32,9 +32,7 @@ def _pad_strides(strides: int, axis: int) -> Tuple[int, int, int, int]: ...@@ -32,9 +32,7 @@ def _pad_strides(strides: int, axis: int) -> Tuple[int, int, int, int]:
return (1, strides, strides, 1) return (1, strides, strides, 1)
def _maybe_downsample(x: tf.Tensor, def _maybe_downsample(x: tf.Tensor, out_filter: int, strides: int,
out_filter: int,
strides: int,
axis: int) -> tf.Tensor: axis: int) -> tf.Tensor:
"""Downsamples feature map and 0-pads tensor if in_filter != out_filter.""" """Downsamples feature map and 0-pads tensor if in_filter != out_filter."""
data_format = 'NCHW' if axis == 1 else 'NHWC' data_format = 'NCHW' if axis == 1 else 'NHWC'
...@@ -738,8 +736,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -738,8 +736,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
x = self._conv2(x) x = self._conv2(x)
x = self._norm2(x) x = self._norm2(x)
if (self._use_residual and if (self._use_residual and self._in_filters == self._out_filters and
self._in_filters == self._out_filters and
self._strides == 1): self._strides == 1):
if self._stochastic_depth: if self._stochastic_depth:
x = self._stochastic_depth(x, training=training) x = self._stochastic_depth(x, training=training)
...@@ -859,8 +856,9 @@ class ResidualInner(tf.keras.layers.Layer): ...@@ -859,8 +856,9 @@ class ResidualInner(tf.keras.layers.Layer):
base_config = super(ResidualInner, self).get_config() base_config = super(ResidualInner, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call( def call(self,
self, inputs: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor: inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor:
x = inputs x = inputs
if self._batch_norm_first: if self._batch_norm_first:
x = self._batch_norm_0(x, training=training) x = self._batch_norm_0(x, training=training)
...@@ -993,8 +991,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer): ...@@ -993,8 +991,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer):
base_config = super(BottleneckResidualInner, self).get_config() base_config = super(BottleneckResidualInner, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call( def call(self,
self, inputs: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor: inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor:
x = inputs x = inputs
if self._batch_norm_first: if self._batch_norm_first:
x = self._batch_norm_0(x, training=training) x = self._batch_norm_0(x, training=training)
...@@ -1063,20 +1062,23 @@ class ReversibleLayer(tf.keras.layers.Layer): ...@@ -1063,20 +1062,23 @@ class ReversibleLayer(tf.keras.layers.Layer):
def _ckpt_non_trainable_vars(self): def _ckpt_non_trainable_vars(self):
self._f_non_trainable_vars = [ self._f_non_trainable_vars = [
v.read_value() for v in self._f.non_trainable_variables] v.read_value() for v in self._f.non_trainable_variables
]
self._g_non_trainable_vars = [ self._g_non_trainable_vars = [
v.read_value() for v in self._g.non_trainable_variables] v.read_value() for v in self._g.non_trainable_variables
]
def _load_ckpt_non_trainable_vars(self): def _load_ckpt_non_trainable_vars(self):
for v, v_chkpt in zip( for v, v_chkpt in zip(self._f.non_trainable_variables,
self._f.non_trainable_variables, self._f_non_trainable_vars): self._f_non_trainable_vars):
v.assign(v_chkpt) v.assign(v_chkpt)
for v, v_chkpt in zip( for v, v_chkpt in zip(self._g.non_trainable_variables,
self._g.non_trainable_variables, self._g_non_trainable_vars): self._g_non_trainable_vars):
v.assign(v_chkpt) v.assign(v_chkpt)
def call( def call(self,
self, inputs: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor: inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor:
@tf.custom_gradient @tf.custom_gradient
def reversible( def reversible(
...@@ -1101,12 +1103,12 @@ class ReversibleLayer(tf.keras.layers.Layer): ...@@ -1101,12 +1103,12 @@ class ReversibleLayer(tf.keras.layers.Layer):
fwdtape.watch(x) fwdtape.watch(x)
x1, x2 = tf.split(x, num_or_size_splits=2, axis=self._axis) x1, x2 = tf.split(x, num_or_size_splits=2, axis=self._axis)
f_x2 = self._f(x2, training=training) f_x2 = self._f(x2, training=training)
x1_down = _maybe_downsample( x1_down = _maybe_downsample(x1, f_x2.shape[self._axis], self._f.strides,
x1, f_x2.shape[self._axis], self._f.strides, self._axis) self._axis)
z1 = f_x2 + x1_down z1 = f_x2 + x1_down
g_z1 = self._g(z1, training=training) g_z1 = self._g(z1, training=training)
x2_down = _maybe_downsample( x2_down = _maybe_downsample(x2, g_z1.shape[self._axis], self._f.strides,
x2, g_z1.shape[self._axis], self._f.strides, self._axis) self._axis)
y2 = x2_down + g_z1 y2 = x2_down + g_z1
# Equation 8: https://arxiv.org/pdf/1707.04585.pdf # Equation 8: https://arxiv.org/pdf/1707.04585.pdf
...@@ -1114,17 +1116,17 @@ class ReversibleLayer(tf.keras.layers.Layer): ...@@ -1114,17 +1116,17 @@ class ReversibleLayer(tf.keras.layers.Layer):
y1 = tf.identity(z1) y1 = tf.identity(z1)
y = tf.concat([y1, y2], axis=self._axis) y = tf.concat([y1, y2], axis=self._axis)
irreversible = ( irreversible = ((self._f.strides != 1 or self._g.strides != 1) or
(self._f.strides != 1 or self._g.strides != 1) (y.shape[self._axis] != inputs.shape[self._axis]))
or (y.shape[self._axis] != inputs.shape[self._axis]))
# Checkpointing moving mean/variance for batch normalization layers # Checkpointing moving mean/variance for batch normalization layers
# as they shouldn't be updated during the custom gradient pass of f/g. # as they shouldn't be updated during the custom gradient pass of f/g.
self._ckpt_non_trainable_vars() self._ckpt_non_trainable_vars()
def grad_fn(dy: tf.Tensor, def grad_fn(
variables: Optional[List[tf.Variable]] = None, dy: tf.Tensor,
) -> Tuple[List[tf.Tensor], List[tf.Tensor]]: variables: Optional[List[tf.Variable]] = None,
) -> Tuple[List[tf.Tensor], List[tf.Tensor]]:
"""Given dy calculate (dy/dx)|_{x_{input}} using f/g.""" """Given dy calculate (dy/dx)|_{x_{input}} using f/g."""
if irreversible or not self._manual_grads: if irreversible or not self._manual_grads:
grads_combined = fwdtape.gradient( grads_combined = fwdtape.gradient(
...@@ -1158,16 +1160,12 @@ class ReversibleLayer(tf.keras.layers.Layer): ...@@ -1158,16 +1160,12 @@ class ReversibleLayer(tf.keras.layers.Layer):
# Compute gradients # Compute gradients
g_grads_combined = gtape.gradient( g_grads_combined = gtape.gradient(
g_z1, g_z1, [z1] + self._g.trainable_variables, output_gradients=dy2)
[z1] + self._g.trainable_variables,
output_gradients=dy2)
dz1 = dy1 + g_grads_combined[0] # line 5 dz1 = dy1 + g_grads_combined[0] # line 5
dwg = g_grads_combined[1:] # line 9 dwg = g_grads_combined[1:] # line 9
f_grads_combined = ftape.gradient( f_grads_combined = ftape.gradient(
f_x2, f_x2, [x2] + self._f.trainable_variables, output_gradients=dz1)
[x2] + self._f.trainable_variables,
output_gradients=dz1)
dx2 = dy2 + f_grads_combined[0] # line 6 dx2 = dy2 + f_grads_combined[0] # line 6
dwf = f_grads_combined[1:] # line 8 dwf = f_grads_combined[1:] # line 8
dx1 = dz1 # line 7 dx1 = dz1 # line 7
...@@ -1263,10 +1261,8 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer): ...@@ -1263,10 +1261,8 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
'filters': self._filters, 'filters': self._filters,
'strides': self._strides, 'strides': self._strides,
'regularize_depthwise': self._regularize_depthwise, 'regularize_depthwise': self._regularize_depthwise,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation, 'activation': self._activation,
'use_sync_bn': self._use_sync_bn, 'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum, 'norm_momentum': self._norm_momentum,
......
...@@ -32,8 +32,7 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]: ...@@ -32,8 +32,7 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
strategy_combinations.default_strategy, strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy, strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu, strategy_combinations.one_device_strategy_gpu,
], ],)
)
class NNBlocksTest(parameterized.TestCase, tf.test.TestCase): class NNBlocksTest(parameterized.TestCase, tf.test.TestCase):
...@@ -92,9 +91,9 @@ class NNBlocksTest(parameterized.TestCase, tf.test.TestCase): ...@@ -92,9 +91,9 @@ class NNBlocksTest(parameterized.TestCase, tf.test.TestCase):
(nn_blocks.InvertedBottleneckBlock, 1, 1, 0.2, None), (nn_blocks.InvertedBottleneckBlock, 1, 1, 0.2, None),
(nn_blocks.InvertedBottleneckBlock, 1, 1, None, 0.2), (nn_blocks.InvertedBottleneckBlock, 1, 1, None, 0.2),
) )
def test_invertedbottleneck_block_creation( def test_invertedbottleneck_block_creation(self, block_fn, expand_ratio,
self, block_fn, expand_ratio, strides, se_ratio, strides, se_ratio,
stochastic_depth_drop_rate): stochastic_depth_drop_rate):
input_size = 128 input_size = 128
in_filters = 24 in_filters = 24
out_filters = 40 out_filters = 40
...@@ -149,6 +148,32 @@ class BottleneckResidualInnerTest(parameterized.TestCase, tf.test.TestCase): ...@@ -149,6 +148,32 @@ class BottleneckResidualInnerTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(expected_output_shape, output.shape.as_list()) self.assertEqual(expected_output_shape, output.shape.as_list())
class DepthwiseSeparableConvBlockTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate(distribution_strategy_combinations())
def test_shape(self, distribution):
batch_size, height, width, num_channels = 8, 32, 32, 32
num_filters = 64
strides = 2
input_tensor = tf.random.normal(
shape=[batch_size, height, width, num_channels])
with distribution.scope():
block = nn_blocks.DepthwiseSeparableConvBlock(
num_filters, strides=strides)
config_dict = block.get_config()
recreate_block = nn_blocks.DepthwiseSeparableConvBlock(**config_dict)
output_tensor = block(input_tensor)
expected_output_shape = [
batch_size, height // strides, width // strides, num_filters
]
self.assertEqual(output_tensor.shape.as_list(), expected_output_shape)
output_tensor = recreate_block(input_tensor)
self.assertEqual(output_tensor.shape.as_list(), expected_output_shape)
class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase): class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate(distribution_strategy_combinations()) @combinations.generate(distribution_strategy_combinations())
...@@ -160,13 +185,9 @@ class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase): ...@@ -160,13 +185,9 @@ class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase):
input_tensor = tf.random.uniform(shape=[bsz, h, w, c]) input_tensor = tf.random.uniform(shape=[bsz, h, w, c])
with distribution.scope(): with distribution.scope():
f = nn_blocks.ResidualInner( f = nn_blocks.ResidualInner(
filters=filters // 2, filters=filters // 2, strides=strides, batch_norm_first=True)
strides=strides,
batch_norm_first=True)
g = nn_blocks.ResidualInner( g = nn_blocks.ResidualInner(
filters=filters // 2, filters=filters // 2, strides=1, batch_norm_first=True)
strides=1,
batch_norm_first=True)
test_layer = nn_blocks.ReversibleLayer(f, g) test_layer = nn_blocks.ReversibleLayer(f, g)
test_layer.build(input_tensor.shape) test_layer.build(input_tensor.shape)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
...@@ -199,13 +220,9 @@ class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase): ...@@ -199,13 +220,9 @@ class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase):
input_tensor = tf.random.uniform(shape=[bsz, h, w, c]) input_tensor = tf.random.uniform(shape=[bsz, h, w, c])
with distribution.scope(): with distribution.scope():
f = nn_blocks.ResidualInner( f = nn_blocks.ResidualInner(
filters=filters // 2, filters=filters // 2, strides=strides, batch_norm_first=False)
strides=strides,
batch_norm_first=False)
g = nn_blocks.ResidualInner( g = nn_blocks.ResidualInner(
filters=filters // 2, filters=filters // 2, strides=1, batch_norm_first=False)
strides=1,
batch_norm_first=False)
test_layer = nn_blocks.ReversibleLayer(f, g) test_layer = nn_blocks.ReversibleLayer(f, g)
test_layer(input_tensor, training=False) # init weights test_layer(input_tensor, training=False) # init weights
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
...@@ -247,24 +264,16 @@ class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase): ...@@ -247,24 +264,16 @@ class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase):
input_tensor = tf.random.uniform(shape=[bsz, h, w, c * 4]) # bottleneck input_tensor = tf.random.uniform(shape=[bsz, h, w, c * 4]) # bottleneck
with distribution.scope(): with distribution.scope():
f_manual = nn_blocks.BottleneckResidualInner( f_manual = nn_blocks.BottleneckResidualInner(
filters=filters // 2, filters=filters // 2, strides=strides, batch_norm_first=False)
strides=strides,
batch_norm_first=False)
g_manual = nn_blocks.BottleneckResidualInner( g_manual = nn_blocks.BottleneckResidualInner(
filters=filters // 2, filters=filters // 2, strides=1, batch_norm_first=False)
strides=1,
batch_norm_first=False)
manual_grad_layer = nn_blocks.ReversibleLayer(f_manual, g_manual) manual_grad_layer = nn_blocks.ReversibleLayer(f_manual, g_manual)
manual_grad_layer(input_tensor, training=False) # init weights manual_grad_layer(input_tensor, training=False) # init weights
f_auto = nn_blocks.BottleneckResidualInner( f_auto = nn_blocks.BottleneckResidualInner(
filters=filters // 2, filters=filters // 2, strides=strides, batch_norm_first=False)
strides=strides,
batch_norm_first=False)
g_auto = nn_blocks.BottleneckResidualInner( g_auto = nn_blocks.BottleneckResidualInner(
filters=filters // 2, filters=filters // 2, strides=1, batch_norm_first=False)
strides=1,
batch_norm_first=False)
auto_grad_layer = nn_blocks.ReversibleLayer( auto_grad_layer = nn_blocks.ReversibleLayer(
f_auto, g_auto, manual_grads=False) f_auto, g_auto, manual_grads=False)
auto_grad_layer(input_tensor) # init weights auto_grad_layer(input_tensor) # init weights
...@@ -294,12 +303,12 @@ class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase): ...@@ -294,12 +303,12 @@ class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllClose( self.assertAllClose(
distribution.experimental_local_results(manual_grad), distribution.experimental_local_results(manual_grad),
distribution.experimental_local_results(auto_grad), distribution.experimental_local_results(auto_grad),
atol=5e-3, rtol=5e-3) atol=5e-3,
rtol=5e-3)
# Verify that BN moving mean and variance is correct. # Verify that BN moving mean and variance is correct.
for manual_var, auto_var in zip( for manual_var, auto_var in zip(manual_grad_layer.non_trainable_variables,
manual_grad_layer.non_trainable_variables, auto_grad_layer.non_trainable_variables):
auto_grad_layer.non_trainable_variables):
self.assertAllClose(manual_var, auto_var) self.assertAllClose(manual_var, auto_var)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment