Commit cb36903f authored by Hao Wu's avatar Hao Wu Committed by A. Unique TensorFlower
Browse files

Internal changes.

PiperOrigin-RevId: 409201490
parent e888406e
......@@ -32,9 +32,7 @@ def _pad_strides(strides: int, axis: int) -> Tuple[int, int, int, int]:
return (1, strides, strides, 1)
def _maybe_downsample(x: tf.Tensor,
out_filter: int,
strides: int,
def _maybe_downsample(x: tf.Tensor, out_filter: int, strides: int,
axis: int) -> tf.Tensor:
"""Downsamples feature map and 0-pads tensor if in_filter != out_filter."""
data_format = 'NCHW' if axis == 1 else 'NHWC'
......@@ -738,8 +736,7 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
x = self._conv2(x)
x = self._norm2(x)
if (self._use_residual and
self._in_filters == self._out_filters and
if (self._use_residual and self._in_filters == self._out_filters and
self._strides == 1):
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
......@@ -859,8 +856,9 @@ class ResidualInner(tf.keras.layers.Layer):
base_config = super(ResidualInner, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(
self, inputs: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor:
def call(self,
inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor:
x = inputs
if self._batch_norm_first:
x = self._batch_norm_0(x, training=training)
......@@ -993,8 +991,9 @@ class BottleneckResidualInner(tf.keras.layers.Layer):
base_config = super(BottleneckResidualInner, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(
self, inputs: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor:
def call(self,
inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor:
x = inputs
if self._batch_norm_first:
x = self._batch_norm_0(x, training=training)
......@@ -1063,20 +1062,23 @@ class ReversibleLayer(tf.keras.layers.Layer):
def _ckpt_non_trainable_vars(self):
self._f_non_trainable_vars = [
v.read_value() for v in self._f.non_trainable_variables]
v.read_value() for v in self._f.non_trainable_variables
]
self._g_non_trainable_vars = [
v.read_value() for v in self._g.non_trainable_variables]
v.read_value() for v in self._g.non_trainable_variables
]
def _load_ckpt_non_trainable_vars(self):
for v, v_chkpt in zip(
self._f.non_trainable_variables, self._f_non_trainable_vars):
for v, v_chkpt in zip(self._f.non_trainable_variables,
self._f_non_trainable_vars):
v.assign(v_chkpt)
for v, v_chkpt in zip(
self._g.non_trainable_variables, self._g_non_trainable_vars):
for v, v_chkpt in zip(self._g.non_trainable_variables,
self._g_non_trainable_vars):
v.assign(v_chkpt)
def call(
self, inputs: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor:
def call(self,
inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor:
@tf.custom_gradient
def reversible(
......@@ -1101,12 +1103,12 @@ class ReversibleLayer(tf.keras.layers.Layer):
fwdtape.watch(x)
x1, x2 = tf.split(x, num_or_size_splits=2, axis=self._axis)
f_x2 = self._f(x2, training=training)
x1_down = _maybe_downsample(
x1, f_x2.shape[self._axis], self._f.strides, self._axis)
x1_down = _maybe_downsample(x1, f_x2.shape[self._axis], self._f.strides,
self._axis)
z1 = f_x2 + x1_down
g_z1 = self._g(z1, training=training)
x2_down = _maybe_downsample(
x2, g_z1.shape[self._axis], self._f.strides, self._axis)
x2_down = _maybe_downsample(x2, g_z1.shape[self._axis], self._f.strides,
self._axis)
y2 = x2_down + g_z1
# Equation 8: https://arxiv.org/pdf/1707.04585.pdf
......@@ -1114,17 +1116,17 @@ class ReversibleLayer(tf.keras.layers.Layer):
y1 = tf.identity(z1)
y = tf.concat([y1, y2], axis=self._axis)
irreversible = (
(self._f.strides != 1 or self._g.strides != 1)
or (y.shape[self._axis] != inputs.shape[self._axis]))
irreversible = ((self._f.strides != 1 or self._g.strides != 1) or
(y.shape[self._axis] != inputs.shape[self._axis]))
# Checkpointing moving mean/variance for batch normalization layers
# as they shouldn't be updated during the custom gradient pass of f/g.
self._ckpt_non_trainable_vars()
def grad_fn(dy: tf.Tensor,
variables: Optional[List[tf.Variable]] = None,
) -> Tuple[List[tf.Tensor], List[tf.Tensor]]:
def grad_fn(
dy: tf.Tensor,
variables: Optional[List[tf.Variable]] = None,
) -> Tuple[List[tf.Tensor], List[tf.Tensor]]:
"""Given dy calculate (dy/dx)|_{x_{input}} using f/g."""
if irreversible or not self._manual_grads:
grads_combined = fwdtape.gradient(
......@@ -1158,16 +1160,12 @@ class ReversibleLayer(tf.keras.layers.Layer):
# Compute gradients
g_grads_combined = gtape.gradient(
g_z1,
[z1] + self._g.trainable_variables,
output_gradients=dy2)
g_z1, [z1] + self._g.trainable_variables, output_gradients=dy2)
dz1 = dy1 + g_grads_combined[0] # line 5
dwg = g_grads_combined[1:] # line 9
f_grads_combined = ftape.gradient(
f_x2,
[x2] + self._f.trainable_variables,
output_gradients=dz1)
f_x2, [x2] + self._f.trainable_variables, output_gradients=dz1)
dx2 = dy2 + f_grads_combined[0] # line 6
dwf = f_grads_combined[1:] # line 8
dx1 = dz1 # line 7
......@@ -1263,10 +1261,8 @@ class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
'filters': self._filters,
'strides': self._strides,
'regularize_depthwise': self._regularize_depthwise,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
......
......@@ -32,8 +32,7 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
)
],)
class NNBlocksTest(parameterized.TestCase, tf.test.TestCase):
......@@ -92,9 +91,9 @@ class NNBlocksTest(parameterized.TestCase, tf.test.TestCase):
(nn_blocks.InvertedBottleneckBlock, 1, 1, 0.2, None),
(nn_blocks.InvertedBottleneckBlock, 1, 1, None, 0.2),
)
def test_invertedbottleneck_block_creation(
self, block_fn, expand_ratio, strides, se_ratio,
stochastic_depth_drop_rate):
def test_invertedbottleneck_block_creation(self, block_fn, expand_ratio,
strides, se_ratio,
stochastic_depth_drop_rate):
input_size = 128
in_filters = 24
out_filters = 40
......@@ -149,6 +148,32 @@ class BottleneckResidualInnerTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(expected_output_shape, output.shape.as_list())
class DepthwiseSeparableConvBlockTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate(distribution_strategy_combinations())
def test_shape(self, distribution):
batch_size, height, width, num_channels = 8, 32, 32, 32
num_filters = 64
strides = 2
input_tensor = tf.random.normal(
shape=[batch_size, height, width, num_channels])
with distribution.scope():
block = nn_blocks.DepthwiseSeparableConvBlock(
num_filters, strides=strides)
config_dict = block.get_config()
recreate_block = nn_blocks.DepthwiseSeparableConvBlock(**config_dict)
output_tensor = block(input_tensor)
expected_output_shape = [
batch_size, height // strides, width // strides, num_filters
]
self.assertEqual(output_tensor.shape.as_list(), expected_output_shape)
output_tensor = recreate_block(input_tensor)
self.assertEqual(output_tensor.shape.as_list(), expected_output_shape)
class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate(distribution_strategy_combinations())
......@@ -160,13 +185,9 @@ class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase):
input_tensor = tf.random.uniform(shape=[bsz, h, w, c])
with distribution.scope():
f = nn_blocks.ResidualInner(
filters=filters // 2,
strides=strides,
batch_norm_first=True)
filters=filters // 2, strides=strides, batch_norm_first=True)
g = nn_blocks.ResidualInner(
filters=filters // 2,
strides=1,
batch_norm_first=True)
filters=filters // 2, strides=1, batch_norm_first=True)
test_layer = nn_blocks.ReversibleLayer(f, g)
test_layer.build(input_tensor.shape)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
......@@ -199,13 +220,9 @@ class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase):
input_tensor = tf.random.uniform(shape=[bsz, h, w, c])
with distribution.scope():
f = nn_blocks.ResidualInner(
filters=filters // 2,
strides=strides,
batch_norm_first=False)
filters=filters // 2, strides=strides, batch_norm_first=False)
g = nn_blocks.ResidualInner(
filters=filters // 2,
strides=1,
batch_norm_first=False)
filters=filters // 2, strides=1, batch_norm_first=False)
test_layer = nn_blocks.ReversibleLayer(f, g)
test_layer(input_tensor, training=False) # init weights
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
......@@ -247,24 +264,16 @@ class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase):
input_tensor = tf.random.uniform(shape=[bsz, h, w, c * 4]) # bottleneck
with distribution.scope():
f_manual = nn_blocks.BottleneckResidualInner(
filters=filters // 2,
strides=strides,
batch_norm_first=False)
filters=filters // 2, strides=strides, batch_norm_first=False)
g_manual = nn_blocks.BottleneckResidualInner(
filters=filters // 2,
strides=1,
batch_norm_first=False)
filters=filters // 2, strides=1, batch_norm_first=False)
manual_grad_layer = nn_blocks.ReversibleLayer(f_manual, g_manual)
manual_grad_layer(input_tensor, training=False) # init weights
f_auto = nn_blocks.BottleneckResidualInner(
filters=filters // 2,
strides=strides,
batch_norm_first=False)
filters=filters // 2, strides=strides, batch_norm_first=False)
g_auto = nn_blocks.BottleneckResidualInner(
filters=filters // 2,
strides=1,
batch_norm_first=False)
filters=filters // 2, strides=1, batch_norm_first=False)
auto_grad_layer = nn_blocks.ReversibleLayer(
f_auto, g_auto, manual_grads=False)
auto_grad_layer(input_tensor) # init weights
......@@ -294,12 +303,12 @@ class ReversibleLayerTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllClose(
distribution.experimental_local_results(manual_grad),
distribution.experimental_local_results(auto_grad),
atol=5e-3, rtol=5e-3)
atol=5e-3,
rtol=5e-3)
# Verify that BN moving mean and variance is correct.
for manual_var, auto_var in zip(
manual_grad_layer.non_trainable_variables,
auto_grad_layer.non_trainable_variables):
for manual_var, auto_var in zip(manual_grad_layer.non_trainable_variables,
auto_grad_layer.non_trainable_variables):
self.assertAllClose(manual_var, auto_var)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment