"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "50876abc47f1be4739f837e477bcf7506fafcff3"
Commit 7954d4b1 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

Refactor resnet_3d.

PiperOrigin-RevId: 436815441
parent 53fb1c67
...@@ -153,19 +153,76 @@ class ResNet3D(tf.keras.Model): ...@@ -153,19 +153,76 @@ class ResNet3D(tf.keras.Model):
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
bn_axis = -1 self._bn_axis = -1
else: else:
bn_axis = 1 self._bn_axis = 1
# Build ResNet3D backbone. # Build ResNet3D backbone.
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = tf.keras.Input(shape=input_specs.shape[1:])
endpoints = self._build_model(inputs)
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
super(ResNet3D, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)
def _build_model(self, inputs):
"""Builds model architecture.
Args:
inputs: the keras input spec.
Returns:
endpoints: A dictionary of backbone endpoint features.
"""
# Build stem.
x = self._build_stem(inputs, stem_type=self._stem_type)
temporal_kernel_size = 1 if self._stem_pool_temporal_stride == 1 else 3
x = layers.MaxPool3D(
pool_size=[temporal_kernel_size, 3, 3],
strides=[self._stem_pool_temporal_stride, 2, 2],
padding='same')(x)
# Build intermediate blocks and endpoints.
resnet_specs = RESNET_SPECS[self._model_id]
if len(self._temporal_strides) != len(resnet_specs) or len(
self._temporal_kernel_sizes) != len(resnet_specs):
raise ValueError(
'Number of blocks in temporal specs should equal to resnet_specs.')
endpoints = {}
for i, resnet_spec in enumerate(resnet_specs):
if resnet_spec[0] == 'bottleneck3d':
block_fn = nn_blocks_3d.BottleneckBlock3D
else:
raise ValueError('Block fn `{}` is not supported.'.format(
resnet_spec[0]))
use_self_gating = (
self._use_self_gating[i] if self._use_self_gating else False)
x = self._block_group(
inputs=x,
filters=resnet_spec[1],
temporal_kernel_sizes=self._temporal_kernel_sizes[i],
temporal_strides=self._temporal_strides[i],
spatial_strides=(1 if i == 0 else 2),
block_fn=block_fn,
block_repeats=resnet_spec[2],
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i + 2, 5),
use_self_gating=use_self_gating,
name='block_group_l{}'.format(i + 2))
endpoints[str(i + 2)] = x
return endpoints
def _build_stem(self, inputs, stem_type):
"""Builds stem layer."""
# Build stem. # Build stem.
if stem_type == 'v0': if stem_type == 'v0':
x = layers.Conv3D( x = layers.Conv3D(
filters=64, filters=64,
kernel_size=[stem_conv_temporal_kernel_size, 7, 7], kernel_size=[self._stem_conv_temporal_kernel_size, 7, 7],
strides=[stem_conv_temporal_stride, 2, 2], strides=[self._stem_conv_temporal_stride, 2, 2],
use_bias=False, use_bias=False,
padding='same', padding='same',
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
...@@ -173,14 +230,15 @@ class ResNet3D(tf.keras.Model): ...@@ -173,14 +230,15 @@ class ResNet3D(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
inputs) inputs)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=self._bn_axis,
x) momentum=self._norm_momentum,
x = tf_utils.get_activation(activation)(x) epsilon=self._norm_epsilon)(x)
x = tf_utils.get_activation(self._activation)(x)
elif stem_type == 'v1': elif stem_type == 'v1':
x = layers.Conv3D( x = layers.Conv3D(
filters=32, filters=32,
kernel_size=[stem_conv_temporal_kernel_size, 3, 3], kernel_size=[self._stem_conv_temporal_kernel_size, 3, 3],
strides=[stem_conv_temporal_stride, 2, 2], strides=[self._stem_conv_temporal_stride, 2, 2],
use_bias=False, use_bias=False,
padding='same', padding='same',
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
...@@ -188,9 +246,10 @@ class ResNet3D(tf.keras.Model): ...@@ -188,9 +246,10 @@ class ResNet3D(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
inputs) inputs)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=self._bn_axis,
x) momentum=self._norm_momentum,
x = tf_utils.get_activation(activation)(x) epsilon=self._norm_epsilon)(x)
x = tf_utils.get_activation(self._activation)(x)
x = layers.Conv3D( x = layers.Conv3D(
filters=32, filters=32,
kernel_size=[1, 3, 3], kernel_size=[1, 3, 3],
...@@ -202,9 +261,10 @@ class ResNet3D(tf.keras.Model): ...@@ -202,9 +261,10 @@ class ResNet3D(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
x) x)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=self._bn_axis,
x) momentum=self._norm_momentum,
x = tf_utils.get_activation(activation)(x) epsilon=self._norm_epsilon)(x)
x = tf_utils.get_activation(self._activation)(x)
x = layers.Conv3D( x = layers.Conv3D(
filters=64, filters=64,
kernel_size=[1, 3, 3], kernel_size=[1, 3, 3],
...@@ -216,51 +276,14 @@ class ResNet3D(tf.keras.Model): ...@@ -216,51 +276,14 @@ class ResNet3D(tf.keras.Model):
bias_regularizer=self._bias_regularizer)( bias_regularizer=self._bias_regularizer)(
x) x)
x = self._norm( x = self._norm(
axis=bn_axis, momentum=norm_momentum, epsilon=norm_epsilon)( axis=self._bn_axis,
x) momentum=self._norm_momentum,
x = tf_utils.get_activation(activation)(x) epsilon=self._norm_epsilon)(x)
x = tf_utils.get_activation(self._activation)(x)
else: else:
raise ValueError(f'Stem type {stem_type} not supported.') raise ValueError(f'Stem type {stem_type} not supported.')
temporal_kernel_size = 1 if stem_pool_temporal_stride == 1 else 3 return x
x = layers.MaxPool3D(
pool_size=[temporal_kernel_size, 3, 3],
strides=[stem_pool_temporal_stride, 2, 2],
padding='same')(
x)
# Build intermediate blocks and endpoints.
resnet_specs = RESNET_SPECS[model_id]
if len(temporal_strides) != len(resnet_specs) or len(
temporal_kernel_sizes) != len(resnet_specs):
raise ValueError(
'Number of blocks in temporal specs should equal to resnet_specs.')
endpoints = {}
for i, resnet_spec in enumerate(resnet_specs):
if resnet_spec[0] == 'bottleneck3d':
block_fn = nn_blocks_3d.BottleneckBlock3D
else:
raise ValueError('Block fn `{}` is not supported.'.format(
resnet_spec[0]))
x = self._block_group(
inputs=x,
filters=resnet_spec[1],
temporal_kernel_sizes=temporal_kernel_sizes[i],
temporal_strides=temporal_strides[i],
spatial_strides=(1 if i == 0 else 2),
block_fn=block_fn,
block_repeats=resnet_spec[2],
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i + 2, 5),
use_self_gating=use_self_gating[i] if use_self_gating else False,
name='block_group_l{}'.format(i + 2))
endpoints[str(i + 2)] = x
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
super(ResNet3D, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)
def _block_group(self, def _block_group(self,
inputs: tf.Tensor, inputs: tf.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment