Unverified Commit 4d325d2f authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Support add_param_group (#310)

* Support add_param_group

* syntax

* Test added and passing
parent cfb628ba
......@@ -106,7 +106,7 @@ def post_backward_with_master_weights(self, scaler):
if fp16_param.grad is None and fp32_param.grad is not None:
continue
elif fp16_param.grad is not None and fp32_param.grad is None:
fp32_param.grad = torch.empty_like(fp32_param)
fp32_param.grad = torch.empty_like(fp32_param)
fp16_grads_needing_unscale.append(fp16_param.grad)
new_fp32_grads.append(fp32_param.grad)
elif fp16_param.grad is not None and fp32_param.grad is not None:
......@@ -176,7 +176,7 @@ def lazy_init_no_master_weights(self):
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params]
......@@ -328,4 +328,77 @@ def _process_optimizer(optimizer, properties):
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights, optimizer)
old_add_param_group = optimizer.add_param_group
def new_add_param_group(self, new_group):
stash = self._amp_stash
assert isinstance(new_group, dict), "param group must be a dict"
new_params = new_group['params']
if isinstance(new_params, torch.Tensor):
new_group['params'] = [new_params]
elif isinstance(new_params, set):
raise TypeError('optimizer parameters need to be organized in ordered collections, but '
'the ordering of tensors in sets will change between runs. Please use a list instead.')
else:
new_group['params'] = list(new_params)
if properties.master_weights:
# Mutate new_group in-place to use FP32 master params
fp16_params_this_group = []
fp32_params_this_group = []
fp32_from_fp16_params_this_group = []
for i, param in enumerate(new_group['params']):
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
fp16_params_this_group.append(param)
master_param = param.detach().clone().float()
master_param.requires_grad = True
new_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param)
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param)
new_group['params'][i] = param
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
stash.fp16_groups.append(fp16_params_this_group)
stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
stash.fp32_from_fp32_groups.append(fp32_params_this_group)
stash.all_fp16_params += fp16_params_this_group
stash.all_fp32_from_fp16_params += fp32_from_fp16_params_this_group
stash.all_fp32_from_fp32_params += fp32_params_this_group
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
stash.all_fp32_from_fp32_grad_stash += [None for _ in fp32_params_this_group]
# It should be ok to let params be added with existing .grad attributes.
# for param in fp16_params_this_group:
# param.grad = None
# for param in fp32_from_fp16_params_this_group:
# param.grad = None
# for param in stash.fp32_params_this_group:
# param.grad = None
else:
for param in new_group['params']:
if param.type() == 'torch.cuda.HalfTensor':
stash.all_fp16_params.append(param)
stash.all_fp16_grad_stash.append(None)
elif param.type() == 'torch.cuda.FloatTensor':
stash.all_fp32_params.append(param)
stash.all_fp32_grad_stash.append(None)
else:
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
old_add_param_group(new_group)
optimizer.add_param_group = types.MethodType(new_add_param_group, optimizer)
return optimizer
import unittest
import functools as ft
import itertools as it
from apex import amp
from apex.amp import _amp_state
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Parameter
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
class MyModel(torch.nn.Module):
def __init__(self, unique):
super(MyModel, self).__init__()
self.weight0 = Parameter(unique +
torch.arange(2, device='cuda', dtype=torch.float32))
self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16))
@staticmethod
def ops(input, weight0, weight1):
return ((input*(weight0.float()))*(weight1.float())).sum()
def forward(self, input):
return self.ops(input, self.weight0, self.weight1)
# Abandon all hope, ye who enter here.
class TestAddParamGroup(unittest.TestCase):
def setUp(self):
self.x = torch.ones((2), device='cuda', dtype=torch.float32)
common_init(self)
def tearDown(self):
pass
def zero_grad(self, models, optimizer, how_to_zero):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
elif how_to_zero == "optimizer":
optimizer.zero_grad()
def test_add_param_group(self):
for opt_level in ("O0", "O1", "O2", "O3"):
for zero_before_add in (True, False):
for try_accumulation in (True, False):
model0 = MyModel(1)
model1 = MyModel(2)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125)
optimizer.zero_grad()
loss = model0(self.x)
loss.backward()
optimizer.step()
if zero_before_add:
optimizer.zero_grad()
optimizer.add_param_group({'params' : model1.parameters(), 'lr' : 0.5})
if not zero_before_add:
optimizer.zero_grad()
loss = model0(self.x) + model1(self.x)
loss.backward(retain_graph=try_accumulation)
if try_accumulation:
loss.backward()
optimizer.step()
# Once more to make sure the new params pick up momemtums properly
optimizer.zero_grad()
loss = model0(self.x) + model1(self.x)
loss.backward(retain_graph=try_accumulation)
if try_accumulation:
loss.backward()
optimizer.step()
reference_params = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()]
for how_to_zero in "none", "model", "optimizer":
model0 = MyModel(1)
model1 = MyModel(2)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1], optimizer = amp.initialize([model0, model1],
optimizer,
opt_level=opt_level,
verbosity=0,
cast_model_type=False)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
self.zero_grad([model0, model1], optimizer, how_to_zero)
loss = model0(self.x)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if zero_before_add:
self.zero_grad([model0, model1], optimizer, how_to_zero)
optimizer.add_param_group({'params' : model1.parameters(), 'lr' : 0.5})
if not zero_before_add:
self.zero_grad([model0, model1], optimizer, how_to_zero)
loss = model0(self.x) + model1(self.x)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(retain_graph=try_accumulation)
if try_accumulation:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
# Once more to make sure the new params pick up momentums properly
self.zero_grad([model0, model1], optimizer, how_to_zero)
loss = model0(self.x) + model1(self.x)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(retain_graph=try_accumulation)
if try_accumulation:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
final_params = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()]
for reference, final in zip(reference_params, final_params):
self.assertTrue(torch.allclose(reference.to(final.dtype), final),
"opt_level = {}, how_to_zero = {}, zero_before_add = {}".format(
opt_level, how_to_zero, zero_before_add))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment