"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "ebd0e1c18bf2f983c8302a609e606e978b2d5a6c"
Commit 7bffd37b authored by Zac Wellmer's avatar Zac Wellmer Committed by Taylor Robie
Browse files

Resnet transfer learning (#5047)

* warm start a resent with all but the dense layer and only update the final layer weights when fine tuning

* Update README for Transfer Learning

* make lint happy and variable naming error related to scaled gradients

* edit the test cases for cifar10 and imagenet to reflect the default case of no fine tuning
parent c2f755c4
......@@ -75,3 +75,19 @@ ResNet-50 v2 (fp16, Accuracy 75.56%):
ResNet-50 v1 (Accuracy 75.91%):
* [Checkpoint](http://download.tensorflow.org/models/official/20180601_resnet_v1_imagenet_checkpoint.tar.gz)
* [SavedModel](http://download.tensorflow.org/models/official/20180601_resnet_v1_imagenet_savedmodel.tar.gz)
### Transfer Learning
You can use a pretrained model to initialize a training process. In addition you are able to freeze all but the final fully connected layers to fine tune your model. Transfer Learning is useful when training on your own small datasets. For a brief look at transfer learning in the context of convolutional neural networks, we recommend reading these [short notes](http://cs231n.github.io/transfer-learning/).
To fine tune a pretrained resnet you must make three changes to your training procedure:
1) Build the exact same model as previously except we change the number of labels in the final classification layer.
2) Restore all weights from the pre-trained resnet except for the final classification layer; this will get randomly initialized instead.
3) Freeze earlier layers of the network
We can perform these three operations by specifying two flags: ```--pretrained_model_checkpoint_path``` and ```--fine_tune```. The first flag is a string that points to the path of a pre-trained resnet model. If this flag is specified, it will load all but the final classification layer. A key thing to note: if both ```--pretrained_model_checkpoint_path``` and a non empty ```model_dir``` directory are passed, the tensorflow estimator will load only the ```model_dir```. For more on this please see [WarmStartSettings](https://www.tensorflow.org/versions/master/api_docs/python/tf/estimator/WarmStartSettings) and [Estimators](https://www.tensorflow.org/guide/estimators).
The second flag ```--fine_tune``` is a boolean that indicates whether earlier layers of the network should be frozen. You may set this flag to false if you wish to continue training a pre-trained model from a checkpoint. If you set this flag to true, you can train a new classification layer from scratch.
......@@ -221,7 +221,8 @@ def cifar10_model_fn(features, labels, mode, params):
resnet_version=params['resnet_version'],
loss_scale=params['loss_scale'],
loss_filter_fn=loss_filter_fn,
dtype=params['dtype']
dtype=params['dtype'],
fine_tune=params['fine_tune']
)
......
......@@ -89,6 +89,7 @@ class BaseTest(tf.test.TestCase):
'batch_size': _BATCH_SIZE,
'resnet_version': resnet_version,
'loss_scale': 128 if dtype == tf.float16 else 1,
'fine_tune': False,
})
predictions = spec.predictions
......
......@@ -303,7 +303,8 @@ def imagenet_model_fn(features, labels, mode, params):
resnet_version=params['resnet_version'],
loss_scale=params['loss_scale'],
loss_filter_fn=None,
dtype=params['dtype']
dtype=params['dtype'],
fine_tune=params['fine_tune']
)
......
......@@ -203,6 +203,7 @@ class BaseTest(tf.test.TestCase):
'batch_size': _BATCH_SIZE,
'resnet_version': resnet_version,
'loss_scale': 128 if dtype == tf.float16 else 1,
'fine_tune': False,
})
predictions = spec.predictions
......
......@@ -177,7 +177,8 @@ def learning_rate_with_decay(
def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, resnet_version, loss_scale,
loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE):
loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE,
fine_tune=False):
"""Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers
......@@ -210,6 +211,7 @@ def resnet_model_fn(features, labels, mode, model_class,
otherwise. If None, batch_normalization variables will be excluded
from the loss.
dtype: the TensorFlow dtype to use for calculations.
fine_tune: If True only train the dense layers(final layers).
Returns:
EstimatorSpec parameterized according to the input params and the
......@@ -281,19 +283,36 @@ def resnet_model_fn(features, labels, mode, model_class,
momentum=momentum
)
def _dense_grad_filter(gvs):
'''
only apply gradient updates to the final layer. This function is used for
fine tuning
Args:
gvs : list of tuples with gradients and variable info
Returns:
filtered gradients so that only the dense layer remains
'''
return [(g, v) for g, v in gvs if 'dense' in v.name]
if loss_scale != 1:
# When computing fp16 gradients, often intermediate tensor values are
# so small, they underflow to 0. To avoid this, we multiply the loss by
# loss_scale to make these tensor values loss_scale times bigger.
scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale)
if fine_tune:
scaled_grad_vars = _dense_grad_filter(scaled_grad_vars)
# Once the gradient computation is complete we can scale the gradients
# back to the correct scale before passing them to the optimizer.
unscaled_grad_vars = [(grad / loss_scale, var)
for grad, var in scaled_grad_vars]
minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step)
else:
minimize_op = optimizer.minimize(loss, global_step)
grad_vars = optimizer.compute_gradients(loss)
if fine_tune:
grad_vars = _dense_grad_filter(grad_vars)
minimize_op = optimizer.apply_gradients(grad_vars, global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = tf.group(minimize_op, update_ops)
......@@ -354,15 +373,24 @@ def resnet_main(
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy, session_config=session_config)
# initialize our model with all but the dense layer from pretrained resnet
if flags_obj.pretrained_model_checkpoint_path is not None:
warm_start_settings = tf.estimator.WarmStartSettings(
flags_obj.pretrained_model_checkpoint_path,
vars_to_warm_start='^(?!.*dense)')
else:
warm_start_settings = None
classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
params={
warm_start_from=warm_start_settings, params={
'resnet_size': int(flags_obj.resnet_size),
'data_format': flags_obj.data_format,
'batch_size': flags_obj.batch_size,
'resnet_version': int(flags_obj.resnet_version),
'loss_scale': flags_core.get_loss_scale(flags_obj),
'dtype': flags_core.get_tf_dtype(flags_obj)
'dtype': flags_core.get_tf_dtype(flags_obj),
'fine_tune': flags_obj.fine_tune
})
run_params = {
......@@ -446,6 +474,15 @@ def define_resnet_flags(resnet_size_choices=None):
enum_values=['1', '2'],
help=flags_core.help_wrap(
'Version of ResNet. (1 or 2) See README.md for details.'))
flags.DEFINE_bool(
name='fine_tune', short_name='ft', default=False,
help=flags_core.help_wrap(
'If True do not train any parameters except for the final layer.'))
flags.DEFINE_string(
name='pretrained_model_checkpoint_path', short_name='pmcp', default=None,
help=flags_core.help_wrap(
'If not None initialize all the network except the final layer with '
'these values'))
choice_kwargs = dict(
name='resnet_size', short_name='rs', default='50',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment