"magic_pdf/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "c2d48bedf8ef42a7b2e66ff9a285cc5de37f70e4"
Unverified Commit 0b0dc7f5 authored by pkanwar23's avatar pkanwar23 Committed by GitHub
Browse files

Adding LARS to ResNet (#6327)

* Adding LARS to ResNet

* Fixes for the LARS patch

* Fixes for the LARS patch

* more fixes

* 1 more fix
parent 05584085
...@@ -313,7 +313,7 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -313,7 +313,7 @@ def imagenet_model_fn(features, labels, mode, params):
mode=mode, mode=mode,
model_class=ImagenetModel, model_class=ImagenetModel,
resnet_size=params['resnet_size'], resnet_size=params['resnet_size'],
weight_decay=1e-4, weight_decay=flags.FLAGS.weight_decay,
learning_rate_fn=learning_rate_fn, learning_rate_fn=learning_rate_fn,
momentum=0.9, momentum=0.9,
data_format=params['data_format'], data_format=params['data_format'],
...@@ -321,7 +321,8 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -321,7 +321,8 @@ def imagenet_model_fn(features, labels, mode, params):
loss_scale=params['loss_scale'], loss_scale=params['loss_scale'],
loss_filter_fn=None, loss_filter_fn=None,
dtype=params['dtype'], dtype=params['dtype'],
fine_tune=params['fine_tune'] fine_tune=params['fine_tune'],
label_smoothing=flags.FLAGS.label_smoothing
) )
......
...@@ -266,6 +266,59 @@ def learning_rate_with_decay( ...@@ -266,6 +266,59 @@ def learning_rate_with_decay(
false_fn=lambda: lr) false_fn=lambda: lr)
return lr return lr
def poly_rate_fn(global_step):
"""Handles linear scaling rule, gradual warmup, and LR decay.
The learning rate starts at 0, then it increases linearly per step. After
FLAGS.poly_warmup_epochs, we reach the base learning rate (scaled to account
for batch size). The learning rate is then decayed using a polynomial rate
decay schedule with power 2.0.
Args:
global_step: the current global_step
Returns:
returns the current learning rate
"""
# Learning rate schedule for LARS polynomial schedule
if flags.FLAGS.batch_size < 8192:
plr = 5.0
w_epochs = 5
elif flags.FLAGS.batch_size < 16384:
plr = 10.0
w_epochs = 5
elif flags.FLAGS.batch_size < 32768:
plr = 25.0
w_epochs = 5
else:
plr = 32.0
w_epochs = 14
w_steps = int(w_epochs * batches_per_epoch)
wrate = (plr * tf.cast(global_step, tf.float32) / tf.cast(
w_steps, tf.float32))
# TODO(pkanwar): use a flag to help calc num_epochs.
num_epochs = 90
train_steps = batches_per_epoch * num_epochs
min_step = tf.constant(1, dtype=tf.int64)
decay_steps = tf.maximum(min_step, tf.subtract(global_step, w_steps))
poly_rate = tf.train.polynomial_decay(
plr,
decay_steps,
train_steps - w_steps + 1,
power=2.0)
return tf.where(global_step <= w_steps, wrate, poly_rate)
# For LARS we have a new learning rate schedule
if flags.FLAGS.enable_lars:
return poly_rate_fn
return learning_rate_fn return learning_rate_fn
...@@ -273,7 +326,7 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -273,7 +326,7 @@ def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum, resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, resnet_version, loss_scale, data_format, resnet_version, loss_scale,
loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE, loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE,
fine_tune=False): fine_tune=False, label_smoothing=0.0):
"""Shared functionality for different resnet model_fns. """Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers Initializes the ResnetModel representing the model layers
...@@ -343,8 +396,14 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -343,8 +396,14 @@ def resnet_model_fn(features, labels, mode, model_class,
}) })
# Calculate loss, which includes softmax cross entropy and L2 regularization. # Calculate loss, which includes softmax cross entropy and L2 regularization.
cross_entropy = tf.compat.v1.losses.sparse_softmax_cross_entropy( if label_smoothing != 0.0:
logits=logits, labels=labels) one_hot_labels = tf.one_hot(labels, 1001)
cross_entropy = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=one_hot_labels,
label_smoothing=label_smoothing)
else:
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=labels)
# Create a tensor named cross_entropy for logging purposes. # Create a tensor named cross_entropy for logging purposes.
tf.identity(cross_entropy, name='cross_entropy') tf.identity(cross_entropy, name='cross_entropy')
...@@ -378,10 +437,17 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -378,10 +437,17 @@ def resnet_model_fn(features, labels, mode, model_class,
tf.identity(learning_rate, name='learning_rate') tf.identity(learning_rate, name='learning_rate')
tf.compat.v1.summary.scalar('learning_rate', learning_rate) tf.compat.v1.summary.scalar('learning_rate', learning_rate)
optimizer = tf.compat.v1.train.MomentumOptimizer( if flags.FLAGS.enable_lars:
learning_rate=learning_rate, optimizer = tf.contrib.opt.LARSOptimizer(
momentum=momentum learning_rate,
) momentum=momentum,
weight_decay=weight_decay,
skip_list=['batch_normalization', 'bias'])
else:
optimizer = tf.compat.v1.train.MomentumOptimizer(
learning_rate=learning_rate,
momentum=momentum
)
def _dense_grad_filter(gvs): def _dense_grad_filter(gvs):
"""Only apply gradient updates to the final layer. """Only apply gradient updates to the final layer.
...@@ -689,6 +755,19 @@ def define_resnet_flags(resnet_size_choices=None): ...@@ -689,6 +755,19 @@ def define_resnet_flags(resnet_size_choices=None):
name='task_index', default=-1, name='task_index', default=-1,
help=flags_core.help_wrap('If multi-worker training, the task_index of ' help=flags_core.help_wrap('If multi-worker training, the task_index of '
'this worker.')) 'this worker.'))
flags.DEFINE_bool(
name='enable_lars', default=False,
help=flags_core.help_wrap(
'Enable LARS optimizer for large batch training.'))
flags.DEFINE_float(
name='label_smoothing', default=0.0,
help=flags_core.help_wrap(
'Label smoothing parameter used in the softmax_cross_entropy'))
flags.DEFINE_float(
name='weight_decay', default=1e-4,
help=flags_core.help_wrap(
'Weight decay coefficiant for l2 regularization.'))
choice_kwargs = dict( choice_kwargs = dict(
name='resnet_size', short_name='rs', default='50', name='resnet_size', short_name='rs', default='50',
help=flags_core.help_wrap('The size of the ResNet model to use.')) help=flags_core.help_wrap('The size of the ResNet model to use.'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment