Commit 59e1ab8a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 388514034
parent 08b68031
......@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""All necessary imports for registration."""
# pylint: disable=unused-import
......
......@@ -12,26 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SimCLR configurations."""
import os
import dataclasses
import os.path
from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
......@@ -115,9 +100,7 @@ class SimCLRModel(hyperparams.Config):
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet())
projection_head: ProjectionHead = ProjectionHead(
proj_output_dim=128,
num_proj_layers=3,
ft_proj_idx=1)
proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1)
supervised_head: SupervisedHead = SupervisedHead(num_classes=1001)
norm_activation: common.NormActivation = common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)
......@@ -201,9 +184,7 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
backbone=backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50)),
projection_head=ProjectionHead(
proj_output_dim=128,
num_proj_layers=3,
ft_proj_idx=1),
proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1),
supervised_head=SupervisedHead(num_classes=1001),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=True)),
......@@ -233,10 +214,13 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
'optimizer': {
'type': 'lars',
'lars': {
'momentum': 0.9,
'weight_decay_rate': 0.000001,
'momentum':
0.9,
'weight_decay_rate':
0.000001,
'exclude_from_weight_decay': [
'batch_normalization', 'bias']
'batch_normalization', 'bias'
]
}
},
'learning_rate': {
......@@ -278,11 +262,8 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
backbone=backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50)),
projection_head=ProjectionHead(
proj_output_dim=128,
num_proj_layers=3,
ft_proj_idx=1),
supervised_head=SupervisedHead(
num_classes=1001, zero_init=True),
proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1),
supervised_head=SupervisedHead(num_classes=1001, zero_init=True),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)),
loss=ClassificationLosses(),
......@@ -311,10 +292,13 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
'optimizer': {
'type': 'lars',
'lars': {
'momentum': 0.9,
'weight_decay_rate': 0.0,
'momentum':
0.9,
'weight_decay_rate':
0.0,
'exclude_from_weight_decay': [
'batch_normalization', 'bias']
'batch_normalization', 'bias'
]
}
},
'learning_rate': {
......
......@@ -12,23 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for simclr."""
# pylint: disable=unused-import
"""Tests for SimCLR config."""
from absl.testing import parameterized
import tensorflow as tf
......
......@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessing ops."""
import functools
import tensorflow as tf
......
......@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data parser and processing for SimCLR.
For pre-training:
......
......@@ -12,21 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dense prediction heads."""
"""SimCLR prediction heads."""
from typing import Text, Optional
......
......@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized
import numpy as np
......
......@@ -12,21 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contrastive loss functions."""
import functools
......
......@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized
import numpy as np
......
......@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common building blocks for simclr neural networks."""
from typing import Text, Optional
......
......@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized
import tensorflow as tf
......
......@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build simclr models."""
from typing import Optional
from absl import logging
......@@ -133,12 +118,12 @@ class SimCLRModel(tf.keras.Model):
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
if self._supervised_head is not None:
items = dict(backbone=self.backbone,
projection_head=self.projection_head,
supervised_head=self.supervised_head)
items = dict(
backbone=self.backbone,
projection_head=self.projection_head,
supervised_head=self.supervised_head)
else:
items = dict(backbone=self.backbone,
projection_head=self.projection_head)
items = dict(backbone=self.backbone, projection_head=self.projection_head)
return items
@property
......
......@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for SimCLR model."""
from absl.testing import parameterized
import numpy as np
......
......@@ -12,21 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image SimCLR task definition.
SimCLR training two different modes:
......@@ -39,7 +24,6 @@ the task definition:
- training loss
- projection_head and/or supervised_head
"""
from typing import Dict, Optional
from absl import logging
......@@ -67,7 +51,8 @@ RuntimeConfig = config_definitions.RuntimeConfig
class SimCLRPretrainTask(base_task.Task):
"""A task for image classification."""
def create_optimizer(self, optimizer_config: OptimizationConfig,
def create_optimizer(self,
optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
......@@ -78,8 +63,8 @@ class SimCLRPretrainTask(base_task.Task):
Returns:
A tf.optimizers.Optimizer object.
"""
if (optimizer_config.optimizer.type == 'lars'
and self.task_config.loss.l2_weight_decay > 0.0):
if (optimizer_config.optimizer.type == 'lars' and
self.task_config.loss.l2_weight_decay > 0.0):
raise ValueError('The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.')
......@@ -97,15 +82,16 @@ class SimCLRPretrainTask(base_task.Task):
def build_model(self):
model_config = self.task_config.model
input_specs = tf.keras.layers.InputSpec(
shape=[None] + model_config.input_size)
input_specs = tf.keras.layers.InputSpec(shape=[None] +
model_config.input_size)
l2_weight_decay = self.task_config.loss.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
l2_regularizer = (
tf.keras.regularizers.l2(l2_weight_decay /
2.0) if l2_weight_decay else None)
# Build backbone
backbone = backbones.factory.build_backbone(
......@@ -220,8 +206,7 @@ class SimCLRPretrainTask(base_task.Task):
projection_outputs = model_outputs[simclr_model.PROJECTION_OUTPUT_KEY]
projection1, projection2 = tf.split(projection_outputs, 2, 0)
contrast_loss, (contrast_logits, contrast_labels) = con_losses_obj(
projection1=projection1,
projection2=projection2)
projection1=projection1, projection2=projection2)
contrast_accuracy = tf.equal(
tf.argmax(contrast_labels, axis=1), tf.argmax(contrast_logits, axis=1))
......@@ -253,8 +238,8 @@ class SimCLRPretrainTask(base_task.Task):
outputs)
sup_loss = tf.reduce_mean(sup_loss)
label_acc = tf.equal(tf.argmax(labels, axis=1),
tf.argmax(outputs, axis=1))
label_acc = tf.equal(
tf.argmax(labels, axis=1), tf.argmax(outputs, axis=1))
label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))
model_loss = contrast_loss + sup_loss
......@@ -278,10 +263,7 @@ class SimCLRPretrainTask(base_task.Task):
if training:
metrics = []
metric_names = [
'total_loss',
'contrast_loss',
'contrast_accuracy',
'contrast_entropy'
'total_loss', 'contrast_loss', 'contrast_accuracy', 'contrast_entropy'
]
if self.task_config.model.supervised_head:
metric_names.extend(['supervised_loss', 'accuracy'])
......@@ -293,18 +275,20 @@ class SimCLRPretrainTask(base_task.Task):
metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
k=k, name='top_{}_accuracy'.format(k))
]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
k=k, name='top_{}_accuracy'.format(k))
]
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
features, labels = inputs
if (self.task_config.model.supervised_head is not None
and self.task_config.evaluation.one_hot):
if (self.task_config.model.supervised_head is not None and
self.task_config.evaluation.one_hot):
num_classes = self.task_config.model.supervised_head.num_classes
labels = tf.one_hot(labels, num_classes)
......@@ -313,8 +297,7 @@ class SimCLRPretrainTask(base_task.Task):
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
losses = self.build_losses(
......@@ -373,7 +356,8 @@ class SimCLRPretrainTask(base_task.Task):
class SimCLRFinetuneTask(base_task.Task):
"""A task for image classification."""
def create_optimizer(self, optimizer_config: OptimizationConfig,
def create_optimizer(self,
optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
......@@ -384,8 +368,8 @@ class SimCLRFinetuneTask(base_task.Task):
Returns:
A tf.optimizers.Optimizer object.
"""
if (optimizer_config.optimizer.type == 'lars'
and self.task_config.loss.l2_weight_decay > 0.0):
if (optimizer_config.optimizer.type == 'lars' and
self.task_config.loss.l2_weight_decay > 0.0):
raise ValueError('The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.')
......@@ -403,15 +387,16 @@ class SimCLRFinetuneTask(base_task.Task):
def build_model(self):
model_config = self.task_config.model
input_specs = tf.keras.layers.InputSpec(
shape=[None] + model_config.input_size)
input_specs = tf.keras.layers.InputSpec(shape=[None] +
model_config.input_size)
l2_weight_decay = self.task_config.loss.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
l2_regularizer = (
tf.keras.regularizers.l2(l2_weight_decay /
2.0) if l2_weight_decay else None)
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
......@@ -467,8 +452,8 @@ class SimCLRFinetuneTask(base_task.Task):
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
elif self.task_config.init_checkpoint_modules == 'backbone_projection':
ckpt = tf.train.Checkpoint(backbone=model.backbone,
projection_head=model.projection_head)
ckpt = tf.train.Checkpoint(
backbone=model.backbone, projection_head=model.projection_head)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
......@@ -542,12 +527,14 @@ class SimCLRFinetuneTask(base_task.Task):
metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
k=k, name='top_{}_accuracy'.format(k))
]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
k=k, name='top_{}_accuracy'.format(k))
]
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
......@@ -577,16 +564,14 @@ class SimCLRFinetuneTask(base_task.Task):
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs,
labels=labels, aux_losses=model.losses)
model_outputs=outputs, labels=labels, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
......@@ -596,8 +581,7 @@ class SimCLRFinetuneTask(base_task.Task):
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(
optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
......@@ -626,11 +610,11 @@ class SimCLRFinetuneTask(base_task.Task):
num_classes = self.task_config.model.supervised_head.num_classes
labels = tf.one_hot(labels, num_classes)
outputs = self.inference_step(
features, model)[simclr_model.SUPERVISED_OUTPUT_KEY]
outputs = self.inference_step(features,
model)[simclr_model.SUPERVISED_OUTPUT_KEY]
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs,
labels=labels, aux_losses=model.losses)
loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
......
......@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Model Garden Vision SimCLR training driver."""
"""TensorFlow Model Garden Vision SimCLR trainer."""
from absl import app
from absl import flags
import gin
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment