"vscode:/vscode.git/clone" did not exist on "b29c5537480e9ce0c3b7a36719719c4cca8027fc"
Commit 09a70c7c authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 374451731
parent 23db25e9
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
num_classes: 18291
input_size: [224, 224, 3]
backbone:
type: 'dilated_resnet'
dilated_resnet:
model_id: 101
output_stride: 16
stem_type: 'v1'
multigrid: [1, 2, 4]
norm_activation:
activation: 'swish'
losses:
l2_weight_decay: 0.0
train_data:
input_path: ''
tfds_name: 'jft/entity'
tfds_split: 'train'
is_training: true
global_batch_size: 3840
is_multilabel: true
shuffle_buffer_size: 500000
dtype: 'bfloat16'
validation_data:
input_path: ''
tfds_name: 'jft/entity'
tfds_split: 'validation'
is_training: false
global_batch_size: 3840
is_multilabel: true
dtype: 'bfloat16'
drop_remainder: false
trainer:
train_steps: 2220000 # 30 epochs
validation_steps: 156
validation_interval: 2000
steps_per_loop: 100
summary_interval: 2000
checkpoint_interval: 2000
best_checkpoint_eval_metric: 'globalPR-AUC'
best_checkpoint_export_subdir: 'best_ckpt'
best_checkpoint_metric_comp: 'higher'
optimizer_config:
ema: null
optimizer:
type: 'sgd'
sgd:
momentum: 0.9
learning_rate:
type: 'stepwise'
stepwise:
values: [0.48, 0.048, 0.0048, 0.00048]
boundaries: [730000, 1460000, 1850000]
warmup:
type: 'linear'
linear:
warmup_steps: 5000
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
num_classes: 18291
input_size: [160, 160, 3]
backbone:
type: 'resnet'
resnet:
model_id: 50
replace_stem_max_pool: true
resnetd_shortcut: true
se_ratio: 0.25
stem_type: 'v1'
stochastic_depth_drop_rate: 0.0
norm_activation:
activation: 'swish'
norm_momentum: 0.0
use_sync_bn: false
dropout_rate: 0.25
losses:
l2_weight_decay: 0.00004
train_data:
input_path: ''
tfds_name: 'jft/entity'
tfds_split: 'train'
is_training: true
global_batch_size: 4096
is_multilabel: true
shuffle_buffer_size: 500000
dtype: 'bfloat16'
aug_type: null
validation_data:
input_path: ''
tfds_name: 'jft/entity'
tfds_split: 'validation'
is_training: false
global_batch_size: 4096
is_multilabel: true
dtype: 'bfloat16'
drop_remainder: false
trainer:
train_steps: 2220000 # 30 epochs
validation_steps: 156
validation_interval: 2000
steps_per_loop: 100
summary_interval: 2000
checkpoint_interval: 2000
best_checkpoint_eval_metric: 'globalPR-AUC'
best_checkpoint_export_subdir: 'best_ckpt'
best_checkpoint_metric_comp: 'higher'
optimizer_config:
optimizer:
type: 'sgd'
sgd:
momentum: 0.9
learning_rate:
type: 'stepwise'
stepwise:
values: [0.48, 0.048, 0.0048, 0.00048]
boundaries: [730000, 1460000, 1850000]
warmup:
type: 'linear'
linear:
warmup_steps: 5000
...@@ -75,15 +75,18 @@ class ImageClassificationTask(base_task.Task): ...@@ -75,15 +75,18 @@ class ImageClassificationTask(base_task.Task):
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
def build_inputs(self, def build_inputs(
params: exp_cfg.DataConfig, self,
input_context: Optional[tf.distribute.InputContext] = None): params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Builds classification input.""" """Builds classification input."""
num_classes = self.task_config.model.num_classes num_classes = self.task_config.model.num_classes
input_size = self.task_config.model.input_size input_size = self.task_config.model.input_size
image_field_key = self.task_config.train_data.image_field_key image_field_key = self.task_config.train_data.image_field_key
label_field_key = self.task_config.train_data.label_field_key label_field_key = self.task_config.train_data.label_field_key
is_multilabel = self.task_config.train_data.is_multilabel
if params.tfds_name: if params.tfds_name:
if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP: if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP:
...@@ -93,7 +96,8 @@ class ImageClassificationTask(base_task.Task): ...@@ -93,7 +96,8 @@ class ImageClassificationTask(base_task.Task):
raise ValueError('TFDS {} is not supported'.format(params.tfds_name)) raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else: else:
decoder = classification_input.Decoder( decoder = classification_input.Decoder(
image_field_key=image_field_key, label_field_key=label_field_key) image_field_key=image_field_key, label_field_key=label_field_key,
is_multilabel=is_multilabel)
parser = classification_input.Parser( parser = classification_input.Parser(
output_size=input_size[:2], output_size=input_size[:2],
...@@ -102,6 +106,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -102,6 +106,7 @@ class ImageClassificationTask(base_task.Task):
label_field_key=label_field_key, label_field_key=label_field_key,
aug_rand_hflip=params.aug_rand_hflip, aug_rand_hflip=params.aug_rand_hflip,
aug_type=params.aug_type, aug_type=params.aug_type,
is_multilabel=is_multilabel,
dtype=params.dtype) dtype=params.dtype)
reader = input_reader_factory.input_reader_generator( reader = input_reader_factory.input_reader_generator(
...@@ -117,7 +122,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -117,7 +122,7 @@ class ImageClassificationTask(base_task.Task):
def build_losses(self, def build_losses(self,
labels: tf.Tensor, labels: tf.Tensor,
model_outputs: tf.Tensor, model_outputs: tf.Tensor,
aux_losses: Optional[Any] = None): aux_losses: Optional[Any] = None) -> tf.Tensor:
"""Builds sparse categorical cross entropy loss. """Builds sparse categorical cross entropy loss.
Args: Args:
...@@ -129,15 +134,23 @@ class ImageClassificationTask(base_task.Task): ...@@ -129,15 +134,23 @@ class ImageClassificationTask(base_task.Task):
The total loss tensor. The total loss tensor.
""" """
losses_config = self.task_config.losses losses_config = self.task_config.losses
if losses_config.one_hot: is_multilabel = self.task_config.train_data.is_multilabel
total_loss = tf.keras.losses.categorical_crossentropy(
labels, if not is_multilabel:
model_outputs, if losses_config.one_hot:
from_logits=True, total_loss = tf.keras.losses.categorical_crossentropy(
label_smoothing=losses_config.label_smoothing) labels,
model_outputs,
from_logits=True,
label_smoothing=losses_config.label_smoothing)
else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=True)
else: else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy( # Multi-label weighted binary cross entropy loss.
labels, model_outputs, from_logits=True) total_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=labels, logits=model_outputs)
total_loss = tf.reduce_sum(total_loss, axis=-1)
total_loss = tf_utils.safe_mean(total_loss) total_loss = tf_utils.safe_mean(total_loss)
if aux_losses: if aux_losses:
...@@ -145,19 +158,41 @@ class ImageClassificationTask(base_task.Task): ...@@ -145,19 +158,41 @@ class ImageClassificationTask(base_task.Task):
return total_loss return total_loss
def build_metrics(self, training: bool = True): def build_metrics(self,
training: bool = True) -> List[tf.keras.metrics.Metric]:
"""Gets streaming metrics for training/validation.""" """Gets streaming metrics for training/validation."""
k = self.task_config.evaluation.top_k is_multilabel = self.task_config.train_data.is_multilabel
if self.task_config.losses.one_hot: if not is_multilabel:
metrics = [ k = self.task_config.evaluation.top_k
tf.keras.metrics.CategoricalAccuracy(name='accuracy'), if self.task_config.losses.one_hot:
tf.keras.metrics.TopKCategoricalAccuracy( metrics = [
k=k, name='top_{}_accuracy'.format(k))] tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
else: else:
metrics = [ metrics = []
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'), # These metrics destablize the training if included in training. The jobs
tf.keras.metrics.SparseTopKCategoricalAccuracy( # fail due to OOM.
k=k, name='top_{}_accuracy'.format(k))] # TODO(arashwan): Investigate adding following metric to train.
if not training:
metrics = [
tf.keras.metrics.AUC(
name='globalPR-AUC',
curve='PR',
multi_label=False,
from_logits=True),
tf.keras.metrics.AUC(
name='meanlPR-AUC',
curve='PR',
multi_label=True,
num_labels=self.task_config.model.num_classes,
from_logits=True),
]
return metrics return metrics
def train_step(self, def train_step(self,
...@@ -177,7 +212,8 @@ class ImageClassificationTask(base_task.Task): ...@@ -177,7 +212,8 @@ class ImageClassificationTask(base_task.Task):
A dictionary of logs. A dictionary of logs.
""" """
features, labels = inputs features, labels = inputs
if self.task_config.losses.one_hot: is_multilabel = self.task_config.train_data.is_multilabel
if self.task_config.losses.one_hot and not is_multilabel:
labels = tf.one_hot(labels, self.task_config.model.num_classes) labels = tf.one_hot(labels, self.task_config.model.num_classes)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
...@@ -233,7 +269,8 @@ class ImageClassificationTask(base_task.Task): ...@@ -233,7 +269,8 @@ class ImageClassificationTask(base_task.Task):
A dictionary of logs. A dictionary of logs.
""" """
features, labels = inputs features, labels = inputs
if self.task_config.losses.one_hot: is_multilabel = self.task_config.train_data.is_multilabel
if self.task_config.losses.one_hot and not is_multilabel:
labels = tf.one_hot(labels, self.task_config.model.num_classes) labels = tf.one_hot(labels, self.task_config.model.num_classes)
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment