Commit 889d1712 authored by Michael Carilli's avatar Michael Carilli
Browse files

New API tentatively works on resnet50, ready for stress testing.

parent fad78c16
from . import fp16_utils
from . import parallel from . import parallel
from . import amp from . import amp
from . import fp16_utils
# For optimizers and normalization there is no Python fallback. # For optimizers and normalization there is no Python fallback.
# Absence of cuda backend is a hard error. # Absence of cuda backend is a hard error.
......
from .amp import init, half_function, float_function, promote_function,\ from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function,\ register_half_function, register_float_function, register_promote_function
register from .handle import scale_loss
from .multi_tensor_apply import MultiTensorApply from .frontend import register
from .multi_tensor_apply import MultiTensorApply, multi_tensor_applier
# This is a "header object" that allows different amp modules to communicate.
# I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like.
# But apparently it's ok:
# http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm
class AmpState(object):
pass
# Attribute stash. Could also just stash things as global module attributes.
_amp_state = AmpState()
from . import compat, rnn_compat, utils, wrap from . import compat, rnn_compat, utils, wrap
from .handle import AmpHandle, NoOpHandle from .handle import AmpHandle, NoOpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides from .lists import functional_overrides, torch_overrides, tensor_overrides
from ._amp_state import _amp_state
from .frontend import * from .frontend import *
import functools import functools
...@@ -170,4 +171,7 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, ...@@ -170,4 +171,7 @@ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False,
wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg) wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg)
_DECORATOR_HANDLE = handle _DECORATOR_HANDLE = handle
_amp_state.handle = handle
return handle return handle
import torch import torch
from .initialize import initialize from .initialize import _initialize
from ._amp_state import _amp_state
class Properties(object): class Properties(object):
...@@ -10,6 +11,7 @@ class Properties(object): ...@@ -10,6 +11,7 @@ class Properties(object):
""" """
def __init__(self): def __init__(self):
self.options = { self.options = {
"enabled" : False,
"opt_level" : None, "opt_level" : None,
"cast_model_type" : None, "cast_model_type" : None,
"cast_torch_functions" : False, "cast_torch_functions" : False,
...@@ -18,6 +20,7 @@ class Properties(object): ...@@ -18,6 +20,7 @@ class Properties(object):
"loss_scale" : 1.0, "loss_scale" : 1.0,
"flatten_model_params" : False, "flatten_model_params" : False,
"flatten_master_params" : False, "flatten_master_params" : False,
"fused_optimizer" : False,
"enable_ddp_interop" : False} "enable_ddp_interop" : False}
""" """
...@@ -45,7 +48,7 @@ class Properties(object): ...@@ -45,7 +48,7 @@ class Properties(object):
def __setattr__(self, name, value): def __setattr__(self, name, value):
if "options" in self.__dict__: if "options" in self.__dict__:
if name in self.options: if name in self.options:
print("setting {}".format(name)) print("setting {} {}".format(name, value))
self.options[name] = value self.options[name] = value
else: else:
super(Properties, self).__setattr__(name, value) super(Properties, self).__setattr__(name, value)
...@@ -63,7 +66,8 @@ class O3: ...@@ -63,7 +66,8 @@ class O3:
"If not, try other optimization levels." "If not, try other optimization levels."
def __call__(self, properties): def __call__(self, properties):
properties.opt_level = "O3", properties.enabled = True
properties.opt_level = "O3"
properties.cast_model_type = torch.float16 properties.cast_model_type = torch.float16
properties.cast_torch_functions = False properties.cast_torch_functions = False
properties.cast_batchnorm = False properties.cast_batchnorm = False
...@@ -71,6 +75,7 @@ class O3: ...@@ -71,6 +75,7 @@ class O3:
properties.loss_scale = 1.0 properties.loss_scale = 1.0
properties.flatten_model_params = False properties.flatten_model_params = False
properties.flatten_master_params = False properties.flatten_master_params = False
properties.fused_optimizer = False
properties.enable_ddp_interop = False properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary return properties # modified in place so this isn't really necessary
...@@ -86,7 +91,8 @@ class O2: ...@@ -86,7 +91,8 @@ class O2:
"Master weights can also improve convergence and stability." "Master weights can also improve convergence and stability."
def __call__(self, properties): def __call__(self, properties):
properties.opt_level = "O2", properties.enabled = True
properties.opt_level = "O2"
properties.cast_model_type = torch.float16 properties.cast_model_type = torch.float16
properties.cast_torch_functions = False properties.cast_torch_functions = False
properties.cast_batchnorm = torch.float32 properties.cast_batchnorm = torch.float32
...@@ -94,6 +100,7 @@ class O2: ...@@ -94,6 +100,7 @@ class O2:
properties.loss_scale = 128.0 properties.loss_scale = 128.0
properties.flatten_model_params = False properties.flatten_model_params = False
properties.flatten_master_params = False properties.flatten_master_params = False
properties.fused_optimizer = False
properties.enable_ddp_interop = False properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary return properties # modified in place so this isn't really necessary
...@@ -108,7 +115,8 @@ class O1: ...@@ -108,7 +115,8 @@ class O1:
"trying mixed precision training for the first time." "trying mixed precision training for the first time."
def __call__(self, properties): def __call__(self, properties):
properties.opt_level = "O1", properties.enabled = True
properties.opt_level = "O1"
properties.cast_model_type = False properties.cast_model_type = False
properties.cast_torch_functions = True properties.cast_torch_functions = True
properties.cast_batchnorm = False properties.cast_batchnorm = False
...@@ -116,6 +124,7 @@ class O1: ...@@ -116,6 +124,7 @@ class O1:
properties.loss_scale = "dynamic" properties.loss_scale = "dynamic"
properties.flatten_model_params = False properties.flatten_model_params = False
properties.flatten_master_params = False properties.flatten_master_params = False
properties.fused_optimizer = False
properties.enable_ddp_interop = False properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary return properties # modified in place so this isn't really necessary
...@@ -128,7 +137,8 @@ class O0: ...@@ -128,7 +137,8 @@ class O0:
"may still be requested.\n" "may still be requested.\n"
def __call__(self, properties): def __call__(self, properties):
properties.opt_level = "O0", properties.enabled = True
properties.opt_level = "O0"
properties.cast_model_type = torch.float32 properties.cast_model_type = torch.float32
properties.cast_torch_functions = False properties.cast_torch_functions = False
properties.cast_batchnorm = False properties.cast_batchnorm = False
...@@ -136,6 +146,7 @@ class O0: ...@@ -136,6 +146,7 @@ class O0:
properties.loss_scale = 1.0 properties.loss_scale = 1.0
properties.flatten_model_params = False properties.flatten_model_params = False
properties.flatten_master_params = False properties.flatten_master_params = False
properties.fused_optimizer = False
properties.enable_ddp_interop = False properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary return properties # modified in place so this isn't really necessary
...@@ -162,47 +173,49 @@ def check_params_fp32(model): ...@@ -162,47 +173,49 @@ def check_params_fp32(model):
# allow user to directly pass Properties struct as well? # allow user to directly pass Properties struct as well?
def register(enabled=False, def register(models, optimizers, enabled=True, opt_level=None, **kwargs):
optimizers=None, """
models=None, Expected kwargs:
opt_level=None, opt_level=None,
cast_model_type=None, cast_model_type=None,
cast_torch_functions=None, cast_torch_functions=None,
cast_batchnorm=None, cast_batchnorm=None,
master_weights=None, master_weights=None,
loss_scale=None, loss_scale=None,
flatten_model_params=None, flatten_model_params=None,
flatten_master_params=None, flatten_master_params=None,
enable_ddp_interop=None): enable_ddp_interop=None):
"""
if not enabled: if not enabled:
return return models, optimizers
if opt_level not in opt_levels: if opt_level not in opt_levels:
raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.") raise RuntimeError(
"Unexpected optimization level {}. ".format(opt_level) +
"Options are 'O0', 'O1', 'O2', 'O3'.")
else: else:
amp.opt_properties = opt_levels[opt_level](Properties()) _amp_state.opt_properties = opt_levels[opt_level](Properties())
print("Selected optimization level {}", opt_levels[opt_level].brief) print("Selected optimization level {}".format(opt_levels[opt_level].brief))
print("Defaults for this optimization level are:") print("Defaults for this optimization level are:")
for k, v in amp.opt_properties.options: print(_amp_state.opt_properties.options)
print("{:20} : {}", k, v) for k, v in _amp_state.opt_properties.options.items():
print("{:20} : {}".format(k, v))
for model in models:
check_params_fp32(model)
print("Processing user overrides (additional kwargs that are not None)...") print("Processing user overrides (additional kwargs that are not None)...")
for k, v in kwargs: for k, v in kwargs.items():
if k not in _amp_state.opt_properties.options:
raise RuntimeError("Unexpected kwarg {}".format(k))
if v is not None: if v is not None:
setattr(amp.opt_properties, k, v) setattr(_amp_state.opt_properties, k, v)
print("After processing overrides, optimization options are:") print("After processing overrides, optimization options are:")
for k, v in amp.opt_properties.options: for k, v in _amp_state.opt_properties.options.items():
print("{:20} : {}", k, v) print("{:20} : {}".format(k, v))
return initialize(optimizers, models) return _initialize(models, optimizers, _amp_state.opt_properties)
def check_option_consistency(enabled=False, def check_option_consistency(enabled=True,
opt_level=None, opt_level=None,
cast_model_type=None, cast_model_type=None,
cast_torch_functions=None, cast_torch_functions=None,
...@@ -230,13 +243,15 @@ def check_option_consistency(enabled=False, ...@@ -230,13 +243,15 @@ def check_option_consistency(enabled=False,
print("Selected optimization level {}", opt_levels[opt_level].brief) print("Selected optimization level {}", opt_levels[opt_level].brief)
print("Defaults for this optimization level are:") print("Defaults for this optimization level are:")
for k, v in opt_properties.options: for k, v in opt_properties.options:
print("{:20} : {}", k, v) print("{:20} : {}".format(k, v))
print("Processing user overrides (additional kwargs that are not None)...") print("Processing user overrides (additional kwargs that are not None)...")
for k, v in kwargs: for k, v in kwargs:
if k not in amp_state.opt_properties.options:
raise RuntimeError("Unexpected kwarg {}".format(k))
if v is not None: if v is not None:
setattr(opt_properties, k, v) setattr(opt_properties, k, v)
print("After processing overrides, optimization options are:") print("After processing overrides, optimization options are:")
for k, v in opt_properties.options: for k, v in opt_properties.options:
print("{:20} : {}", k, v) print("{:20} : {}".format(k, v))
...@@ -5,6 +5,64 @@ import warnings ...@@ -5,6 +5,64 @@ import warnings
from . import utils from . import utils
from .opt import OptimWrapper from .opt import OptimWrapper
from .scaler import LossScaler, iter_params from .scaler import LossScaler, iter_params
from ._amp_state import _amp_state
from ..fp16_utils import FP16_Optimizer
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
@contextlib.contextmanager
def scale_loss(loss,
optimizer,
model=None,
delay_unscale=False):
if not _amp_state.opt_properties.enabled:
yield loss
return
if optimizer.loss_scaler is None:
raise RuntimeError("optimizer passed to scale_loss does not have a loss_scaler.")
loss_scale = optimizer.loss_scaler.loss_scale()
if ((not _amp_state.opt_properties.master_weights)
and (not optimizer.loss_scaler.dynamic)
and loss_scale == 1.0):
yield loss
# Needing to drop the cache here as well is an ugly gotcha.
# But for now I think it's necessary to short-circuit.
# Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.cast_torch_functions:
_amp_state.handle._clear_cache()
return
yield loss*loss_scale
# this isn't pretty but it unifies things. Once I deprecate the old API entirely,
# I will have freedom to clean this up. Maybe instead of wrapping optimizers,
# I can simply construct a set of attributes (e.g. master params) and assign them
# directly to optimizer instances.
if not delay_unscale:
if isinstance(optimizer, FP16_Optimizer):
optimizer.update_master_grads()
else:
optimizer.loss_scaler.unscale(
iter_params(optimizer.param_groups),
iter_params(optimizer.param_groups),
loss_scale)
# If overflow_check_on_cpu is False, should_skip will always be False.
should_skip = optimizer.loss_scaler.update_scale()
if should_skip:
optimizer_step = optimizer.step
def skip_step():
logger = logging.getLogger('apex.amp')
logger.warning('Gradient overflow, skipping update')
optimizer.step = optimizer_step
optimizer.step = skip_step
# Probably ok to skip this if not delay_unscale
if _amp_state.opt_properties.cast_torch_functions:
_amp_state.handle._clear_cache()
class AmpHandle(object): class AmpHandle(object):
def __init__(self, loss_scale="dynamic", enable_caching=True, verbose=False): def __init__(self, loss_scale="dynamic", enable_caching=True, verbose=False):
...@@ -43,10 +101,11 @@ class AmpHandle(object): ...@@ -43,10 +101,11 @@ class AmpHandle(object):
loss_scale = self._default_scaler.loss_scale() loss_scale = self._default_scaler.loss_scale()
yield loss * loss_scale yield loss * loss_scale
should_skip = self._default_scaler.unscale_and_update( self._default_scaler.unscale(
iter_params(optimizer.param_groups), iter_params(optimizer.param_groups),
iter_params(optimizer.param_groups), iter_params(optimizer.param_groups),
loss_scale) loss_scale)
should_skip = self._default_scaler.update_scale()
if should_skip: if should_skip:
optimizer_step = optimizer.step optimizer_step = optimizer.step
def skip_step(): def skip_step():
...@@ -108,5 +167,8 @@ class NoOpHandle(object): ...@@ -108,5 +167,8 @@ class NoOpHandle(object):
def verbose(self): def verbose(self):
return False return False
def _clear_cache(self):
pass
def _deactivate(self): def _deactivate(self):
pass pass
...@@ -2,6 +2,25 @@ import torch ...@@ -2,6 +2,25 @@ import torch
from torch._six import container_abcs, string_classes from torch._six import container_abcs, string_classes
import functools import functools
from apex.fp16_utils import convert_network from apex.fp16_utils import convert_network
from ._amp_state import _amp_state
from .scaler import LossScaler
from ..fp16_utils import FP16_Optimizer
def check_params_fp32(model):
for name, param in model.named_parameters():
if param.is_floating_point() and param.type() != "torch.cuda.FloatTensor":
print("Warning: Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.register, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, param.type()))
for name, buf in model.named_buffers():
if buf.is_floating_point() and buf.type() != "torch.cuda.FloatTensor":
print("Warning: Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.register, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.".format(
name, buf.type()))
def to_type(dtype, t): def to_type(dtype, t):
...@@ -11,7 +30,7 @@ def to_type(dtype, t): ...@@ -11,7 +30,7 @@ def to_type(dtype, t):
print("Warning: input data requires grad. Since input data is not a model parameter,\n" print("Warning: input data requires grad. Since input data is not a model parameter,\n"
"its gradients will not be properly allreduced by DDP.") "its gradients will not be properly allreduced by DDP.")
if t.is_floating_point(): if t.is_floating_point():
return t.half() return t.to(dtype)
return t return t
...@@ -29,13 +48,47 @@ def applier(value, fn): ...@@ -29,13 +48,47 @@ def applier(value, fn):
return value return value
def initialize(optimizers, models, properties): def _initialize(models, optimizers, properties):
from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init
if isinstance(optimizers, torch.optim.Optimizer):
optimizers_was_list = False
optimizers = [optimizers]
elif isinstance(optimizers, list):
optimizers_was_list = True
else:
raise TypeError("optimizers must be either a single optimizer or a list of optimizers.")
if isinstance(models, torch.nn.Module):
models_was_list = False
models = [models]
elif isinstance(models, list):
models_was_list = True
else:
raise TypeError("models must be either a single model or a list of models.")
for model in models:
parallel_type = None
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
parallel_type = "torch.nn.parallel.DistributedDataParallel"
if isinstance(model, apex_DDP):
parallel_type = "apex.parallel.DistributedDataParallel"
if isinstance(model, torch.nn.parallel.DataParallel):
parallel_type = "torch.nn.parallel.DataParallel"
if parallel_type is not None:
raise RuntimeError("Incoming model is an instance of {}. ".format(parallel_type) +
"Parallel wrappers should only be applied AFTER the model(s) have been "
"returned from amp.register.")
for model in models:
check_params_fp32(model)
# Stash master weights before casting the model. # Stash master weights before casting the model.
# if properties.master_weights: # if properties.master_weights:
if properties.cast_model_type is not None: if properties.cast_model_type:
if properties.cast_batchnorm is not None: if properties.cast_batchnorm:
for model in models: for model in models:
convert_network(model, properties.cast_model_type) convert_network(model, properties.cast_model_type)
else: else:
...@@ -50,7 +103,7 @@ def initialize(optimizers, models, properties): ...@@ -50,7 +103,7 @@ def initialize(optimizers, models, properties):
return old_fwd(*applier(args, caster), return old_fwd(*applier(args, caster),
**applier(kwargs, caster)) **applier(kwargs, caster))
return new_fwd return new_fwd
model.forward = patch_forward(model.forward) model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors # State dict trick to recast any preexisting per-param state tensors
...@@ -60,11 +113,23 @@ def initialize(optimizers, models, properties): ...@@ -60,11 +113,23 @@ def initialize(optimizers, models, properties):
if properties.master_weights: if properties.master_weights:
for i, optimizer in enumerate(optimizers): for i, optimizer in enumerate(optimizers):
if properties.loss_scale == "dynamic": if properties.loss_scale == "dynamic":
optimizers[i] = FP16_Optimizer(optimizer[i], dynamic_loss_scale=True) optimizers[i] = FP16_Optimizer(optimizers[i], dynamic_loss_scale=True)
else: else:
optimizers[i] = FP16_Optimizer(optimizer[i], static_loss_scale=properties.loss_scale) optimizers[i] = FP16_Optimizer(optimizers[i], static_loss_scale=properties.loss_scale)
else:
for optimizer in optimizers:
optimizer.loss_scaler = LossScaler(properties.loss_scale)
if properties.cast_torch_functions: if properties.cast_torch_functions:
handle = amp.init() # the handle is also globally accessible as amp._DECORATOR_HANDLE handle = amp_init(loss_scale=properties.loss_scale)
return optimizers, models if optimizers_was_list:
if models_was_list:
return models, optimizers
else:
return models[0], optimizers
else:
if models_was_list:
return models, optimizers[0]
else:
return models[0], optimizers[0]
import torch import torch
from amp_C import prep_multi_tensor_launch
class MultiTensorApply(object): class MultiTensorApply(object):
available = False
warned = False
def __init__(self, max_blocks, max_tensors, max_depth, chunk_size): def __init__(self, max_blocks, max_tensors, max_depth, chunk_size):
self.chunk_size = chunk_size try:
self.reallocate(max_blocks, max_tensors, max_depth) import amp_C
MultiTensorApply.available = True
MultiTensorApply.prep_multi_tensor_launch = amp_C.prep_multi_tensor_launch
self.chunk_size = chunk_size
self.reallocate(max_blocks, max_tensors, max_depth)
except ImportError as err:
MultiTensorApply.availble = False
MultiTensorApply.import_err = err
def check_avail(self):
if MultiTensorApply.available == False:
raise RuntimeError(
"Attempted to call MultiTensorApply method, but MultiTensorApply "
"is not available, possibly because Apex was installed without "
"--cpp_ext --cuda_ext. Original import error message:",
MultiTensorApply.import_err)
def __call__(self, op, noop_flag_buffer, tensor_lists, *args): def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
self.assign_blocks(tensor_lists) self.check_avail()
assert len(tensor_lists) > 0, "len(tensor_lists) = {}".format(len(tensor_lists))
len0 = len(tensor_lists[0])
assert len0 > 0, "len(tensor_lists[0]) = {}".format(len0)
for i, l in enumerate(tensor_lists):
assert len(tensor_lists[i]) == len0,\
"len(tensor_lists[{}] = {}, len(tensor_lists[0] = {}".format(
len(tensor_lists[i]), len(tensor_lists[0]))
self.assign_blocks(tensor_lists)
# print(self.gpu_block_to_tensor) # print(self.gpu_block_to_tensor)
# print(self.gpu_block_to_chunk) # print(self.gpu_block_to_chunk)
# print(self.gpu_tensor_sizes) # print(self.gpu_tensor_sizes)
...@@ -16,11 +42,11 @@ class MultiTensorApply(object): ...@@ -16,11 +42,11 @@ class MultiTensorApply(object):
return op(self.nblocks, return op(self.nblocks,
noop_flag_buffer, noop_flag_buffer,
self.cpu_tensor_addresses, self.cpu_tensor_addresses,
self.gpu_block_to_tensor, self.gpu_block_to_tensor,
self.gpu_block_to_chunk, self.gpu_block_to_chunk,
self.gpu_tensor_sizes, self.gpu_tensor_sizes,
self.gpu_tensor_addresses, self.gpu_tensor_addresses,
self.chunk_size, self.chunk_size,
tensor_lists, tensor_lists,
*args) *args)
...@@ -30,6 +56,8 @@ class MultiTensorApply(object): ...@@ -30,6 +56,8 @@ class MultiTensorApply(object):
# print(self.gpu_tensor_addresses) # print(self.gpu_tensor_addresses)
def assign_blocks(self, tensor_lists): def assign_blocks(self, tensor_lists):
self.check_avail()
needs_reallocate = False needs_reallocate = False
# Currently, this loop appears prohibitively expensive. # Currently, this loop appears prohibitively expensive.
...@@ -38,7 +66,7 @@ class MultiTensorApply(object): ...@@ -38,7 +66,7 @@ class MultiTensorApply(object):
# list0 = tensor_lists[0] # list0 = tensor_lists[0]
# self.nblocks = 0 # self.nblocks = 0
# for t, tensor in enumerate(list0): # for t, tensor in enumerate(list0):
# blocks_this_tensor = (tensor.numel() + # blocks_this_tensor = (tensor.numel() +
# self.chunk_size - 1)//self.chunk_size # self.chunk_size - 1)//self.chunk_size
# if not needs_reallocate: # if not needs_reallocate:
# self.cpu_tensor_sizes[t] = tensor.numel() # self.cpu_tensor_sizes[t] = tensor.numel()
...@@ -49,20 +77,21 @@ class MultiTensorApply(object): ...@@ -49,20 +77,21 @@ class MultiTensorApply(object):
# self.cpu_block_to_tensor[self.nblocks] = t # self.cpu_block_to_tensor[self.nblocks] = t
# self.cpu_block_to_chunk[self.nblocks] = chunk # self.cpu_block_to_chunk[self.nblocks] = chunk
# self.nblocks += 1 # self.nblocks += 1
needs_reallocate, self.nblocks = prep_multi_tensor_launch(self.cpu_block_to_tensor, needs_reallocate, self.nblocks = MultiTensorApply.prep_multi_tensor_launch(
self.cpu_block_to_chunk, self.cpu_block_to_tensor,
self.cpu_tensor_sizes, self.cpu_block_to_chunk,
self.gpu_block_to_tensor, self.cpu_tensor_sizes,
self.gpu_block_to_chunk, self.gpu_block_to_tensor,
self.gpu_tensor_sizes, self.gpu_block_to_chunk,
self.chunk_size, self.gpu_tensor_sizes,
self.max_depth, self.chunk_size,
self.max_tensors, self.max_depth,
self.max_blocks, self.max_tensors,
tensor_lists) self.max_blocks,
tensor_lists)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
print(self.nblocks) # print(self.nblocks)
if self.nblocks > self.max_blocks: if self.nblocks > self.max_blocks:
self.max_blocks = self.nblocks self.max_blocks = self.nblocks
...@@ -73,23 +102,26 @@ class MultiTensorApply(object): ...@@ -73,23 +102,26 @@ class MultiTensorApply(object):
if needs_reallocate: if needs_reallocate:
self.reallocate(self.max_blocks, self.max_tensors, self.max_depth) self.reallocate(self.max_blocks, self.max_tensors, self.max_depth)
needs_reallocate, self.nblocks = prep_multi_tensor_launch(self.cpu_block_to_tensor, needs_reallocate, self.nblocks = MultiTensorApply.prep_multi_tensor_launch(
self.cpu_block_to_chunk, self.cpu_block_to_tensor,
self.cpu_tensor_sizes, self.cpu_block_to_chunk,
self.gpu_block_to_tensor, self.cpu_tensor_sizes,
self.gpu_block_to_chunk, self.gpu_block_to_tensor,
self.gpu_tensor_sizes, self.gpu_block_to_chunk,
self.chunk_size, self.gpu_tensor_sizes,
self.max_depth, self.chunk_size,
self.max_tensors, self.max_depth,
self.max_blocks, self.max_tensors,
tensor_lists) self.max_blocks,
tensor_lists)
assert needs_reallocate == 0, "Should not need reallocate on second attempt." assert needs_reallocate == 0, "Should not need reallocate on second attempt."
assert self.nblocks <= self.max_blocks, "Should not need to increase blocks again." assert self.nblocks <= self.max_blocks, "Should not need to increase blocks again."
def reallocate(self, max_blocks, max_tensors, max_depth): def reallocate(self, max_blocks, max_tensors, max_depth):
self.check_avail()
self.max_blocks = max_blocks self.max_blocks = max_blocks
self.max_tensors = max_tensors self.max_tensors = max_tensors
self.max_depth = max_depth self.max_depth = max_depth
self.cpu_block_to_tensor = torch.IntTensor(max_blocks).pin_memory() self.cpu_block_to_tensor = torch.IntTensor(max_blocks).pin_memory()
...@@ -101,3 +133,5 @@ class MultiTensorApply(object): ...@@ -101,3 +133,5 @@ class MultiTensorApply(object):
self.gpu_block_to_chunk = torch.cuda.IntTensor(max_blocks) self.gpu_block_to_chunk = torch.cuda.IntTensor(max_blocks)
self.gpu_tensor_sizes = torch.cuda.IntTensor(max_tensors) self.gpu_tensor_sizes = torch.cuda.IntTensor(max_tensors)
self.gpu_tensor_addresses = torch.cuda.LongTensor(max_depth, max_tensors) self.gpu_tensor_addresses = torch.cuda.LongTensor(max_depth, max_tensors)
multi_tensor_applier = MultiTensorApply(1000, 100, 4, 2048)
...@@ -37,10 +37,11 @@ class OptimWrapper(object): ...@@ -37,10 +37,11 @@ class OptimWrapper(object):
loss_scale = self._cur_loss_scaler().loss_scale() loss_scale = self._cur_loss_scaler().loss_scale()
yield loss * loss_scale yield loss * loss_scale
self._skip_next[self._loss_idx] = self._cur_loss_scaler().unscale_and_update( self._cur_loss_scaler().unscale(
iter_params(self._optimizer.param_groups), iter_params(self._optimizer.param_groups),
iter_params(self._optimizer.param_groups), iter_params(self._optimizer.param_groups),
loss_scale) loss_scale)
self._skip_next[self._loss_idx] = self._cur_loss_scaler().update_scale()
self._loss_idx += 1 self._loss_idx += 1
if len(cached_grads) > 0: if len(cached_grads) > 0:
......
import torch import torch
import logging import logging
from .multi_tensor_apply import multi_tensor_applier
from ._amp_state import _amp_state
# from apex_C import scale_check_overflow # from apex_C import scale_check_overflow
def scale_check_overflow_python(model_grad, scale, master_grad): def scale_check_overflow_python(model_grad, scale, master_grad):
# Exception handling for 18.04 compatibility # Exception handling for 18.04 compatibility
try: try:
cpu_sum = float(d_grads.float().sum()) cpu_sum = float(model_grad.float().sum())
except RuntimeError as instance: except RuntimeError as instance:
if "value cannot be converted" not in instance.args[0]: if "value cannot be converted" not in instance.args[0]:
raise raise
...@@ -16,9 +18,10 @@ def scale_check_overflow_python(model_grad, scale, master_grad): ...@@ -16,9 +18,10 @@ def scale_check_overflow_python(model_grad, scale, master_grad):
return True return True
if master_grad is not model_grad: if master_grad is not model_grad:
master_grad.copy_(model_grad) master_grad.copy_(model_grad)
master_grad.mul_(scale) if scale != 1.0:
master_grad.mul_(scale)
return False return False
class LossScaler(object): class LossScaler(object):
warned_no_fused_kernel = False warned_no_fused_kernel = False
warned_fp16_grad = False warned_fp16_grad = False
...@@ -39,48 +42,88 @@ class LossScaler(object): ...@@ -39,48 +42,88 @@ class LossScaler(object):
self._scale_seq_len = scale_window self._scale_seq_len = scale_window
self._unskipped = 0 self._unskipped = 0
self._has_overflow = False self._has_overflow = False
try: self._overflow_buf = torch.cuda.IntTensor([0])
if multi_tensor_applier.available:
import amp_C import amp_C
LossScaler.has_fused_kernel = True LossScaler.has_fused_kernel = multi_tensor_applier.available
LossScaler.scale_check_overflow_cuda = amp_C.scale_check_overflow LossScaler.multi_tensor_unscale_cuda = amp_C.multi_tensor_unscale
self._overflow_buf = torch.cuda.IntTensor([0]) else:
except ImportError as err:
if not LossScaler.warned_no_fused_kernel: if not LossScaler.warned_no_fused_kernel:
print("Warning: Amp fused downscale kernel is unavailable, possibly because apex " print("Warning: multi_tensor_applier fused downscale kernel is unavailable, "
"was installed without --cuda_ext. Using Python fallback. ImportError was: ", "possibly because apex was installed without --cuda_ext --cpp_ext. "
err) "Using Python fallback. Original ImportError was: ",
multi_tensor_applier.import_err)
LossScaler.has_fused_kernel = False LossScaler.has_fused_kernel = False
LossScaler.warned_no_fused_kernel = True LossScaler.warned_no_fused_kernel = True
def loss_scale(self): def loss_scale(self):
return self._loss_scale return self._loss_scale
def unscale_and_update(self, model_params, master_params, scale): def unscale_grads_python(self, model_grads, master_grads, scale):
if LossScaler.has_fused_kernel: for model, master in zip(model_grads, master_grads):
self._overflow_buf.zero_() if model is not None:
if (master.type() != "torch.cuda.FloatTensor"
and not LossScaler.warned_fp16_grad):
logger = logging.getLogger("apex.amp")
logger.warning(
"Attempting to downscale {} grads. ".format(master.type()) +
"Downscaling non-fp32 grads may indicate an error. "
"When using Amp, you don't need to call .half() on your model.")
LossScaler.warned_fp16_grad = True
self._has_overflow = scale_check_overflow_python(
model,
1./scale,
master)
if self._has_overflow and self.dynamic:
break
def unscale(self, model_params, master_params, scale):
self._has_overflow = False self._has_overflow = False
for model, master in zip(model_params, master_params):
if model.grad is not None: # Lots of defensive list processing going on here. Way more less efficient than
if LossScaler.has_fused_kernel and master.grad.data.type() == "torch.cuda.FloatTensor": # consuming the iterator directly. Need to examine Python overhead.
LossScaler.scale_check_overflow_cuda(model.grad.data, model_master_params = [(model, master) for model, master
1./scale, in zip(model_params, master_params)] # some of these may be None
self._overflow_buf,
master.grad.data) # Sync the None-ness of model and master params.
all_same = True
for model, master in model_master_params:
if model.grad is None and master.grad is not None:
master.grad = None
if model.grad is not None and master.grad is None:
master.grad = torch.empty_like(master)
if model.grad is not master.grad:
all_same = False
model_grads = [mmp[0].grad.data for mmp in model_master_params if mmp[0].grad is not None]
master_grads = [mmp[1].grad.data for mmp in model_master_params if mmp[1].grad is not None]
if LossScaler.has_fused_kernel:
# The master grads should never be fp16. The kernel can't handle that, so bail out
# and print a warning. This is overly conservative, and maybe we do want to enable
# fast downscaling of fp16 grads eventually.
if any(grad.type() == "torch.cuda.HalfTensor" for grad in master_grads):
self.unscale_grads_python(model_grads, master_grads, scale)
else:
# This is inefficient if opt_level is O1 and loss scale is 1.0. But to elide
# the launch, I would need to make sure the model grads are the master grads.
# The O(N) checks are proliferating...
self._overflow_buf.zero_()
# handle case of opt_level O1 and loss_scale 1.0. There's also some
# special-cased yields in scale_loss to potentially short-circuit earlier.
if scale == 1.0 and all_same and not self.dynamic:
return
else: else:
if (master.grad.data.type() != "torch.cuda.FloatTensor" multi_tensor_applier(
and not LossScaler.warned_fp16_grad): LossScaler.multi_tensor_unscale_cuda,
logger = logging.getLogger("apex.amp") self._overflow_buf,
logger.warning( [model_grads, master_grads],
"Attempting to downscale {} grads. ".format(master.grad.data.type()) + 1./scale)
"Downscaling non-fp32 grads may indicate an error. " else:
"When using Amp, you don't need to call .half() on your model.") self.unscale_grads_python(model_grads, master_grads, scale)
LossScaler.warned_fp16_grad = True
self._has_overflow = scale_check_overflow_python(model.grad.data,
1./scale,
master.grad.data)
if self._has_overflow and self.dynamic:
break
# Break into multiple param groups so unscale() can be called more that once before updating.
def update_scale(self):
# If the fused kernel is available, we only need one D2H memcopy and sync. # If the fused kernel is available, we only need one D2H memcopy and sync.
if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow: if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
self._has_overflow = self._overflow_buf.item() self._has_overflow = self._overflow_buf.item()
......
...@@ -14,3 +14,5 @@ from .fp16util import ( ...@@ -14,3 +14,5 @@ from .fp16util import (
from .fp16_optimizer import FP16_Optimizer from .fp16_optimizer import FP16_Optimizer
from .loss_scaler import LossScaler, DynamicLossScaler from .loss_scaler import LossScaler, DynamicLossScaler
test = 1
...@@ -39,7 +39,7 @@ class FP16_Optimizer(object): ...@@ -39,7 +39,7 @@ class FP16_Optimizer(object):
init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`.
static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate.
dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option.
dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`LossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`LossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`LossScaler`'s defaults will be used.
verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling.
``init_optimizer`` is expected to have been constructed in the ordinary way. ``init_optimizer`` is expected to have been constructed in the ordinary way.
...@@ -154,6 +154,18 @@ class FP16_Optimizer(object): ...@@ -154,6 +154,18 @@ class FP16_Optimizer(object):
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group)
self.all_fp16_params = []
for group in self.fp16_groups:
self.all_fp16_params += group
self.all_fp32_from_fp16_params = []
for group in self.fp32_from_fp16_groups:
self.all_fp32_from_fp16_params += group
self.all_fp32_from_fp32_params = []
for group in self.fp32_from_fp32_groups:
self.all_fp32_from_fp32_params += group
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# alternative way to cast per-param state tensors: # alternative way to cast per-param state tensors:
...@@ -210,35 +222,36 @@ class FP16_Optimizer(object): ...@@ -210,35 +222,36 @@ class FP16_Optimizer(object):
param.grad.detach_() # as in torch.optim.optimizer.zero_grad() param.grad.detach_() # as in torch.optim.optimizer.zero_grad()
param.grad.zero_() param.grad.zero_()
def _check_overflow(self): # Should not be used anymore.
params = [] # def _check_overflow(self):
for group in self.fp16_groups: # params = []
for param in group: # for group in self.fp16_groups:
params.append(param) # for param in group:
for group in self.fp32_from_fp32_groups: # params.append(param)
for param in group: # for group in self.fp32_from_fp32_groups:
params.append(param) # for param in group:
self.overflow = self.loss_scaler.has_overflow(params) # params.append(param)
# self.overflow = self.loss_scaler.has_overflow(params)
def _update_scale(self, has_overflow=False): # def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow) # self.loss_scaler.update_scale(has_overflow)
def _master_params_to_model_params(self): def _master_params_to_model_params(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group) master_params_to_model_params(fp16_group, fp32_from_fp16_group)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable # To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
def _model_grads_to_master_grads(self): # def _model_grads_to_master_grads(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): # for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) # model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
def _downscale_master(self): # def _downscale_master(self):
if self.loss_scale != 1.0: # if self.loss_scale != 1.0:
for group in self.optimizer.param_groups: # for group in self.optimizer.param_groups:
for param in group['params']: # for param in group['params']:
if param.grad is not None: # if param.grad is not None:
param.grad.data.mul_(1./self.loss_scale) # param.grad.data.mul_(1./self.loss_scale)
def clip_master_grads(self, max_norm, norm_type=2): def clip_master_grads(self, max_norm, norm_type=2):
""" """
...@@ -366,12 +379,15 @@ class FP16_Optimizer(object): ...@@ -366,12 +379,15 @@ class FP16_Optimizer(object):
http://pytorch.org/docs/master/optim.html#optimizer-step-closure http://pytorch.org/docs/master/optim.html#optimizer-step-closure
""" """
scale = self.loss_scaler.loss_scale scale = self.loss_scaler.loss_scale()
self._update_scale(self.overflow) # To consider: Should this be in step(), or update_master_grads? It works either way,
# but I should make it consistent with the Amp control flow, which updates the scale
# during backward context manager exit.
# self._update_scale(self.overflow)
if self.overflow: if self.overflow:
print("OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}" print("OVERFLOW! Skipping step, reducing loss scale to {}".format(
.format(scale, self.loss_scale)) self.loss_scaler.loss_scale()))
return return
if closure is not None: if closure is not None:
...@@ -409,10 +425,10 @@ class FP16_Optimizer(object): ...@@ -409,10 +425,10 @@ class FP16_Optimizer(object):
# closure() and return the loss. # closure() and return the loss.
temp_loss = closure() temp_loss = closure()
while(self.overflow): while(self.overflow):
scale = self.loss_scaler.loss_scale scale = self.loss_scaler.loss_scale()
self._update_scale(self.overflow) # self._update_scale(self.overflow) # now done at the end of backward
print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, " print("OVERFLOW within closure! Skipping step, reducing loss scale to {}".format(
"reducing to {}".format(scale, self.loss_scale)) self.loss_scaler.loss_scale()))
temp_loss = closure() temp_loss = closure()
return temp_loss return temp_loss
...@@ -480,7 +496,8 @@ class FP16_Optimizer(object): ...@@ -480,7 +496,8 @@ class FP16_Optimizer(object):
# a loss scale that works. After you find a loss scale that works, do a final dummy # a loss scale that works. After you find a loss scale that works, do a final dummy
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid # backward pass with retain_graph=False to tear down the graph. Doing this would avoid
# discarding the iteration, but probably wouldn't improve overall efficiency. # discarding the iteration, but probably wouldn't improve overall efficiency.
loss.float()*loss_scaler.loss_scale().backward(retain_graph=retain_graph) scaled_loss = loss.float()*self.loss_scaler.loss_scale()
scaled_loss.backward(retain_graph=retain_graph)
if update_master_grads: if update_master_grads:
self.update_master_grads() self.update_master_grads()
...@@ -491,11 +508,24 @@ class FP16_Optimizer(object): ...@@ -491,11 +508,24 @@ class FP16_Optimizer(object):
updated by the optimizer. :attr:`update_master_grads` only needs to be called if updated by the optimizer. :attr:`update_master_grads` only needs to be called if
``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.
""" """
if self.dynamic_loss_scale: # if self.dynamic_loss_scale:
self._check_overflow() # self._check_overflow()
if self.overflow: return # if self.overflow: return
self._model_grads_to_master_grads() # self._model_grads_to_master_grads()
self._downscale_master() # self._downscale_master()
# Use the one-shot multi-tensor apply kernel
if len(self.all_fp16_params) > 0:
self.loss_scaler.unscale(
self.all_fp16_params,
self.all_fp32_from_fp16_params,
self.loss_scaler.loss_scale())
if len(self.all_fp32_from_fp32_params) > 0:
self.loss_scaler.unscale(
self.all_fp32_from_fp32_params,
self.all_fp32_from_fp32_params,
self.loss_scaler.loss_scale())
self.overflow = self.loss_scaler.update_scale()
def inspect_master_grad_data(self): def inspect_master_grad_data(self):
""" """
...@@ -533,10 +563,10 @@ class FP16_Optimizer(object): ...@@ -533,10 +563,10 @@ class FP16_Optimizer(object):
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self): def _get_loss_scale(self):
return self.loss_scaler.loss_scale return self.loss_scaler.loss_scale()
def _set_loss_scale(self, value): def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value self.loss_scaler._loss_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale) loss_scale = property(_get_loss_scale, _set_loss_scale)
......
...@@ -52,7 +52,7 @@ struct UnscaleFunctor ...@@ -52,7 +52,7 @@ struct UnscaleFunctor
{ {
incoming_vals[ii] = 0; incoming_vals[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n) if(i < n && i < chunk_size)
incoming_vals[ii] = static_cast<float>(in[i]); incoming_vals[ii] = static_cast<float>(in[i]);
} }
...@@ -60,7 +60,7 @@ struct UnscaleFunctor ...@@ -60,7 +60,7 @@ struct UnscaleFunctor
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n) if(i < n && i < chunk_size)
if(isfinite(incoming_vals[ii])) if(isfinite(incoming_vals[ii]))
out[i] = incoming_vals[ii]*scale; out[i] = incoming_vals[ii]*scale;
else else
...@@ -85,6 +85,8 @@ void multi_tensor_unscale_cuda( ...@@ -85,6 +85,8 @@ void multi_tensor_unscale_cuda(
{ {
using namespace at; using namespace at;
AT_CHECK(nblocks > 0, "nblocks is not > 0");
int addresses_x = gpu_tensor_addresses.size(1); int addresses_x = gpu_tensor_addresses.size(1);
// <.< >.> i don't see any cops. i'm going to access the pointers directly. // <.< >.> i don't see any cops. i'm going to access the pointers directly.
......
...@@ -76,7 +76,7 @@ class TestCache(unittest.TestCase): ...@@ -76,7 +76,7 @@ class TestCache(unittest.TestCase):
param.grad = None param.grad = None
loss = model(self.x).sum() loss = model(self.x).sum()
self.handle._default_scaler._loss_scale = 1.0 self.handle._default_scaler._loss_scale = 4.0
with self.handle.scale_loss(loss, dummy_optimizer) as scaled_loss: with self.handle.scale_loss(loss, dummy_optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
......
...@@ -51,6 +51,7 @@ class TestFP16Optimizer(unittest.TestCase): ...@@ -51,6 +51,7 @@ class TestFP16Optimizer(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff) self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_loss_scaling(self): def test_loss_scaling(self):
ref_optim = torch.optim.Adam(self.ref_model.parameters()) ref_optim = torch.optim.Adam(self.ref_model.parameters())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment