Commit 1f693b92 authored by Michael Carilli's avatar Michael Carilli
Browse files

stashing work

parent b2f63c48
from .amp import init, half_function, float_function, promote_function,\
register_half_function, register_float_function, register_promote_function
register_half_function, register_float_function, register_promote_function,\
register
from . import compat, rnn_compat, utils, wrap
from .handle import AmpHandle, NoOpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides
from ..fp16_utils import FP16_Optimizer
from .frontend import *
import functools
import itertools
import torch
_DECORATOR_HANDLE = None
_USER_CAST_REGISTRY = set()
_USER_PROMOTE_REGISTRY = set()
def _decorator_helper(orig_fn, cast_fn, wrap_fn):
def wrapper(*args, **kwargs):
handle = _DECORATOR_HANDLE
......@@ -21,19 +25,23 @@ def _decorator_helper(orig_fn, cast_fn, wrap_fn):
return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)
return wrapper
# Decorator form
def half_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_half, wrap_fn)
def float_function(fn):
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
def promote_function(fn):
wrap_fn = functools.partial(wrap.make_promote_wrapper)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
# Registry form
def register_half_function(module, name):
if not hasattr(module, name):
......@@ -41,18 +49,21 @@ def register_half_function(module, name):
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
def register_float_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_CAST_REGISTRY.add((module, name, utils.maybe_float))
def register_promote_function(module, name):
if not hasattr(module, name):
raise ValueError('No function named {} in module {}.'.format(
name, module))
_USER_PROMOTE_REGISTRY.add((module, name))
# Top-level function to insert _all_ the hooks.
def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
global _DECORATOR_HANDLE
......
import torch
from .initialize import initialize
class Properties(object):
"""
The purpose of this class is twofold: to establish a set of default properties,
and to route setting of these attributes through __setattr__ so that (in theory)
they can be checked for consistency with other existing args.
"""
def __init__(self):
self.options = {
"opt_level" : None,
"cast_model_type" : None,
"cast_torch_functions" : False,
"cast_batchnorm" : None,
"master_weights" : False,
"loss_scale" : 1.0,
"flatten_model_params" : False,
"flatten_master_params" : False,
"enable_ddp_interop" : False}
"""
This function will allow updating several options at a time without routing through
__setattr__ checks, to avoid "you can't get there from here" scenarios.
"""
def update_options_dict(new_options):
for k, v in new_options:
if k in self.options:
self.options[k] = v
else:
raise ValueError("Tried to set unexpected option {}".format(k))
"""
The members of options are not direct attributes of self, so __getattr__ is ok.
This borrows from the logic in torch.nn.Module.
"""
def __getattr__(self, name):
if "options" in self.__dict__:
options = self.__dict__["options"]
if name in options:
return options[name]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))
def __setattr__(self, name, value):
if "options" in self.__dict__:
if name in self.options:
print("setting {}".format(name))
self.options[name] = value
else:
super(Properties, self).__setattr__(name, value)
""" O0-O3 are convenience wrappers to establish defaults for typically used mixed precision options. """
class O3:
brief = "O3: Pure FP16 training."
more = "Calls .half() on your model, converting the entire model to FP16.\n"\
"A casting operation is also inserted to cast incoming Tensors to FP16,\n"\
"so you don't need to change your data pipeline.\n"\
"This mode is useful for establishing a performance ceiling.\n"\
"It's also possible training may 'just work' in this mode.\n"\
"If not, try other optimization levels."
def __call__(self, properties):
properties.opt_level = "O3",
properties.cast_model_type = torch.float16
properties.cast_torch_functions = False
properties.cast_batchnorm = False
properties.master_weights = False
properties.loss_scale = 1.0
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
class O2:
brief = "O2: FP16 training with FP32 batchnorm and FP32 master weights.\n"
more = "Calls .half() on your model, converting the entire model (except for batchnorms)\n"\
"to FP16. Batchnorms are retained in FP32 for additional stability.\n"\
"The forward pass is patched to cast incoming Tensors to FP16, so you don't need to change\n"\
"your data pipeline.\n"\
"O2 creates FP32 master weights outside the model and patches any optimizers to update\n"\
"these master weights, then copy the master weights into the FP16 model weights.\n"\
"Master weights can also improve convergence and stability."
def __call__(self, properties):
properties.opt_level = "O2",
properties.cast_model_type = torch.float16
properties.cast_torch_functions = False
properties.cast_batchnorm = torch.float32
properties.master_weights = True
properties.loss_scale = 128.0
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
class O1:
brief = "O1: Insert automatic casts around Pytorch functions and Tensor methods.\n"
more = "The type of your model's weights is not altered. However, internally,\n"\
"Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16 for speed,\n"\
"while operations that might benefit from the additional stability of FP32 are patched\n"\
"to cast their inputs to fp32.\n"\
"O1 is the safest way to try mixed precision training, and is recommended when\n"\
"trying mixed precision training for the first time."
def __call__(self, properties):
properties.opt_level = "O1",
properties.cast_model_type = False
properties.cast_torch_functions = True
properties.cast_batchnorm = False
properties.master_weights = False
properties.loss_scale = "dynamic"
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
class O0:
brief = "O0: Pure FP32 training.\n"
more = "Your models are checked to make sure parameters are FP32, but otherwise the\n"\
"types of weights and internal Pytorch operations are not altered. This mode disables any\n"\
"FP16 arithmetic, although other optimizations like parameter flattening and DDP interop\n"\
"may still be requested.\n"
def __call__(self, properties):
properties.opt_level = "O0",
properties.cast_model_type = torch.float32
properties.cast_torch_functions = False
properties.cast_batchnorm = False
properties.master_weights = False
properties.loss_scale = 1.0
properties.flatten_model_params = False
properties.flatten_master_params = False
properties.enable_ddp_interop = False
return properties # modified in place so this isn't really necessary
opt_levels = {"O3": O3(),
"O2": O2(),
"O1": O1(),
"O0": O0()}
def check_params_fp32(model):
for name, param in model.named_parameters():
if param.type() != "torch.cuda.FloatTensor":
print("Warning: Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.register, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.",
name, param.type())
for name, param in model.named_buffers():
if param.type() != "torch.cuda.FloatTensor":
print("Warning: Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
"When using amp.register, you do not need to call .half() on your model\n"
"before passing it, no matter what optimization level you choose.",
name, param.type())
# allow user to directly pass Properties struct as well?
def register(enabled=False,
optimizers=None,
models=None,
opt_level=None,
cast_model_type=None,
cast_torch_functions=None,
cast_batchnorm=None,
master_weights=None,
loss_scale=None,
flatten_model_params=None,
flatten_master_params=None,
enable_ddp_interop=None):
if not enabled:
return
if opt_level not in opt_levels:
raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.")
else:
amp.opt_properties = opt_levels[opt_level](Properties())
print("Selected optimization level {}", opt_levels[opt_level].brief)
print("Defaults for this optimization level are:")
for k, v in amp.opt_properties.options:
print("{:20} : {}", k, v)
for model in models:
check_params_fp32(model)
print("Processing user overrides (additional kwargs that are not None)...")
for k, v in kwargs:
if v is not None:
setattr(amp.opt_properties, k, v)
print("After processing overrides, optimization options are:")
for k, v in amp.opt_properties.options:
print("{:20} : {}", k, v)
initialize(optimizers, models)
def check_option_consistency(enabled=False,
opt_level=None,
cast_model_type=None,
cast_torch_functions=None,
cast_batchnorm=None,
master_weights=None,
loss_scale=None,
flatten_model_params=None,
flatten_master_params=None,
enable_ddp_interop=None):
"""
Utility function that enables users to quickly check if the option combination they intend
to use is permitted. ``check_option_consistency`` does not require models or optimizers
to be constructed, and can be called at any point in the script. ``check_option_consistency``
is totally self-contained; it does not set any amp global state or affect anything outside
of itself.
"""
if not enabled:
return
if opt_level not in opt_levels:
raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.")
else:
opt_properties = opt_levels[opt_level](Properties())
print("Selected optimization level {}", opt_levels[opt_level].brief)
print("Defaults for this optimization level are:")
for k, v in opt_properties.options:
print("{:20} : {}", k, v)
print("Processing user overrides (additional kwargs that are not None)...")
for k, v in kwargs:
if v is not None:
setattr(opt_properties, k, v)
print("After processing overrides, optimization options are:")
for k, v in opt_properties.options:
print("{:20} : {}", k, v)
import torch
from torch._six import container_abcs, string_classes
import functools
def to_type(dtype, t):
if not t.is_cuda:
print("Warning: input tensor was not cuda. Call .cuda() on your data before passing it.")
if t.requires_grad:
print("Warning: input data requires grad. Since input data is not a model parameter,\n"
"its gradients will not be properly allreduced by DDP.")
if t.is_floating_point():
return t.half()
return t
# Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py.
def applier(value, fn):
if isinstance(value, torch.Tensor):
return fn(value)
elif isinstance(value, string_classes):
return value
elif isinstance(value, container_abcs.Mapping):
return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(applier(v, fn) for v in value)
else:
return value
def initialize(optimizers, models, properties):
# Stash master weights before casting the model.
# if properties.master_weights:
if properties.cast_model_type is not None:
if properties.cast_batchnorm is not None:
for model in models:
model.to(properties.cast_model_type)
else:
for model in models:
model.to(properties.cast_model_type)
caster = functools.partial(to_type, properties.cast_model_type)
# Patch the forward method to cast incoming data to the correct type.
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
return old_fwd(*applier(args, caster),
**applier(kwargs, caster))
return new_fwd
model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict())
......@@ -322,6 +322,7 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
print("hook fired")
if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True?
......
......@@ -6,14 +6,14 @@
#include <assert.h>
#include <cuda_runtime.h>
#define BLOCK_SIZE 1024
#define NBLOCKS 160
#define BLOCK_SIZE 256
#define NBLOCKS 160*4
#define ILP 4
// It makes sense to lock the output type to fp32 because the downscaled
// grads should be master grads (and in the case of Amp, the params and their
// gradients should always be fp32.
// gradients should always be fp32).
// This can be optimized with ILP but it's fine for now.
template<typename in_t>
__global__ void scale_reduce_overflow(in_t* in,
float* out,
......@@ -22,12 +22,12 @@ __global__ void scale_reduce_overflow(in_t* in,
volatile int* overflow_global)
{
__shared__ int overflow;
int tid = blockIdx.x*blockDim.x + threadIdx.x;
int stride = gridDim.x*blockDim.x;
float incoming_vals[4];
// Non-divergent exit condition for the __syncthreads
for(int i = tid; i - threadIdx.x < n; i += stride)
for(int chunk_start = blockIdx.x*blockDim.x*ILP;
chunk_start < n;
chunk_start += gridDim.x*blockDim.x*ILP)
{
if(threadIdx.x == 0)
overflow = *overflow_global;
......@@ -37,19 +37,27 @@ __global__ void scale_reduce_overflow(in_t* in,
if(overflow == 1)
break;
if(i < n)
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
float incoming_val = static_cast<float>(in[i]);
if(isfinite(incoming_val))
out[i] = incoming_val*scale;
else
*overflow_global = 1; // Blindly fire off a write. These will race but that's ok.
// This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
// I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
// It's possible we can just lean on the cache (no smem or syncs) and still be fast.
incoming_vals[ii] = 0;
int i = chunk_start + threadIdx.x + ii*blockDim.x;
if(i < n)
incoming_vals[ii] = static_cast<float>(in[i]);
}
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = chunk_start + threadIdx.x + ii*blockDim.x;
if(i < n)
if(isfinite(incoming_vals[ii]))
out[i] = incoming_vals[ii]*scale;
else
*overflow_global = 1; // Blindly fire off a write. These will race but that's ok.
} // This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
} // I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
} // It's possible we can just lean on the cache (no smem or syncs) and still be fast.
void scale_check_overflow_cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment