Commit 96eed905 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Adding gradient clipping for detection models.

PiperOrigin-RevId: 365639389
parent 1dc59163
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.vision.detection.configs import base_config from official.vision.detection.configs import base_config
SHAPEMASK_RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+/)conv2d(|_([1-9]|10))\/' SHAPEMASK_RESNET_FROZEN_VAR_PREFIX = r'(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
SHAPEMASK_CFG = params_dict.ParamsDict(base_config.BASE_CFG) SHAPEMASK_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
SHAPEMASK_CFG.override({ SHAPEMASK_CFG.override({
......
...@@ -63,10 +63,9 @@ class DetectionDistributedExecutor(executor.DistributedExecutor): ...@@ -63,10 +63,9 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
trainable_variables) trainable_variables)
logging.info('Filter trainable variables from %d to %d', logging.info('Filter trainable variables from %d to %d',
len(model.trainable_variables), len(trainable_variables)) len(model.trainable_variables), len(trainable_variables))
_update_state = lambda labels, outputs: None update_state_fn = lambda labels, outputs: None
if isinstance(metric, tf.keras.metrics.Metric): if isinstance(metric, tf.keras.metrics.Metric):
_update_state = lambda labels, outputs: metric.update_state( update_state_fn = metric.update_state
labels, outputs)
else: else:
logging.error('Detection: train metric is not an instance of ' logging.error('Detection: train metric is not an instance of '
'tf.keras.metrics.Metric.') 'tf.keras.metrics.Metric.')
...@@ -82,10 +81,11 @@ class DetectionDistributedExecutor(executor.DistributedExecutor): ...@@ -82,10 +81,11 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
for k, v in all_losses.items(): for k, v in all_losses.items():
losses[k] = tf.reduce_mean(v) losses[k] = tf.reduce_mean(v)
per_replica_loss = losses['total_loss'] / strategy.num_replicas_in_sync per_replica_loss = losses['total_loss'] / strategy.num_replicas_in_sync
_update_state(labels, outputs) update_state_fn(labels, outputs)
grads = tape.gradient(per_replica_loss, trainable_variables) grads = tape.gradient(per_replica_loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables)) clipped_grads, _ = tf.clip_by_global_norm(grads, clip_norm=1.0)
optimizer.apply_gradients(zip(clipped_grads, trainable_variables))
return losses return losses
return _replicated_step return _replicated_step
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment