"gallery/transforms/plot_transforms_illustrations.py" did not exist on "8e0f6916c549005e5b18a9da1cdf94e7ac90f36c"
Commit 0a092aaf authored by Michael Carilli's avatar Michael Carilli
Browse files

Updating clip_master_grads for forward compatibility

parent 7eba6bfb
......@@ -199,6 +199,13 @@ class FP16_Optimizer(object):
self.overflow = False
self.first_closure_call_this_step = True
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
self.clip_grad_norm = torch.nn.utils.clip_grad_norm
else:
self.clip_grad_norm = torch.nn.utils.clip_grad_norm_
def __getstate__(self):
raise RuntimeError("FP16_Optimizer should be serialized using state_dict().")
......@@ -270,7 +277,7 @@ class FP16_Optimizer(object):
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
fp32_params.append(param)
return torch.nn.utils.clip_grad_norm(fp32_params, max_norm, norm_type)
return self.clip_grad_norm(fp32_params, max_norm, norm_type)
else:
return -1
......
......@@ -4,6 +4,7 @@ To use `FP16_Optimizer` on a half-precision model, or a model with a mixture of
half and float parameters, only two lines of your training script need to change:
1. Construct an `FP16_Optimizer` instance from an existing optimizer.
2. Replace `loss.backward()` with `optimizer.backward(loss)`.
[Full API Documentation](https://nvidia.github.io/apex/fp16_utils.html#automatic-management-of-master-params-loss-scaling)
See "Other Options" at the bottom of this page for some cases that require special treatment.
......@@ -42,8 +43,8 @@ bash run.sh
#### Other Options
Gradient clipping requires that calls to `torch.nn.utils.clip_grad_norm"
be replaced with [fp16_optimizer_instance.clip_master_grads](https://nvidia.github.io/apex/fp16_utils.html#apex.fp16_utils.FP16_Optimizer.clip_master_grads).
Gradient clipping requires that calls to `torch.nn.utils.clip_grad_norm`
be replaced with [fp16_optimizer_instance.clip_master_grads()](https://nvidia.github.io/apex/fp16_utils.html#apex.fp16_utils.FP16_Optimizer.clip_master_grads). The [word_language_model example](https://github.com/NVIDIA/apex/blob/master/examples/word_language_model/main_fp16_optimizer.py) uses this feature.
Multiple losses will work if you simply replace
```bash
......@@ -56,4 +57,4 @@ optimizer.backward(loss1)
optimizer.backward(loss2)
```
but `FP16_Optimizer` can be told to handle this more efficiently using the
[update_master_grads](https://nvidia.github.io/apex/fp16_utils.html#apex.fp16_utils.FP16_Optimizer.update_master_grads) option.
[update_master_grads()](https://nvidia.github.io/apex/fp16_utils.html#apex.fp16_utils.FP16_Optimizer.update_master_grads) option.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment