Unverified Commit ca552843 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-segmentation

parents 7e2f7a35 6b90e134
......@@ -338,7 +338,7 @@ with the Python API:
```python
# Create the interpreter and signature runner
interpreter = tf.lite.Interpreter('/tmp/movinet_a0_stream.tflite')
signature = interpreter.get_signature_runner()
runner = interpreter.get_signature_runner()
# Extract state names and create the initial (zero) states
def state_name(name: str) -> str:
......@@ -358,7 +358,7 @@ clips = tf.split(video, video.shape[1], axis=1)
states = init_states
for clip in clips:
# Input shape: [1, 1, 172, 172, 3]
outputs = signature(**states, image=clip)
outputs = runner(**states, image=clip)
logits = outputs.pop('logits')
states = outputs
```
......
......@@ -121,7 +121,7 @@ class ExportSavedModelTest(tf.test.TestCase):
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signature = interpreter.get_signature_runner()
runner = interpreter.get_signature_runner('serving_default')
def state_name(name: str) -> str:
return name[len('serving_default_'):-len(':0')]
......@@ -137,7 +137,7 @@ class ExportSavedModelTest(tf.test.TestCase):
states = init_states
for clip in clips:
outputs = signature(**states, image=clip)
outputs = runner(**states, image=clip)
logits = outputs.pop('logits')
states = outputs
......
......@@ -17,10 +17,10 @@
Reference: https://arxiv.org/pdf/2103.11511.pdf
"""
import dataclasses
import math
from typing import Dict, Mapping, Optional, Sequence, Tuple, Union
import dataclasses
import tensorflow as tf
from official.modeling import hyperparams
......@@ -454,7 +454,7 @@ class Movinet(tf.keras.Model):
stochastic_depth_idx = 1
for block_idx, block in enumerate(self._block_specs):
if isinstance(block, StemSpec):
x, states = movinet_layers.Stem(
layer_obj = movinet_layers.Stem(
block.filters,
block.kernel_size,
block.strides,
......@@ -466,9 +466,9 @@ class Movinet(tf.keras.Model):
batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon,
state_prefix='state/stem',
name='stem')(
x, states=states)
state_prefix='state_stem',
name='stem')
x, states = layer_obj(x, states=states)
endpoints['stem'] = x
elif isinstance(block, MovinetBlockSpec):
if not (len(block.expand_filters) == len(block.kernel_sizes) ==
......@@ -486,8 +486,8 @@ class Movinet(tf.keras.Model):
self._stochastic_depth_drop_rate * stochastic_depth_idx /
num_layers)
expand_filters, kernel_size, strides = layer
name = f'b{block_idx-1}/l{layer_idx}'
x, states = movinet_layers.MovinetBlock(
name = f'block{block_idx-1}_layer{layer_idx}'
layer_obj = movinet_layers.MovinetBlock(
block.base_filters,
expand_filters,
kernel_size=kernel_size,
......@@ -505,13 +505,14 @@ class Movinet(tf.keras.Model):
batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon,
state_prefix=f'state/{name}',
name=name)(
x, states=states)
state_prefix=f'state_{name}',
name=name)
x, states = layer_obj(x, states=states)
endpoints[name] = x
stochastic_depth_idx += 1
elif isinstance(block, HeadSpec):
x, states = movinet_layers.Head(
layer_obj = movinet_layers.Head(
project_filters=block.project_filters,
conv_type=self._conv_type,
activation=self._activation,
......@@ -520,9 +521,9 @@ class Movinet(tf.keras.Model):
batch_norm_layer=self._norm,
batch_norm_momentum=self._norm_momentum,
batch_norm_epsilon=self._norm_epsilon,
state_prefix='state/head',
name='head')(
x, states=states)
state_prefix='state_head',
name='head')
x, states = layer_obj(x, states=states)
endpoints['head'] = x
else:
raise ValueError('Unknown block type {}'.format(block))
......@@ -567,7 +568,7 @@ class Movinet(tf.keras.Model):
for block_idx, block in enumerate(block_specs):
if isinstance(block, StemSpec):
if block.kernel_size[0] > 1:
states['state/stem/stream_buffer'] = (
states['state_stem_stream_buffer'] = (
input_shape[0],
input_shape[1],
divide_resolution(input_shape[2], num_downsamples),
......@@ -590,8 +591,10 @@ class Movinet(tf.keras.Model):
self._conv_type in ['2plus1d', '3d_2plus1d']):
num_downsamples += 1
prefix = f'state_block{block_idx}_layer{layer_idx}'
if kernel_size[0] > 1:
states[f'state/b{block_idx}/l{layer_idx}/stream_buffer'] = (
states[f'{prefix}_stream_buffer'] = (
input_shape[0],
kernel_size[0] - 1,
divide_resolution(input_shape[2], num_downsamples),
......@@ -599,13 +602,13 @@ class Movinet(tf.keras.Model):
expand_filters,
)
states[f'state/b{block_idx}/l{layer_idx}/pool_buffer'] = (
states[f'{prefix}_pool_buffer'] = (
input_shape[0], 1, 1, 1, expand_filters,
)
states[f'state/b{block_idx}/l{layer_idx}/pool_frame_count'] = (1,)
states[f'{prefix}_pool_frame_count'] = (1,)
if use_positional_encoding:
name = f'state/b{block_idx}/l{layer_idx}/pos_enc_frame_count'
name = f'{prefix}_pos_enc_frame_count'
states[name] = (1,)
if strides[1] != strides[2]:
......@@ -618,10 +621,10 @@ class Movinet(tf.keras.Model):
self._conv_type not in ['2plus1d', '3d_2plus1d']):
num_downsamples += 1
elif isinstance(block, HeadSpec):
states['state/head/pool_buffer'] = (
states['state_head_pool_buffer'] = (
input_shape[0], 1, 1, 1, block.project_filters,
)
states['state/head/pool_frame_count'] = (1,)
states['state_head_pool_frame_count'] = (1,)
return states
......
......@@ -478,7 +478,7 @@ class StreamBuffer(tf.keras.layers.Layer):
state_prefix = state_prefix if state_prefix is not None else ''
self._state_prefix = state_prefix
self._state_name = f'{state_prefix}/stream_buffer'
self._state_name = f'{state_prefix}_stream_buffer'
self._buffer_size = buffer_size
def get_config(self):
......@@ -501,7 +501,7 @@ class StreamBuffer(tf.keras.layers.Layer):
inputs: the input tensor.
states: a dict of states such that, if any of the keys match for this
layer, will overwrite the contents of the buffer(s).
Expected keys include `state_prefix + '/stream_buffer'`.
Expected keys include `state_prefix + '_stream_buffer'`.
Returns:
the output tensor and states
......
......@@ -35,11 +35,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
endpoints, states = network(inputs)
self.assertAllEqual(endpoints['stem'].shape, [1, 8, 64, 64, 8])
self.assertAllEqual(endpoints['b0/l0'].shape, [1, 8, 32, 32, 8])
self.assertAllEqual(endpoints['b1/l0'].shape, [1, 8, 16, 16, 32])
self.assertAllEqual(endpoints['b2/l0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b3/l0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b4/l0'].shape, [1, 8, 4, 4, 104])
self.assertAllEqual(endpoints['block0_layer0'].shape, [1, 8, 32, 32, 8])
self.assertAllEqual(endpoints['block1_layer0'].shape, [1, 8, 16, 16, 32])
self.assertAllEqual(endpoints['block2_layer0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['block3_layer0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['block4_layer0'].shape, [1, 8, 4, 4, 104])
self.assertAllEqual(endpoints['head'].shape, [1, 1, 1, 1, 480])
self.assertNotEmpty(states)
......@@ -59,11 +59,11 @@ class MoViNetTest(parameterized.TestCase, tf.test.TestCase):
endpoints, new_states = backbone({**init_states, 'image': inputs})
self.assertAllEqual(endpoints['stem'].shape, [1, 8, 64, 64, 8])
self.assertAllEqual(endpoints['b0/l0'].shape, [1, 8, 32, 32, 8])
self.assertAllEqual(endpoints['b1/l0'].shape, [1, 8, 16, 16, 32])
self.assertAllEqual(endpoints['b2/l0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b3/l0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['b4/l0'].shape, [1, 8, 4, 4, 104])
self.assertAllEqual(endpoints['block0_layer0'].shape, [1, 8, 32, 32, 8])
self.assertAllEqual(endpoints['block1_layer0'].shape, [1, 8, 16, 16, 32])
self.assertAllEqual(endpoints['block2_layer0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['block3_layer0'].shape, [1, 8, 8, 8, 56])
self.assertAllEqual(endpoints['block4_layer0'].shape, [1, 8, 4, 4, 104])
self.assertAllEqual(endpoints['head'].shape, [1, 1, 1, 1, 480])
self.assertNotEmpty(init_states)
......
......@@ -22,6 +22,7 @@ from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.beta.configs import common
from official.vision.beta.configs import maskrcnn
from official.vision.beta.configs import semantic_segmentation
......@@ -47,14 +48,31 @@ class Parser(maskrcnn.Parser):
segmentation_groundtruth_padded_size: List[int] = dataclasses.field(
default_factory=list)
segmentation_ignore_label: int = 255
panoptic_ignore_label: int = 0
# Setting this to true will enable parsing category_mask and instance_mask.
include_panoptic_masks: bool = True
@dataclasses.dataclass
class TfExampleDecoder(common.TfExampleDecoder):
"""A simple TF Example decoder config."""
# Setting this to true will enable decoding category_mask and instance_mask.
include_panoptic_masks: bool = True
@dataclasses.dataclass
class DataDecoder(common.DataDecoder):
"""Data decoder config."""
simple_decoder: TfExampleDecoder = TfExampleDecoder()
@dataclasses.dataclass
class DataConfig(maskrcnn.DataConfig):
"""Input config for training."""
decoder: DataDecoder = DataDecoder()
parser: Parser = Parser()
# @dataclasses.dataclass
@dataclasses.dataclass
class PanopticSegmentationGenerator(hyperparams.Config):
output_size: List[int] = dataclasses.field(
......
......@@ -24,25 +24,51 @@ from official.vision.beta.ops import preprocess_ops
class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
"""Tensorflow Example proto decoder."""
def __init__(self, regenerate_source_id, mask_binarize_threshold):
def __init__(self, regenerate_source_id,
mask_binarize_threshold, include_panoptic_masks):
super(TfExampleDecoder, self).__init__(
include_mask=True,
regenerate_source_id=regenerate_source_id,
mask_binarize_threshold=None)
self._segmentation_keys_to_features = {
self._include_panoptic_masks = include_panoptic_masks
keys_to_features = {
'image/segmentation/class/encoded':
tf.io.FixedLenFeature((), tf.string, default_value='')
}
tf.io.FixedLenFeature((), tf.string, default_value='')}
if include_panoptic_masks:
keys_to_features.update({
'image/panoptic/category_mask':
tf.io.FixedLenFeature((), tf.string, default_value=''),
'image/panoptic/instance_mask':
tf.io.FixedLenFeature((), tf.string, default_value='')})
self._segmentation_keys_to_features = keys_to_features
def decode(self, serialized_example):
decoded_tensors = super(TfExampleDecoder, self).decode(serialized_example)
segmentation_parsed_tensors = tf.io.parse_single_example(
parsed_tensors = tf.io.parse_single_example(
serialized_example, self._segmentation_keys_to_features)
segmentation_mask = tf.io.decode_image(
segmentation_parsed_tensors['image/segmentation/class/encoded'],
parsed_tensors['image/segmentation/class/encoded'],
channels=1)
segmentation_mask.set_shape([None, None, 1])
decoded_tensors.update({'groundtruth_segmentation_mask': segmentation_mask})
if self._include_panoptic_masks:
category_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/category_mask'],
channels=1)
instance_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/instance_mask'],
channels=1)
category_mask.set_shape([None, None, 1])
instance_mask.set_shape([None, None, 1])
decoded_tensors.update({
'groundtruth_panoptic_category_mask':
category_mask,
'groundtruth_panoptic_instance_mask':
instance_mask})
return decoded_tensors
......@@ -69,6 +95,8 @@ class Parser(maskrcnn_input.Parser):
segmentation_resize_eval_groundtruth=True,
segmentation_groundtruth_padded_size=None,
segmentation_ignore_label=255,
panoptic_ignore_label=0,
include_panoptic_masks=True,
dtype='float32'):
"""Initializes parameters for parsing annotations in the dataset.
......@@ -106,8 +134,12 @@ class Parser(maskrcnn_input.Parser):
segmentation_groundtruth_padded_size: `Tensor` or `list` for [height,
width]. When resize_eval_groundtruth is set to False, the groundtruth
masks are padded to this size.
segmentation_ignore_label: `int` the pixel with ignore label will not used
for training and evaluation.
segmentation_ignore_label: `int` the pixels with ignore label will not be
used for training and evaluation.
panoptic_ignore_label: `int` the pixels with ignore label will not be used
by the PQ evaluator.
include_panoptic_masks: `bool`, if True, category_mask and instance_mask
will be parsed. Set this to true if PQ evaluator is enabled.
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
"""
super(Parser, self).__init__(
......@@ -139,6 +171,8 @@ class Parser(maskrcnn_input.Parser):
'specified when segmentation_resize_eval_groundtruth is False.')
self._segmentation_groundtruth_padded_size = segmentation_groundtruth_padded_size
self._segmentation_ignore_label = segmentation_ignore_label
self._panoptic_ignore_label = panoptic_ignore_label
self._include_panoptic_masks = include_panoptic_masks
def _parse_train_data(self, data):
"""Parses data for training.
......@@ -250,39 +284,54 @@ class Parser(maskrcnn_input.Parser):
shape [height_l, width_l, 4] representing anchor boxes at each
level.
"""
segmentation_mask = tf.cast(
data['groundtruth_segmentation_mask'], tf.float32)
segmentation_mask = tf.reshape(
segmentation_mask, shape=[1, data['height'], data['width'], 1])
segmentation_mask += 1
def _process_mask(mask, ignore_label, image_info):
mask = tf.cast(mask, dtype=tf.float32)
mask = tf.reshape(mask, shape=[1, data['height'], data['width'], 1])
mask += 1
if self._segmentation_resize_eval_groundtruth:
# Resizes eval masks to match input image sizes. In that case, mean IoU
# is computed on output_size not the original size of the images.
image_scale = image_info[2, :]
offset = image_info[3, :]
mask = preprocess_ops.resize_and_crop_masks(
mask, image_scale, self._output_size, offset)
else:
mask = tf.image.pad_to_bounding_box(
mask, 0, 0,
self._segmentation_groundtruth_padded_size[0],
self._segmentation_groundtruth_padded_size[1])
mask -= 1
# Assign ignore label to the padded region.
mask = tf.where(
tf.equal(mask, -1),
ignore_label * tf.ones_like(mask),
mask)
mask = tf.squeeze(mask, axis=0)
return mask
image, labels = super(Parser, self)._parse_eval_data(data)
image_info = labels['image_info']
if self._segmentation_resize_eval_groundtruth:
# Resizes eval masks to match input image sizes. In that case, mean IoU
# is computed on output_size not the original size of the images.
image_info = labels['image_info']
image_scale = image_info[2, :]
offset = image_info[3, :]
segmentation_mask = preprocess_ops.resize_and_crop_masks(
segmentation_mask, image_scale, self._output_size, offset)
else:
segmentation_mask = tf.image.pad_to_bounding_box(
segmentation_mask, 0, 0,
self._segmentation_groundtruth_padded_size[0],
self._segmentation_groundtruth_padded_size[1])
segmentation_mask -= 1
# Assign ignore label to the padded region.
segmentation_mask = tf.where(
tf.equal(segmentation_mask, -1),
self._segmentation_ignore_label * tf.ones_like(segmentation_mask),
segmentation_mask)
segmentation_mask = tf.squeeze(segmentation_mask, axis=0)
segmentation_mask = _process_mask(
data['groundtruth_segmentation_mask'],
self._segmentation_ignore_label, image_info)
segmentation_valid_mask = tf.not_equal(
segmentation_mask, self._segmentation_ignore_label)
labels['groundtruths'].update({
'gt_segmentation_mask': segmentation_mask,
'gt_segmentation_valid_mask': segmentation_valid_mask})
if self._include_panoptic_masks:
panoptic_category_mask = _process_mask(
data['groundtruth_panoptic_category_mask'],
self._panoptic_ignore_label, image_info)
panoptic_instance_mask = _process_mask(
data['groundtruth_panoptic_instance_mask'],
self._panoptic_ignore_label, image_info)
labels['groundtruths'].update({
'gt_panoptic_category_mask': panoptic_category_mask,
'gt_panoptic_instance_mask': panoptic_instance_mask})
return image, labels
......@@ -493,7 +493,7 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
ckpt.save(os.path.join(save_dir, 'ckpt'))
partial_ckpt = tf.train.Checkpoint(backbone=backbone)
partial_ckpt.restore(tf.train.latest_checkpoint(
partial_ckpt.read(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
partial_ckpt_mask = tf.train.Checkpoint(
......
......@@ -78,14 +78,14 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
checkpoint_path = _get_checkpoint_path(
self.task_config.init_checkpoint)
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(checkpoint_path)
status.assert_consumed()
status = ckpt.read(checkpoint_path)
status.expect_partial().assert_existing_objects_matched()
elif init_module == 'backbone':
checkpoint_path = _get_checkpoint_path(
self.task_config.init_checkpoint)
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(checkpoint_path)
status = ckpt.read(checkpoint_path)
status.expect_partial().assert_existing_objects_matched()
elif init_module == 'segmentation_backbone':
......@@ -93,7 +93,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self.task_config.segmentation_init_checkpoint)
ckpt = tf.train.Checkpoint(
segmentation_backbone=model.segmentation_backbone)
status = ckpt.restore(checkpoint_path)
status = ckpt.read(checkpoint_path)
status.expect_partial().assert_existing_objects_matched()
elif init_module == 'segmentation_decoder':
......@@ -101,7 +101,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
self.task_config.segmentation_init_checkpoint)
ckpt = tf.train.Checkpoint(
segmentation_decoder=model.segmentation_decoder)
status = ckpt.restore(checkpoint_path)
status = ckpt.read(checkpoint_path)
status.expect_partial().assert_existing_objects_matched()
else:
......@@ -122,7 +122,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
if params.decoder.type == 'simple_decoder':
decoder = panoptic_maskrcnn_input.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id,
mask_binarize_threshold=decoder_cfg.mask_binarize_threshold)
mask_binarize_threshold=decoder_cfg.mask_binarize_threshold,
include_panoptic_masks=decoder_cfg.include_panoptic_masks)
else:
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
......@@ -148,7 +149,9 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
.segmentation_resize_eval_groundtruth,
segmentation_groundtruth_padded_size=params.parser
.segmentation_groundtruth_padded_size,
segmentation_ignore_label=params.parser.segmentation_ignore_label)
segmentation_ignore_label=params.parser.segmentation_ignore_label,
panoptic_ignore_label=params.parser.panoptic_ignore_label,
include_panoptic_masks=params.parser.include_panoptic_masks)
reader = input_reader_factory.input_reader_generator(
params,
......
......@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""All necessary imports for registration."""
# pylint: disable=unused-import
......
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: ''
model:
backbone:
resnet:
model_id: 50
type: resnet
projection_head:
ft_proj_idx: 1
num_proj_layers: 3
proj_output_dim: 128
backbone_trainable: true
heads: !!python/tuple
# Define heads for the PRETRAIN networks here
- task_name: pretrain_imagenet
mode: pretrain
# # Define heads for the FINETUNE networks here
- task_name: finetune_imagenet_10percent
mode: finetune
supervised_head:
num_classes: 1001
zero_init: true
input_size: [224, 224, 3]
l2_weight_decay: 0.0
norm_activation:
norm_epsilon: 1.0e-05
norm_momentum: 0.9
use_sync_bn: true
task_routines: !!python/tuple
# Define TASK CONFIG for the PRETRAIN networks here
- task_name: pretrain_imagenet
task_weight: 30.0
task_config:
evaluation:
one_hot: true
top_k: 5
loss:
l2_weight_decay: 0.0
projection_norm: true
temperature: 0.1
model:
input_size: [224, 224, 3]
mode: pretrain
train_data:
input_path: /readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/train*
input_set_label_to_zero: true # Set labels to zeros to double confirm that no label is used during pretrain
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
parser:
aug_rand_hflip: true
mode: pretrain
decoder:
decode_label: true
validation_data:
input_path: /readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/valid*
is_training: false
global_batch_size: 2048
dtype: 'bfloat16'
drop_remainder: false
parser:
mode: pretrain
decoder:
decode_label: true
# Define TASK CONFIG for the FINETUNE Networks here
- task_name: finetune_imagenet_10percent
task_weight: 1.0
task_config:
evaluation:
one_hot: true
top_k: 5
loss:
l2_weight_decay: 0.0
label_smoothing: 0.0
one_hot: true
model:
input_size: [224, 224, 3]
mode: finetune
supervised_head:
num_classes: 1001
zero_init: true
train_data:
tfds_name: 'imagenet2012_subset/10pct'
tfds_split: 'train'
input_path: ''
is_training: true
global_batch_size: 1024
dtype: 'bfloat16'
parser:
aug_rand_hflip: true
mode: finetune
decoder:
decode_label: true
validation_data:
tfds_name: 'imagenet2012_subset/10pct'
tfds_split: 'validation'
input_path: ''
is_training: false
global_batch_size: 2048
dtype: 'bfloat16'
drop_remainder: false
parser:
mode: finetune
decoder:
decode_label: true
trainer:
trainer_type: interleaving
task_sampler:
proportional:
alpha: 1.0
type: proportional
train_steps: 32000 # 100 epochs
validation_steps: 24 # NUM_EXAMPLES (50000) // global_batch_size
validation_interval: 625
steps_per_loop: 625 # NUM_EXAMPLES (1281167) // global_batch_size
summary_interval: 625
checkpoint_interval: 625
max_to_keep: 3
optimizer_config:
learning_rate:
cosine:
decay_steps: 32000
initial_learning_rate: 4.8
type: cosine
optimizer:
lars:
exclude_from_weight_decay: [batch_normalization, bias]
momentum: 0.9
weight_decay_rate: 1.0e-06
type: lars
warmup:
linear:
name: linear
warmup_steps: 3200
type: linear
......@@ -29,6 +29,7 @@ from official.vision.beta.projects.simclr.modeling import simclr_model
@dataclasses.dataclass
class SimCLRMTHeadConfig(hyperparams.Config):
"""Per-task specific configs."""
task_name: str = 'task_name'
# Supervised head is required for finetune, but optional for pretrain.
supervised_head: simclr_configs.SupervisedHead = simclr_configs.SupervisedHead(
num_classes=1001)
......@@ -50,6 +51,9 @@ class SimCLRMTModelConfig(hyperparams.Config):
# L2 weight decay is used in the model, not in task.
# Note that this can not be used together with lars optimizer.
l2_weight_decay: float = 0.0
init_checkpoint: str = ''
# backbone_projection or backbone
init_checkpoint_modules: str = 'backbone_projection'
@exp_factory.register_config_factory('multitask_simclr')
......@@ -57,14 +61,17 @@ def multitask_simclr() -> multitask_configs.MultiTaskExperimentConfig:
return multitask_configs.MultiTaskExperimentConfig(
task=multitask_configs.MultiTaskConfig(
model=SimCLRMTModelConfig(
heads=(SimCLRMTHeadConfig(mode=simclr_model.PRETRAIN),
SimCLRMTHeadConfig(mode=simclr_model.FINETUNE))),
heads=(SimCLRMTHeadConfig(
task_name='pretrain_simclr', mode=simclr_model.PRETRAIN),
SimCLRMTHeadConfig(
task_name='finetune_simclr',
mode=simclr_model.FINETUNE))),
task_routines=(multitask_configs.TaskRoutine(
task_name=simclr_model.PRETRAIN,
task_name='pretrain_simclr',
task_config=simclr_configs.SimCLRPretrainTask(),
task_weight=2.0),
multitask_configs.TaskRoutine(
task_name=simclr_model.FINETUNE,
task_name='finetune_simclr',
task_config=simclr_configs.SimCLRFinetuneTask(),
task_weight=1.0))),
trainer=multitask_configs.MultiTaskTrainerConfig())
......@@ -12,27 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SimCLR configurations."""
import dataclasses
import os
from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
......@@ -73,6 +57,9 @@ class DataConfig(cfg.DataConfig):
# simclr specific configs
parser: Parser = Parser()
decoder: Decoder = Decoder()
# Useful when doing a sanity check that we absolutely use no labels while
# pretrain by setting labels to zeros (default = False, keep original labels)
input_set_label_to_zero: bool = False
@dataclasses.dataclass
......@@ -115,9 +102,7 @@ class SimCLRModel(hyperparams.Config):
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet())
projection_head: ProjectionHead = ProjectionHead(
proj_output_dim=128,
num_proj_layers=3,
ft_proj_idx=1)
proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1)
supervised_head: SupervisedHead = SupervisedHead(num_classes=1001)
norm_activation: common.NormActivation = common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)
......@@ -201,9 +186,7 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
backbone=backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50)),
projection_head=ProjectionHead(
proj_output_dim=128,
num_proj_layers=3,
ft_proj_idx=1),
proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1),
supervised_head=SupervisedHead(num_classes=1001),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=True)),
......@@ -233,10 +216,13 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
'optimizer': {
'type': 'lars',
'lars': {
'momentum': 0.9,
'weight_decay_rate': 0.000001,
'momentum':
0.9,
'weight_decay_rate':
0.000001,
'exclude_from_weight_decay': [
'batch_normalization', 'bias']
'batch_normalization', 'bias'
]
}
},
'learning_rate': {
......@@ -278,11 +264,8 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
backbone=backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50)),
projection_head=ProjectionHead(
proj_output_dim=128,
num_proj_layers=3,
ft_proj_idx=1),
supervised_head=SupervisedHead(
num_classes=1001, zero_init=True),
proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1),
supervised_head=SupervisedHead(num_classes=1001, zero_init=True),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)),
loss=ClassificationLosses(),
......@@ -311,10 +294,13 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
'optimizer': {
'type': 'lars',
'lars': {
'momentum': 0.9,
'weight_decay_rate': 0.0,
'momentum':
0.9,
'weight_decay_rate':
0.0,
'exclude_from_weight_decay': [
'batch_normalization', 'bias']
'batch_normalization', 'bias'
]
}
},
'learning_rate': {
......
......@@ -12,23 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for simclr."""
# pylint: disable=unused-import
"""Tests for SimCLR config."""
from absl.testing import parameterized
import tensorflow as tf
......
......@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessing ops."""
import functools
import tensorflow as tf
......
......@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data parser and processing for SimCLR.
For pre-training:
......
......@@ -12,21 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dense prediction heads."""
"""SimCLR prediction heads."""
from typing import Text, Optional
......
......@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized
import numpy as np
......
......@@ -12,21 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contrastive loss functions."""
import functools
......
......@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized
import numpy as np
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment