Unverified Commit 2ec84ebd authored by Vinicius Reis's avatar Vinicius Reis Committed by GitHub
Browse files

Fix LARC with mixed precision (#793)

The LARC optimizer wraps an underlying optimizer and then needs to be passed
to amp.initialize for mixed precision. There were 3 different crashes happening
in this situation, fix all of them and add a unit test.

I don't know if the 'LARC' in sys.modules check ever worked. In my setup, the
entry in sys.modules is 'apex.parallel.LARC'. Checking if the variable is
defined seems more reliable though.
parent 55716d85
......@@ -146,7 +146,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
from .amp import init as amp_init
optimizers_was_list = False
if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in sys.modules and isinstance(optimizers, LARC)):
if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):
optimizers = [optimizers]
elif optimizers is None:
optimizers = []
......
......@@ -87,7 +87,7 @@ def scale_loss(loss,
yield loss
return
if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in sys.modules and isinstance(optimizers, LARC)):
if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):
optimizers = [optimizers]
loss_scaler = _amp_state.loss_scalers[loss_id]
......
......@@ -48,6 +48,10 @@ class LARC(object):
def __setstate__(self, state):
self.optim.__setstate__(state)
@property
def state(self):
return self.optim.state
def __repr__(self):
return self.optim.__repr__()
......
import unittest
import torch
from torch import nn
from torch.nn import Parameter
from apex import amp
from apex.parallel.LARC import LARC
from utils import common_init
class MyModel(torch.nn.Module):
def __init__(self, unique):
super(MyModel, self).__init__()
self.weight0 = Parameter(
unique + torch.arange(2, device="cuda", dtype=torch.float32)
)
def forward(self, input):
return (input * self.weight0).sum()
class TestLARC(unittest.TestCase):
def setUp(self):
self.x = torch.ones((2), device="cuda", dtype=torch.float32)
common_init(self)
def tearDown(self):
pass
def test_larc_mixed_precision(self):
for opt_level in ["O0", "O1", "O2", "O3"]:
model = MyModel(1)
optimizer = LARC(
torch.optim.SGD(
[{"params": model.parameters(), "lr": 0.25}], momentum=0.125
)
)
model, optimizer = amp.initialize(
model, optimizer, opt_level=opt_level, verbosity=0
)
optimizer.zero_grad()
loss = model(self.x)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment