Commit 4b913261 authored by Aron Hoffmann's avatar Aron Hoffmann Committed by mcarilli
Browse files

Made the patched optimizer step function a full method, not simply a function...

Made the patched optimizer step function a full method, not simply a function stored as an instance member (#553)
parent 08898593
......@@ -3,6 +3,7 @@ from torch._six import string_classes
import functools
import numpy as np
import sys
from types import MethodType
import warnings
from ._amp_state import _amp_state, warn_or_err, container_abcs
from .handle import disable_casts
......@@ -236,13 +237,13 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
# Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway.
def patch_step(old_step):
def new_step(*args, **kwargs):
def new_step(self, *args, **kwargs):
with disable_casts():
output = old_step(*args, **kwargs)
return output
return new_step
optimizer.step = patch_step(optimizer.step)
optimizer.step = MethodType(patch_step(optimizer.step), optimizer)
if optimizers_was_list:
if models_was_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment