Commit aba78478 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 365713370
parent f3f3ec34
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build simclr models."""
from typing import Optional
from absl import logging
import tensorflow as tf
layers = tf.keras.layers
PRETRAIN = 'pretrain'
FINETUNE = 'finetune'
PROJECTION_OUTPUT_KEY = 'projection_outputs'
SUPERVISED_OUTPUT_KEY = 'supervised_outputs'
@tf.keras.utils.register_keras_serializable(package='simclr')
class SimCLRModel(tf.keras.Model):
"""A classification model based on SimCLR framework."""
def __init__(self,
backbone: tf.keras.models.Model,
projection_head: tf.keras.layers.Layer,
supervised_head: Optional[tf.keras.layers.Layer] = None,
input_specs=layers.InputSpec(shape=[None, None, None, 3]),
mode: str = PRETRAIN,
backbone_trainable: bool = True,
**kwargs):
"""A classification model based on SimCLR framework.
Args:
backbone: a backbone network.
projection_head: a projection head network.
supervised_head: a head network for supervised learning, e.g.
classification head.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
mode: `str` indicates mode of training to be executed.
backbone_trainable: `bool` whether the backbone is trainable or not.
**kwargs: keyword arguments to be passed.
"""
super(SimCLRModel, self).__init__(**kwargs)
self._config_dict = {
'backbone': backbone,
'projection_head': projection_head,
'supervised_head': supervised_head,
'input_specs': input_specs,
'mode': mode,
'backbone_trainable': backbone_trainable,
}
self._input_specs = input_specs
self._backbone = backbone
self._projection_head = projection_head
self._supervised_head = supervised_head
self._mode = mode
self._backbone_trainable = backbone_trainable
# Set whether the backbone is trainable
self._backbone.trainable = backbone_trainable
def call(self, inputs, training=None, **kwargs):
model_outputs = {}
if training and self._mode == PRETRAIN:
num_transforms = 2
else:
num_transforms = 1
# Split channels, and optionally apply extra batched augmentation.
# (bsz, h, w, c*num_transforms) -> [(bsz, h, w, c), ....]
features_list = tf.split(inputs, num_or_size_splits=num_transforms, axis=-1)
# (num_transforms * bsz, h, w, c)
features = tf.concat(features_list, 0)
# Base network forward pass.
endpoints = self._backbone(features, training=training)
features = endpoints[max(endpoints.keys())]
projection_inputs = layers.GlobalAveragePooling2D()(features)
# Add heads.
projection_outputs, supervised_inputs = self._projection_head(
projection_inputs, training)
if self._supervised_head is not None:
if self._mode == PRETRAIN:
logging.info('Ignoring gradient from supervised outputs !')
# When performing pretraining and supervised_head together, we do not
# want information from supervised evaluation flowing back into
# pretraining network. So we put a stop_gradient.
supervised_outputs = self._supervised_head(
tf.stop_gradient(supervised_inputs), training)
else:
supervised_outputs = self._supervised_head(supervised_inputs, training)
else:
supervised_outputs = None
model_outputs.update({
PROJECTION_OUTPUT_KEY: projection_outputs,
SUPERVISED_OUTPUT_KEY: supervised_outputs
})
return model_outputs
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
if self._supervised_head is not None:
items = dict(backbone=self.backbone,
projection_head=self.projection_head,
supervised_head=self.supervised_head)
else:
items = dict(backbone=self.backbone,
projection_head=self.projection_head)
return items
@property
def backbone(self):
return self._backbone
@property
def projection_head(self):
return self._projection_head
@property
def supervised_head(self):
return self._supervised_head
@property
def mode(self):
return self._mode
@mode.setter
def mode(self, value):
self._mode = value
@property
def backbone_trainable(self):
return self._backbone_trainable
@backbone_trainable.setter
def backbone_trainable(self, value):
self._backbone_trainable = value
self._backbone.trainable = value
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.modeling import backbones
from official.vision.beta.projects.simclr.heads import simclr_head
from official.vision.beta.projects.simclr.modeling import simclr_model
class SimCLRModelTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(128, 3, 0),
(128, 3, 1),
(128, 1, 0),
(128, 1, 1),
)
def test_model_creation(self, project_dim, num_proj_layers, ft_proj_idx):
input_size = 224
inputs = np.random.rand(2, input_size, input_size, 3)
input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size, input_size, 3])
tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.ResNet(model_id=50, activation='relu',
input_specs=input_specs)
projection_head = simclr_head.ProjectionHead(
proj_output_dim=project_dim,
num_proj_layers=num_proj_layers,
ft_proj_idx=ft_proj_idx
)
num_classes = 10
supervised_head = simclr_head.ClassificationHead(
num_classes=10
)
model = simclr_model.SimCLRModel(
input_specs=input_specs,
backbone=backbone,
projection_head=projection_head,
supervised_head=supervised_head,
mode=simclr_model.PRETRAIN
)
outputs = model(inputs)
projection_outputs = outputs[simclr_model.PROJECTION_OUTPUT_KEY]
supervised_outputs = outputs[simclr_model.SUPERVISED_OUTPUT_KEY]
self.assertAllEqual(projection_outputs.shape.as_list(),
[2, project_dim])
self.assertAllEqual([2, num_classes],
supervised_outputs.numpy().shape)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image SimCLR task definition.
SimCLR training two different modes:
- pretrain
- fine-tuning
For the above two different modes, the following components are different in
the task definition:
- training data format
- training loss
- projection_head and/or supervised_head
"""
from typing import Dict, Optional
from absl import logging
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.core import input_reader
from official.core import task_factory
from official.modeling import optimization
from official.modeling import performance
from official.modeling import tf_utils
from official.vision.beta.modeling import backbones
from official.vision.beta.projects.simclr.configs import simclr as exp_cfg
from official.vision.beta.projects.simclr.dataloaders import simclr_input
from official.vision.beta.projects.simclr.heads import simclr_head
from official.vision.beta.projects.simclr.losses import contrastive_losses
from official.vision.beta.projects.simclr.modeling import simclr_model
OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig
@task_factory.register_task_cls(exp_cfg.SimCLRPretrainTask)
class SimCLRPretrainTask(base_task.Task):
"""A task for image classification."""
def create_optimizer(self, optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
Args:
optimizer_config: the parameters of the Optimization settings.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
if (optimizer_config.optimizer.type == 'lars'
and self.task_config.loss.l2_weight_decay > 0.0):
raise ValueError('The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.')
opt_factory = optimization.OptimizerFactory(optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if runtime_config and runtime_config.loss_scale:
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == 'float16',
loss_scale=runtime_config.loss_scale)
return optimizer
def build_model(self):
model_config = self.task_config.model
input_specs = tf.keras.layers.InputSpec(
shape=[None] + model_config.input_size)
l2_weight_decay = self.task_config.loss.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
# Build backbone
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
# Build projection head
norm_activation_config = model_config.norm_activation
projection_head_config = model_config.projection_head
projection_head = simclr_head.ProjectionHead(
proj_output_dim=projection_head_config.proj_output_dim,
num_proj_layers=projection_head_config.num_proj_layers,
ft_proj_idx=projection_head_config.ft_proj_idx,
kernel_regularizer=l2_regularizer,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
# Build supervised head
supervised_head_config = model_config.supervised_head
if supervised_head_config:
if supervised_head_config.zero_init:
s_kernel_initializer = 'zeros'
else:
s_kernel_initializer = 'random_uniform'
supervised_head = simclr_head.ClassificationHead(
num_classes=supervised_head_config.num_classes,
kernel_initializer=s_kernel_initializer,
kernel_regularizer=l2_regularizer)
else:
supervised_head = None
model = simclr_model.SimCLRModel(
input_specs=input_specs,
backbone=backbone,
projection_head=projection_head,
supervised_head=supervised_head,
mode=model_config.mode,
backbone_trainable=model_config.backbone_trainable)
logging.info(model.get_config())
return model
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
assert "Only 'all' or 'backbone' can be used to initialize the model."
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self, params, input_context=None):
input_size = self.task_config.model.input_size
if params.tfds_name:
decoder = simclr_input.TFDSDecoder(params.decoder.decode_label)
else:
decoder = simclr_input.Decoder(params.decoder.decode_label)
parser = simclr_input.Parser(
output_size=input_size[:2],
aug_rand_crop=params.parser.aug_rand_crop,
aug_rand_hflip=params.parser.aug_rand_hflip,
aug_color_distort=params.parser.aug_color_distort,
aug_color_jitter_strength=params.parser.aug_color_jitter_strength,
aug_color_jitter_impl=params.parser.aug_color_jitter_impl,
aug_rand_blur=params.parser.aug_rand_blur,
parse_label=params.parser.parse_label,
test_crop=params.parser.test_crop,
mode=params.parser.mode,
dtype=params.dtype)
reader = input_reader.InputReader(
params,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self,
labels,
model_outputs,
aux_losses=None) -> Dict[str, tf.Tensor]:
# Compute contrastive relative loss
con_losses_obj = contrastive_losses.ContrastiveLoss(
projection_norm=self.task_config.loss.projection_norm,
temperature=self.task_config.loss.temperature)
# The projection outputs from model has the size of
# (2 * bsz, project_dim)
projection_outputs = model_outputs[simclr_model.PROJECTION_OUTPUT_KEY]
projection1, projection2 = tf.split(projection_outputs, 2, 0)
contrast_loss, (contrast_logits, contrast_labels) = con_losses_obj(
projection1=projection1,
projection2=projection2)
contrast_accuracy = tf.equal(
tf.argmax(contrast_labels, axis=1), tf.argmax(contrast_logits, axis=1))
contrast_accuracy = tf.reduce_mean(tf.cast(contrast_accuracy, tf.float32))
contrast_prob = tf.nn.softmax(contrast_logits)
contrast_entropy = -tf.reduce_mean(
tf.reduce_sum(contrast_prob * tf.math.log(contrast_prob + 1e-8), -1))
model_loss = contrast_loss
losses = {
'contrast_loss': contrast_loss,
'contrast_accuracy': contrast_accuracy,
'contrast_entropy': contrast_entropy
}
if self.task_config.model.supervised_head is not None:
outputs = model_outputs[simclr_model.SUPERVISED_OUTPUT_KEY]
labels = tf.concat([labels, labels], 0)
if self.task_config.evaluation.one_hot:
sup_loss = tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels,
outputs)
else:
sup_loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels,
outputs)
sup_loss = tf.reduce_mean(sup_loss)
label_acc = tf.equal(tf.argmax(labels, axis=1),
tf.argmax(outputs, axis=1))
label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))
model_loss = contrast_loss + sup_loss
losses.update({
'accuracy': label_acc,
'supervised_loss': sup_loss,
})
total_loss = model_loss
if aux_losses:
reg_loss = tf.reduce_sum(aux_losses)
total_loss = model_loss + reg_loss
losses['total_loss'] = total_loss
return losses
def build_metrics(self, training=True):
if training:
metrics = []
metric_names = [
'total_loss',
'contrast_loss',
'contrast_accuracy',
'contrast_entropy'
]
if self.task_config.model.supervised_head:
metric_names.extend(['supervised_loss', 'accuracy'])
for name in metric_names:
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
else:
k = self.task_config.evaluation.top_k
if self.task_config.evaluation.one_hot:
metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
features, labels = inputs
if (self.task_config.model.supervised_head is not None
and self.task_config.evaluation.one_hot):
num_classes = self.task_config.model.supervised_head.num_classes
labels = tf.one_hot(labels, num_classes)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
losses = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
scaled_loss = losses['total_loss'] / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
logging.info('Trainable variables:')
for var in tvars:
logging.info(var.name)
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient when LossScaleOptimizer is used.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: losses['total_loss']}
for m in metrics:
m.update_state(losses[m.name])
logs.update({m.name: m.result()})
return logs
def validation_step(self, inputs, model, metrics=None):
if self.task_config.model.supervised_head is None:
assert 'Skipping eval during pretraining without supervised head.'
features, labels = inputs
if self.task_config.evaluation.one_hot:
num_classes = self.task_config.model.supervised_head.num_classes
labels = tf.one_hot(labels, num_classes)
outputs = model(
features, training=False)[simclr_model.SUPERVISED_OUTPUT_KEY]
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
logs = {self.loss: 0}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
@task_factory.register_task_cls(exp_cfg.SimCLRFinetuneTask)
class SimCLRFinetuneTask(base_task.Task):
"""A task for image classification."""
def create_optimizer(self, optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
Args:
optimizer_config: the parameters of the Optimization settings.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
if (optimizer_config.optimizer.type == 'lars'
and self.task_config.loss.l2_weight_decay > 0.0):
raise ValueError('The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.')
opt_factory = optimization.OptimizerFactory(optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if runtime_config and runtime_config.loss_scale:
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == 'float16',
loss_scale=runtime_config.loss_scale)
return optimizer
def build_model(self):
model_config = self.task_config.model
input_specs = tf.keras.layers.InputSpec(
shape=[None] + model_config.input_size)
l2_weight_decay = self.task_config.loss.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
norm_activation_config = model_config.norm_activation
projection_head_config = model_config.projection_head
projection_head = simclr_head.ProjectionHead(
proj_output_dim=projection_head_config.proj_output_dim,
num_proj_layers=projection_head_config.num_proj_layers,
ft_proj_idx=projection_head_config.ft_proj_idx,
kernel_regularizer=l2_regularizer,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
supervised_head_config = model_config.supervised_head
if supervised_head_config.zero_init:
s_kernel_initializer = 'zeros'
else:
s_kernel_initializer = 'random_uniform'
supervised_head = simclr_head.ClassificationHead(
num_classes=supervised_head_config.num_classes,
kernel_initializer=s_kernel_initializer,
kernel_regularizer=l2_regularizer)
model = simclr_model.SimCLRModel(
input_specs=input_specs,
backbone=backbone,
projection_head=projection_head,
supervised_head=supervised_head,
mode=model_config.mode,
backbone_trainable=model_config.backbone_trainable)
logging.info(model.get_config())
return model
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
elif self.task_config.init_checkpoint_modules == 'backbone_projection':
ckpt = tf.train.Checkpoint(backbone=model.backbone,
projection_head=model.projection_head)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
assert "Only 'all' or 'backbone' can be used to initialize the model."
# If the checkpoint is from pretraining, reset the following parameters
model.backbone_trainable = self.task_config.model.backbone_trainable
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self, params, input_context=None):
input_size = self.task_config.model.input_size
if params.tfds_name:
decoder = simclr_input.TFDSDecoder(params.decoder.decode_label)
else:
decoder = simclr_input.Decoder(params.decoder.decode_label)
parser = simclr_input.Parser(
output_size=input_size[:2],
parse_label=params.parser.parse_label,
test_crop=params.parser.test_crop,
mode=params.parser.mode,
dtype=params.dtype)
reader = input_reader.InputReader(
params,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self, labels, model_outputs, aux_losses=None):
"""Sparse categorical cross entropy loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
losses_config = self.task_config.loss
if losses_config.one_hot:
total_loss = tf.keras.losses.categorical_crossentropy(
labels,
model_outputs,
from_logits=True,
label_smoothing=losses_config.label_smoothing)
else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=True)
total_loss = tf_utils.safe_mean(total_loss)
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
def build_metrics(self, training=True):
"""Gets streaming metrics for training/validation."""
k = self.task_config.evaluation.top_k
if self.task_config.evaluation.one_hot:
metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
if self.task_config.loss.one_hot:
num_classes = self.task_config.model.supervised_head.num_classes
labels = tf.one_hot(labels, num_classes)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(
features, training=True)[simclr_model.SUPERVISED_OUTPUT_KEY]
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs,
labels=labels, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
logging.info('Trainable variables:')
for var in tvars:
logging.info(var.name)
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def validation_step(self, inputs, model, metrics=None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
if self.task_config.loss.one_hot:
num_classes = self.task_config.model.supervised_head.num_classes
labels = tf.one_hot(labels, num_classes)
outputs = self.inference_step(
features, model)[simclr_model.SUPERVISED_OUTPUT_KEY]
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs,
labels=labels, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Model Garden Vision SimCLR training driver."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.vision.beta.projects.simclr.common import registry_imports # pylint: disable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
print(FLAGS.experiment)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment