"src/diffusers/commands/__init__.py" did not exist on "cbb19ee84ee495e819e4ef5723cc3412b237ff1d"
Commit 887a50bd authored by Michael Carilli's avatar Michael Carilli
Browse files

Better way to expose scale adjustment

parent 9efb2809
...@@ -122,7 +122,7 @@ def scale_loss(loss, ...@@ -122,7 +122,7 @@ def scale_loss(loss,
# necessary because amp.scale_loss is already creating a temporary scope. # necessary because amp.scale_loss is already creating a temporary scope.
def patch_step(opt, loss_scaler, loss_id): def patch_step(opt, loss_scaler, loss_id):
opt_step = opt.step opt_step = opt.step
def skip_step(scale=None): def skip_step():
maybe_print(("Gradient overflow. Skipping step, loss scaler " + maybe_print(("Gradient overflow. Skipping step, loss scaler " +
"{} reducing loss scale to {}").format(loss_id, "{} reducing loss scale to {}").format(loss_id,
loss_scaler.loss_scale())) loss_scaler.loss_scale()))
......
...@@ -39,7 +39,8 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -39,7 +39,8 @@ class FusedAdam(torch.optim.Optimizer):
def __init__(self, params, def __init__(self, params,
lr=1e-3, bias_correction = True, lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False, betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False): weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0):
global fused_adam_cuda global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
...@@ -51,6 +52,8 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -51,6 +52,8 @@ class FusedAdam(torch.optim.Optimizer):
self._use_multi_tensor = True self._use_multi_tensor = True
self._overflow_buf = torch.cuda.IntTensor([0]) self._overflow_buf = torch.cuda.IntTensor([0])
self._amp_scale_adjustment = amp_scale_adjustment
if amsgrad: if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.') raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
...@@ -81,7 +84,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -81,7 +84,7 @@ class FusedAdam(torch.optim.Optimizer):
if hasattr(self, "_amp_stash"): if hasattr(self, "_amp_stash"):
grads = self._amp_stash.grads grads = self._amp_stash.grads
output_params = self._amp_stash.output_params output_params = self._amp_stash.output_params
scale = self._amp_stash.scale*scale scale = self._amp_stash.scale*self._amp_scale_adjustment
grad_norms = self._amp_stash.grad_norms grad_norms = self._amp_stash.grad_norms
if grads is None: if grads is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment