Commit 96eed905 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Adding gradient clipping for detection models.

PiperOrigin-RevId: 365639389
parent 1dc59163
......@@ -17,7 +17,7 @@
from official.modeling.hyperparams import params_dict
from official.vision.detection.configs import base_config
SHAPEMASK_RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+/)conv2d(|_([1-9]|10))\/'
SHAPEMASK_RESNET_FROZEN_VAR_PREFIX = r'(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
SHAPEMASK_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
SHAPEMASK_CFG.override({
......
......@@ -63,10 +63,9 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
trainable_variables)
logging.info('Filter trainable variables from %d to %d',
len(model.trainable_variables), len(trainable_variables))
_update_state = lambda labels, outputs: None
update_state_fn = lambda labels, outputs: None
if isinstance(metric, tf.keras.metrics.Metric):
_update_state = lambda labels, outputs: metric.update_state(
labels, outputs)
update_state_fn = metric.update_state
else:
logging.error('Detection: train metric is not an instance of '
'tf.keras.metrics.Metric.')
......@@ -82,10 +81,11 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
for k, v in all_losses.items():
losses[k] = tf.reduce_mean(v)
per_replica_loss = losses['total_loss'] / strategy.num_replicas_in_sync
_update_state(labels, outputs)
update_state_fn(labels, outputs)
grads = tape.gradient(per_replica_loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables))
clipped_grads, _ = tf.clip_by_global_norm(grads, clip_norm=1.0)
optimizer.apply_gradients(zip(clipped_grads, trainable_variables))
return losses
return _replicated_step
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment