Commit 59e1ab8a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 388514034
parent 08b68031
...@@ -12,20 +12,6 @@ ...@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""All necessary imports for registration.""" """All necessary imports for registration."""
# pylint: disable=unused-import # pylint: disable=unused-import
......
...@@ -12,26 +12,11 @@ ...@@ -12,26 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SimCLR configurations.""" """SimCLR configurations."""
import os import dataclasses
import os.path
from typing import List, Optional from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
...@@ -115,9 +100,7 @@ class SimCLRModel(hyperparams.Config): ...@@ -115,9 +100,7 @@ class SimCLRModel(hyperparams.Config):
backbone: backbones.Backbone = backbones.Backbone( backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet()) type='resnet', resnet=backbones.ResNet())
projection_head: ProjectionHead = ProjectionHead( projection_head: ProjectionHead = ProjectionHead(
proj_output_dim=128, proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1)
num_proj_layers=3,
ft_proj_idx=1)
supervised_head: SupervisedHead = SupervisedHead(num_classes=1001) supervised_head: SupervisedHead = SupervisedHead(num_classes=1001)
norm_activation: common.NormActivation = common.NormActivation( norm_activation: common.NormActivation = common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False) norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)
...@@ -201,9 +184,7 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig: ...@@ -201,9 +184,7 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50)), type='resnet', resnet=backbones.ResNet(model_id=50)),
projection_head=ProjectionHead( projection_head=ProjectionHead(
proj_output_dim=128, proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1),
num_proj_layers=3,
ft_proj_idx=1),
supervised_head=SupervisedHead(num_classes=1001), supervised_head=SupervisedHead(num_classes=1001),
norm_activation=common.NormActivation( norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=True)), norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=True)),
...@@ -233,10 +214,13 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig: ...@@ -233,10 +214,13 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
'optimizer': { 'optimizer': {
'type': 'lars', 'type': 'lars',
'lars': { 'lars': {
'momentum': 0.9, 'momentum':
'weight_decay_rate': 0.000001, 0.9,
'weight_decay_rate':
0.000001,
'exclude_from_weight_decay': [ 'exclude_from_weight_decay': [
'batch_normalization', 'bias'] 'batch_normalization', 'bias'
]
} }
}, },
'learning_rate': { 'learning_rate': {
...@@ -278,11 +262,8 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig: ...@@ -278,11 +262,8 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50)), type='resnet', resnet=backbones.ResNet(model_id=50)),
projection_head=ProjectionHead( projection_head=ProjectionHead(
proj_output_dim=128, proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1),
num_proj_layers=3, supervised_head=SupervisedHead(num_classes=1001, zero_init=True),
ft_proj_idx=1),
supervised_head=SupervisedHead(
num_classes=1001, zero_init=True),
norm_activation=common.NormActivation( norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)), norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)),
loss=ClassificationLosses(), loss=ClassificationLosses(),
...@@ -311,10 +292,13 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig: ...@@ -311,10 +292,13 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
'optimizer': { 'optimizer': {
'type': 'lars', 'type': 'lars',
'lars': { 'lars': {
'momentum': 0.9, 'momentum':
'weight_decay_rate': 0.0, 0.9,
'weight_decay_rate':
0.0,
'exclude_from_weight_decay': [ 'exclude_from_weight_decay': [
'batch_normalization', 'bias'] 'batch_normalization', 'bias'
]
} }
}, },
'learning_rate': { 'learning_rate': {
......
...@@ -12,23 +12,7 @@ ...@@ -12,23 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3 """Tests for SimCLR config."""
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for simclr."""
# pylint: disable=unused-import
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
......
...@@ -12,20 +12,6 @@ ...@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessing ops.""" """Preprocessing ops."""
import functools import functools
import tensorflow as tf import tensorflow as tf
......
...@@ -12,20 +12,6 @@ ...@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data parser and processing for SimCLR. """Data parser and processing for SimCLR.
For pre-training: For pre-training:
......
...@@ -12,21 +12,7 @@ ...@@ -12,21 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. """SimCLR prediction heads."""
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dense prediction heads."""
from typing import Text, Optional from typing import Text, Optional
......
...@@ -12,22 +12,6 @@ ...@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
......
...@@ -12,21 +12,6 @@ ...@@ -12,21 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contrastive loss functions.""" """Contrastive loss functions."""
import functools import functools
......
...@@ -12,22 +12,6 @@ ...@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
......
...@@ -12,22 +12,6 @@ ...@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common building blocks for simclr neural networks.""" """Contains common building blocks for simclr neural networks."""
from typing import Text, Optional from typing import Text, Optional
......
...@@ -12,22 +12,6 @@ ...@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
......
...@@ -12,22 +12,7 @@ ...@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build simclr models.""" """Build simclr models."""
from typing import Optional from typing import Optional
from absl import logging from absl import logging
...@@ -133,12 +118,12 @@ class SimCLRModel(tf.keras.Model): ...@@ -133,12 +118,12 @@ class SimCLRModel(tf.keras.Model):
def checkpoint_items(self): def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed.""" """Returns a dictionary of items to be additionally checkpointed."""
if self._supervised_head is not None: if self._supervised_head is not None:
items = dict(backbone=self.backbone, items = dict(
projection_head=self.projection_head, backbone=self.backbone,
supervised_head=self.supervised_head) projection_head=self.projection_head,
supervised_head=self.supervised_head)
else: else:
items = dict(backbone=self.backbone, items = dict(backbone=self.backbone, projection_head=self.projection_head)
projection_head=self.projection_head)
return items return items
@property @property
......
...@@ -12,22 +12,7 @@ ...@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3 """Test for SimCLR model."""
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
......
...@@ -12,21 +12,6 @@ ...@@ -12,21 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image SimCLR task definition. """Image SimCLR task definition.
SimCLR training two different modes: SimCLR training two different modes:
...@@ -39,7 +24,6 @@ the task definition: ...@@ -39,7 +24,6 @@ the task definition:
- training loss - training loss
- projection_head and/or supervised_head - projection_head and/or supervised_head
""" """
from typing import Dict, Optional from typing import Dict, Optional
from absl import logging from absl import logging
...@@ -67,7 +51,8 @@ RuntimeConfig = config_definitions.RuntimeConfig ...@@ -67,7 +51,8 @@ RuntimeConfig = config_definitions.RuntimeConfig
class SimCLRPretrainTask(base_task.Task): class SimCLRPretrainTask(base_task.Task):
"""A task for image classification.""" """A task for image classification."""
def create_optimizer(self, optimizer_config: OptimizationConfig, def create_optimizer(self,
optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None): runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations. """Creates an TF optimizer from configurations.
...@@ -78,8 +63,8 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -78,8 +63,8 @@ class SimCLRPretrainTask(base_task.Task):
Returns: Returns:
A tf.optimizers.Optimizer object. A tf.optimizers.Optimizer object.
""" """
if (optimizer_config.optimizer.type == 'lars' if (optimizer_config.optimizer.type == 'lars' and
and self.task_config.loss.l2_weight_decay > 0.0): self.task_config.loss.l2_weight_decay > 0.0):
raise ValueError('The l2_weight_decay cannot be used together with lars ' raise ValueError('The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.') 'optimizer. Please set it to 0.')
...@@ -97,15 +82,16 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -97,15 +82,16 @@ class SimCLRPretrainTask(base_task.Task):
def build_model(self): def build_model(self):
model_config = self.task_config.model model_config = self.task_config.model
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(shape=[None] +
shape=[None] + model_config.input_size) model_config.input_size)
l2_weight_decay = self.task_config.loss.l2_weight_decay l2_weight_decay = self.task_config.loss.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss. # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2) # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss) # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2( l2_regularizer = (
l2_weight_decay / 2.0) if l2_weight_decay else None) tf.keras.regularizers.l2(l2_weight_decay /
2.0) if l2_weight_decay else None)
# Build backbone # Build backbone
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
...@@ -220,8 +206,7 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -220,8 +206,7 @@ class SimCLRPretrainTask(base_task.Task):
projection_outputs = model_outputs[simclr_model.PROJECTION_OUTPUT_KEY] projection_outputs = model_outputs[simclr_model.PROJECTION_OUTPUT_KEY]
projection1, projection2 = tf.split(projection_outputs, 2, 0) projection1, projection2 = tf.split(projection_outputs, 2, 0)
contrast_loss, (contrast_logits, contrast_labels) = con_losses_obj( contrast_loss, (contrast_logits, contrast_labels) = con_losses_obj(
projection1=projection1, projection1=projection1, projection2=projection2)
projection2=projection2)
contrast_accuracy = tf.equal( contrast_accuracy = tf.equal(
tf.argmax(contrast_labels, axis=1), tf.argmax(contrast_logits, axis=1)) tf.argmax(contrast_labels, axis=1), tf.argmax(contrast_logits, axis=1))
...@@ -253,8 +238,8 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -253,8 +238,8 @@ class SimCLRPretrainTask(base_task.Task):
outputs) outputs)
sup_loss = tf.reduce_mean(sup_loss) sup_loss = tf.reduce_mean(sup_loss)
label_acc = tf.equal(tf.argmax(labels, axis=1), label_acc = tf.equal(
tf.argmax(outputs, axis=1)) tf.argmax(labels, axis=1), tf.argmax(outputs, axis=1))
label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32)) label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))
model_loss = contrast_loss + sup_loss model_loss = contrast_loss + sup_loss
...@@ -278,10 +263,7 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -278,10 +263,7 @@ class SimCLRPretrainTask(base_task.Task):
if training: if training:
metrics = [] metrics = []
metric_names = [ metric_names = [
'total_loss', 'total_loss', 'contrast_loss', 'contrast_accuracy', 'contrast_entropy'
'contrast_loss',
'contrast_accuracy',
'contrast_entropy'
] ]
if self.task_config.model.supervised_head: if self.task_config.model.supervised_head:
metric_names.extend(['supervised_loss', 'accuracy']) metric_names.extend(['supervised_loss', 'accuracy'])
...@@ -293,18 +275,20 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -293,18 +275,20 @@ class SimCLRPretrainTask(base_task.Task):
metrics = [ metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'), tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy( tf.keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))] k=k, name='top_{}_accuracy'.format(k))
]
else: else:
metrics = [ metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'), tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy( tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))] k=k, name='top_{}_accuracy'.format(k))
]
return metrics return metrics
def train_step(self, inputs, model, optimizer, metrics=None): def train_step(self, inputs, model, optimizer, metrics=None):
features, labels = inputs features, labels = inputs
if (self.task_config.model.supervised_head is not None if (self.task_config.model.supervised_head is not None and
and self.task_config.evaluation.one_hot): self.task_config.evaluation.one_hot):
num_classes = self.task_config.model.supervised_head.num_classes num_classes = self.task_config.model.supervised_head.num_classes
labels = tf.one_hot(labels, num_classes) labels = tf.one_hot(labels, num_classes)
...@@ -313,8 +297,7 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -313,8 +297,7 @@ class SimCLRPretrainTask(base_task.Task):
outputs = model(features, training=True) outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is # Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure( outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss. # Computes per-replica loss.
losses = self.build_losses( losses = self.build_losses(
...@@ -373,7 +356,8 @@ class SimCLRPretrainTask(base_task.Task): ...@@ -373,7 +356,8 @@ class SimCLRPretrainTask(base_task.Task):
class SimCLRFinetuneTask(base_task.Task): class SimCLRFinetuneTask(base_task.Task):
"""A task for image classification.""" """A task for image classification."""
def create_optimizer(self, optimizer_config: OptimizationConfig, def create_optimizer(self,
optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None): runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations. """Creates an TF optimizer from configurations.
...@@ -384,8 +368,8 @@ class SimCLRFinetuneTask(base_task.Task): ...@@ -384,8 +368,8 @@ class SimCLRFinetuneTask(base_task.Task):
Returns: Returns:
A tf.optimizers.Optimizer object. A tf.optimizers.Optimizer object.
""" """
if (optimizer_config.optimizer.type == 'lars' if (optimizer_config.optimizer.type == 'lars' and
and self.task_config.loss.l2_weight_decay > 0.0): self.task_config.loss.l2_weight_decay > 0.0):
raise ValueError('The l2_weight_decay cannot be used together with lars ' raise ValueError('The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.') 'optimizer. Please set it to 0.')
...@@ -403,15 +387,16 @@ class SimCLRFinetuneTask(base_task.Task): ...@@ -403,15 +387,16 @@ class SimCLRFinetuneTask(base_task.Task):
def build_model(self): def build_model(self):
model_config = self.task_config.model model_config = self.task_config.model
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(shape=[None] +
shape=[None] + model_config.input_size) model_config.input_size)
l2_weight_decay = self.task_config.loss.l2_weight_decay l2_weight_decay = self.task_config.loss.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss. # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2) # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss) # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2( l2_regularizer = (
l2_weight_decay / 2.0) if l2_weight_decay else None) tf.keras.regularizers.l2(l2_weight_decay /
2.0) if l2_weight_decay else None)
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
...@@ -467,8 +452,8 @@ class SimCLRFinetuneTask(base_task.Task): ...@@ -467,8 +452,8 @@ class SimCLRFinetuneTask(base_task.Task):
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed() status.assert_consumed()
elif self.task_config.init_checkpoint_modules == 'backbone_projection': elif self.task_config.init_checkpoint_modules == 'backbone_projection':
ckpt = tf.train.Checkpoint(backbone=model.backbone, ckpt = tf.train.Checkpoint(
projection_head=model.projection_head) backbone=model.backbone, projection_head=model.projection_head)
status = ckpt.restore(ckpt_dir_or_file) status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched() status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone': elif self.task_config.init_checkpoint_modules == 'backbone':
...@@ -542,12 +527,14 @@ class SimCLRFinetuneTask(base_task.Task): ...@@ -542,12 +527,14 @@ class SimCLRFinetuneTask(base_task.Task):
metrics = [ metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'), tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy( tf.keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))] k=k, name='top_{}_accuracy'.format(k))
]
else: else:
metrics = [ metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'), tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy( tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))] k=k, name='top_{}_accuracy'.format(k))
]
return metrics return metrics
def train_step(self, inputs, model, optimizer, metrics=None): def train_step(self, inputs, model, optimizer, metrics=None):
...@@ -577,16 +564,14 @@ class SimCLRFinetuneTask(base_task.Task): ...@@ -577,16 +564,14 @@ class SimCLRFinetuneTask(base_task.Task):
# Computes per-replica loss. # Computes per-replica loss.
loss = self.build_losses( loss = self.build_losses(
model_outputs=outputs, model_outputs=outputs, labels=labels, aux_losses=model.losses)
labels=labels, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the # Scales loss as the default gradients allreduce performs sum inside the
# optimizer. # optimizer.
scaled_loss = loss / num_replicas scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is # For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability. # scaled for numerical stability.
if isinstance( if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss) scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables tvars = model.trainable_variables
...@@ -596,8 +581,7 @@ class SimCLRFinetuneTask(base_task.Task): ...@@ -596,8 +581,7 @@ class SimCLRFinetuneTask(base_task.Task):
grads = tape.gradient(scaled_loss, tvars) grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is # Scales back gradient before apply_gradients when LossScaleOptimizer is
# used. # used.
if isinstance( if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads) grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars))) optimizer.apply_gradients(list(zip(grads, tvars)))
...@@ -626,11 +610,11 @@ class SimCLRFinetuneTask(base_task.Task): ...@@ -626,11 +610,11 @@ class SimCLRFinetuneTask(base_task.Task):
num_classes = self.task_config.model.supervised_head.num_classes num_classes = self.task_config.model.supervised_head.num_classes
labels = tf.one_hot(labels, num_classes) labels = tf.one_hot(labels, num_classes)
outputs = self.inference_step( outputs = self.inference_step(features,
features, model)[simclr_model.SUPERVISED_OUTPUT_KEY] model)[simclr_model.SUPERVISED_OUTPUT_KEY]
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs, loss = self.build_losses(
labels=labels, aux_losses=model.losses) model_outputs=outputs, labels=labels, aux_losses=model.losses)
logs = {self.loss: loss} logs = {self.loss: loss}
if metrics: if metrics:
......
...@@ -12,22 +12,7 @@ ...@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3 """TensorFlow Model Garden Vision SimCLR trainer."""
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Model Garden Vision SimCLR training driver."""
from absl import app from absl import app
from absl import flags from absl import flags
import gin import gin
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment