Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
1f693b92
Commit
1f693b92
authored
Feb 08, 2019
by
Michael Carilli
Browse files
stashing work
parent
b2f63c48
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
340 additions
and
20 deletions
+340
-20
apex/amp/__init__.py
apex/amp/__init__.py
+2
-1
apex/amp/amp.py
apex/amp/amp.py
+11
-0
apex/amp/frontend.py
apex/amp/frontend.py
+242
-0
apex/amp/initialize.py
apex/amp/initialize.py
+57
-0
apex/parallel/distributed.py
apex/parallel/distributed.py
+1
-0
csrc/scale_check_overflow_kernel.cu
csrc/scale_check_overflow_kernel.cu
+27
-19
No files found.
apex/amp/__init__.py
View file @
1f693b92
from
.amp
import
init
,
half_function
,
float_function
,
promote_function
,
\
register_half_function
,
register_float_function
,
register_promote_function
register_half_function
,
register_float_function
,
register_promote_function
,
\
register
apex/amp/amp.py
View file @
1f693b92
from
.
import
compat
,
rnn_compat
,
utils
,
wrap
from
.handle
import
AmpHandle
,
NoOpHandle
from
.lists
import
functional_overrides
,
torch_overrides
,
tensor_overrides
from
..fp16_utils
import
FP16_Optimizer
from
.frontend
import
*
import
functools
import
itertools
import
torch
_DECORATOR_HANDLE
=
None
_USER_CAST_REGISTRY
=
set
()
_USER_PROMOTE_REGISTRY
=
set
()
def
_decorator_helper
(
orig_fn
,
cast_fn
,
wrap_fn
):
def
wrapper
(
*
args
,
**
kwargs
):
handle
=
_DECORATOR_HANDLE
...
...
@@ -21,19 +25,23 @@ def _decorator_helper(orig_fn, cast_fn, wrap_fn):
return
wrap_fn
(
orig_fn
,
inner_cast_fn
,
handle
)(
*
args
,
**
kwargs
)
return
wrapper
# Decorator form
def
half_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
True
)
return
_decorator_helper
(
fn
,
utils
.
maybe_half
,
wrap_fn
)
def
float_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_cast_wrapper
,
try_caching
=
False
)
return
_decorator_helper
(
fn
,
utils
.
maybe_float
,
wrap_fn
)
def
promote_function
(
fn
):
wrap_fn
=
functools
.
partial
(
wrap
.
make_promote_wrapper
)
return
_decorator_helper
(
fn
,
utils
.
maybe_float
,
wrap_fn
)
# Registry form
def
register_half_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
...
...
@@ -41,18 +49,21 @@ def register_half_function(module, name):
name
,
module
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_half
))
def
register_float_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
raise
ValueError
(
'No function named {} in module {}.'
.
format
(
name
,
module
))
_USER_CAST_REGISTRY
.
add
((
module
,
name
,
utils
.
maybe_float
))
def
register_promote_function
(
module
,
name
):
if
not
hasattr
(
module
,
name
):
raise
ValueError
(
'No function named {} in module {}.'
.
format
(
name
,
module
))
_USER_PROMOTE_REGISTRY
.
add
((
module
,
name
))
# Top-level function to insert _all_ the hooks.
def
init
(
enabled
=
True
,
enable_caching
=
True
,
verbose
=
False
,
allow_banned
=
False
):
global
_DECORATOR_HANDLE
...
...
apex/amp/frontend.py
0 → 100644
View file @
1f693b92
import
torch
from
.initialize
import
initialize
class
Properties
(
object
):
"""
The purpose of this class is twofold: to establish a set of default properties,
and to route setting of these attributes through __setattr__ so that (in theory)
they can be checked for consistency with other existing args.
"""
def
__init__
(
self
):
self
.
options
=
{
"opt_level"
:
None
,
"cast_model_type"
:
None
,
"cast_torch_functions"
:
False
,
"cast_batchnorm"
:
None
,
"master_weights"
:
False
,
"loss_scale"
:
1.0
,
"flatten_model_params"
:
False
,
"flatten_master_params"
:
False
,
"enable_ddp_interop"
:
False
}
"""
This function will allow updating several options at a time without routing through
__setattr__ checks, to avoid "you can't get there from here" scenarios.
"""
def
update_options_dict
(
new_options
):
for
k
,
v
in
new_options
:
if
k
in
self
.
options
:
self
.
options
[
k
]
=
v
else
:
raise
ValueError
(
"Tried to set unexpected option {}"
.
format
(
k
))
"""
The members of options are not direct attributes of self, so __getattr__ is ok.
This borrows from the logic in torch.nn.Module.
"""
def
__getattr__
(
self
,
name
):
if
"options"
in
self
.
__dict__
:
options
=
self
.
__dict__
[
"options"
]
if
name
in
options
:
return
options
[
name
]
raise
AttributeError
(
"'{}' object has no attribute '{}'"
.
format
(
type
(
self
).
__name__
,
name
))
def
__setattr__
(
self
,
name
,
value
):
if
"options"
in
self
.
__dict__
:
if
name
in
self
.
options
:
print
(
"setting {}"
.
format
(
name
))
self
.
options
[
name
]
=
value
else
:
super
(
Properties
,
self
).
__setattr__
(
name
,
value
)
""" O0-O3 are convenience wrappers to establish defaults for typically used mixed precision options. """
class
O3
:
brief
=
"O3: Pure FP16 training."
more
=
"Calls .half() on your model, converting the entire model to FP16.
\n
"
\
"A casting operation is also inserted to cast incoming Tensors to FP16,
\n
"
\
"so you don't need to change your data pipeline.
\n
"
\
"This mode is useful for establishing a performance ceiling.
\n
"
\
"It's also possible training may 'just work' in this mode.
\n
"
\
"If not, try other optimization levels."
def
__call__
(
self
,
properties
):
properties
.
opt_level
=
"O3"
,
properties
.
cast_model_type
=
torch
.
float16
properties
.
cast_torch_functions
=
False
properties
.
cast_batchnorm
=
False
properties
.
master_weights
=
False
properties
.
loss_scale
=
1.0
properties
.
flatten_model_params
=
False
properties
.
flatten_master_params
=
False
properties
.
enable_ddp_interop
=
False
return
properties
# modified in place so this isn't really necessary
class
O2
:
brief
=
"O2: FP16 training with FP32 batchnorm and FP32 master weights.
\n
"
more
=
"Calls .half() on your model, converting the entire model (except for batchnorms)
\n
"
\
"to FP16. Batchnorms are retained in FP32 for additional stability.
\n
"
\
"The forward pass is patched to cast incoming Tensors to FP16, so you don't need to change
\n
"
\
"your data pipeline.
\n
"
\
"O2 creates FP32 master weights outside the model and patches any optimizers to update
\n
"
\
"these master weights, then copy the master weights into the FP16 model weights.
\n
"
\
"Master weights can also improve convergence and stability."
def
__call__
(
self
,
properties
):
properties
.
opt_level
=
"O2"
,
properties
.
cast_model_type
=
torch
.
float16
properties
.
cast_torch_functions
=
False
properties
.
cast_batchnorm
=
torch
.
float32
properties
.
master_weights
=
True
properties
.
loss_scale
=
128.0
properties
.
flatten_model_params
=
False
properties
.
flatten_master_params
=
False
properties
.
enable_ddp_interop
=
False
return
properties
# modified in place so this isn't really necessary
class
O1
:
brief
=
"O1: Insert automatic casts around Pytorch functions and Tensor methods.
\n
"
more
=
"The type of your model's weights is not altered. However, internally,
\n
"
\
"Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16 for speed,
\n
"
\
"while operations that might benefit from the additional stability of FP32 are patched
\n
"
\
"to cast their inputs to fp32.
\n
"
\
"O1 is the safest way to try mixed precision training, and is recommended when
\n
"
\
"trying mixed precision training for the first time."
def
__call__
(
self
,
properties
):
properties
.
opt_level
=
"O1"
,
properties
.
cast_model_type
=
False
properties
.
cast_torch_functions
=
True
properties
.
cast_batchnorm
=
False
properties
.
master_weights
=
False
properties
.
loss_scale
=
"dynamic"
properties
.
flatten_model_params
=
False
properties
.
flatten_master_params
=
False
properties
.
enable_ddp_interop
=
False
return
properties
# modified in place so this isn't really necessary
class
O0
:
brief
=
"O0: Pure FP32 training.
\n
"
more
=
"Your models are checked to make sure parameters are FP32, but otherwise the
\n
"
\
"types of weights and internal Pytorch operations are not altered. This mode disables any
\n
"
\
"FP16 arithmetic, although other optimizations like parameter flattening and DDP interop
\n
"
\
"may still be requested.
\n
"
def
__call__
(
self
,
properties
):
properties
.
opt_level
=
"O0"
,
properties
.
cast_model_type
=
torch
.
float32
properties
.
cast_torch_functions
=
False
properties
.
cast_batchnorm
=
False
properties
.
master_weights
=
False
properties
.
loss_scale
=
1.0
properties
.
flatten_model_params
=
False
properties
.
flatten_master_params
=
False
properties
.
enable_ddp_interop
=
False
return
properties
# modified in place so this isn't really necessary
opt_levels
=
{
"O3"
:
O3
(),
"O2"
:
O2
(),
"O1"
:
O1
(),
"O0"
:
O0
()}
def
check_params_fp32
(
model
):
for
name
,
param
in
model
.
named_parameters
():
if
param
.
type
()
!=
"torch.cuda.FloatTensor"
:
print
(
"Warning: Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.register, you do not need to call .half() on your model
\n
"
"before passing it, no matter what optimization level you choose."
,
name
,
param
.
type
())
for
name
,
param
in
model
.
named_buffers
():
if
param
.
type
()
!=
"torch.cuda.FloatTensor"
:
print
(
"Warning: Found buffer {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.register, you do not need to call .half() on your model
\n
"
"before passing it, no matter what optimization level you choose."
,
name
,
param
.
type
())
# allow user to directly pass Properties struct as well?
def
register
(
enabled
=
False
,
optimizers
=
None
,
models
=
None
,
opt_level
=
None
,
cast_model_type
=
None
,
cast_torch_functions
=
None
,
cast_batchnorm
=
None
,
master_weights
=
None
,
loss_scale
=
None
,
flatten_model_params
=
None
,
flatten_master_params
=
None
,
enable_ddp_interop
=
None
):
if
not
enabled
:
return
if
opt_level
not
in
opt_levels
:
raise
RuntimeError
(
"Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'."
)
else
:
amp
.
opt_properties
=
opt_levels
[
opt_level
](
Properties
())
print
(
"Selected optimization level {}"
,
opt_levels
[
opt_level
].
brief
)
print
(
"Defaults for this optimization level are:"
)
for
k
,
v
in
amp
.
opt_properties
.
options
:
print
(
"{:20} : {}"
,
k
,
v
)
for
model
in
models
:
check_params_fp32
(
model
)
print
(
"Processing user overrides (additional kwargs that are not None)..."
)
for
k
,
v
in
kwargs
:
if
v
is
not
None
:
setattr
(
amp
.
opt_properties
,
k
,
v
)
print
(
"After processing overrides, optimization options are:"
)
for
k
,
v
in
amp
.
opt_properties
.
options
:
print
(
"{:20} : {}"
,
k
,
v
)
initialize
(
optimizers
,
models
)
def
check_option_consistency
(
enabled
=
False
,
opt_level
=
None
,
cast_model_type
=
None
,
cast_torch_functions
=
None
,
cast_batchnorm
=
None
,
master_weights
=
None
,
loss_scale
=
None
,
flatten_model_params
=
None
,
flatten_master_params
=
None
,
enable_ddp_interop
=
None
):
"""
Utility function that enables users to quickly check if the option combination they intend
to use is permitted. ``check_option_consistency`` does not require models or optimizers
to be constructed, and can be called at any point in the script. ``check_option_consistency``
is totally self-contained; it does not set any amp global state or affect anything outside
of itself.
"""
if
not
enabled
:
return
if
opt_level
not
in
opt_levels
:
raise
RuntimeError
(
"Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'."
)
else
:
opt_properties
=
opt_levels
[
opt_level
](
Properties
())
print
(
"Selected optimization level {}"
,
opt_levels
[
opt_level
].
brief
)
print
(
"Defaults for this optimization level are:"
)
for
k
,
v
in
opt_properties
.
options
:
print
(
"{:20} : {}"
,
k
,
v
)
print
(
"Processing user overrides (additional kwargs that are not None)..."
)
for
k
,
v
in
kwargs
:
if
v
is
not
None
:
setattr
(
opt_properties
,
k
,
v
)
print
(
"After processing overrides, optimization options are:"
)
for
k
,
v
in
opt_properties
.
options
:
print
(
"{:20} : {}"
,
k
,
v
)
apex/amp/initialize.py
0 → 100644
View file @
1f693b92
import
torch
from
torch._six
import
container_abcs
,
string_classes
import
functools
def
to_type
(
dtype
,
t
):
if
not
t
.
is_cuda
:
print
(
"Warning: input tensor was not cuda. Call .cuda() on your data before passing it."
)
if
t
.
requires_grad
:
print
(
"Warning: input data requires grad. Since input data is not a model parameter,
\n
"
"its gradients will not be properly allreduced by DDP."
)
if
t
.
is_floating_point
():
return
t
.
half
()
return
t
# Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py.
def
applier
(
value
,
fn
):
if
isinstance
(
value
,
torch
.
Tensor
):
return
fn
(
value
)
elif
isinstance
(
value
,
string_classes
):
return
value
elif
isinstance
(
value
,
container_abcs
.
Mapping
):
return
{
applier
(
k
,
fn
)
:
applier
(
v
,
fn
)
for
k
,
v
in
value
.
items
()}
elif
isinstance
(
value
,
container_abcs
.
Iterable
):
return
type
(
value
)(
applier
(
v
,
fn
)
for
v
in
value
)
else
:
return
value
def
initialize
(
optimizers
,
models
,
properties
):
# Stash master weights before casting the model.
# if properties.master_weights:
if
properties
.
cast_model_type
is
not
None
:
if
properties
.
cast_batchnorm
is
not
None
:
for
model
in
models
:
model
.
to
(
properties
.
cast_model_type
)
else
:
for
model
in
models
:
model
.
to
(
properties
.
cast_model_type
)
caster
=
functools
.
partial
(
to_type
,
properties
.
cast_model_type
)
# Patch the forward method to cast incoming data to the correct type.
def
patch_forward
(
old_fwd
):
def
new_fwd
(
*
args
,
**
kwargs
):
return
old_fwd
(
*
applier
(
args
,
caster
),
**
applier
(
kwargs
,
caster
))
return
new_fwd
model
.
forward
=
patch_forward
(
model
.
forward
)
# State dict trick to recast any preexisting per-param state tensors
for
optimizer
in
optimizers
:
optimizer
.
load_state_dict
(
optimizer
.
state_dict
())
apex/parallel/distributed.py
View file @
1f693b92
...
...
@@ -322,6 +322,7 @@ class DistributedDataParallel(Module):
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
allreduce_hook
(
*
unused
):
print
(
"hook fired"
)
if
self
.
delay_allreduce
or
self
.
needs_refresh
:
# TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True?
...
...
csrc/scale_check_overflow_kernel.cu
View file @
1f693b92
...
...
@@ -6,14 +6,14 @@
#include <assert.h>
#include <cuda_runtime.h>
#define BLOCK_SIZE 1024
#define NBLOCKS 160
#define BLOCK_SIZE 256
#define NBLOCKS 160*4
#define ILP 4
// It makes sense to lock the output type to fp32 because the downscaled
// grads should be master grads (and in the case of Amp, the params and their
// gradients should always be fp32.
// gradients should always be fp32
)
.
// This can be optimized with ILP but it's fine for now.
template
<
typename
in_t
>
__global__
void
scale_reduce_overflow
(
in_t
*
in
,
float
*
out
,
...
...
@@ -22,12 +22,12 @@ __global__ void scale_reduce_overflow(in_t* in,
volatile
int
*
overflow_global
)
{
__shared__
int
overflow
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
float
incoming_vals
[
4
];
// Non-divergent exit condition for the __syncthreads
for
(
int
i
=
tid
;
i
-
threadIdx
.
x
<
n
;
i
+=
stride
)
for
(
int
chunk_start
=
blockIdx
.
x
*
blockDim
.
x
*
ILP
;
chunk_start
<
n
;
chunk_start
+=
gridDim
.
x
*
blockDim
.
x
*
ILP
)
{
if
(
threadIdx
.
x
==
0
)
overflow
=
*
overflow_global
;
...
...
@@ -37,19 +37,27 @@ __global__ void scale_reduce_overflow(in_t* in,
if
(
overflow
==
1
)
break
;
if
(
i
<
n
)
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
float
incoming_val
=
static_cast
<
float
>
(
in
[
i
]);
if
(
isfinite
(
incoming_val
))
out
[
i
]
=
incoming_val
*
scale
;
else
*
overflow_global
=
1
;
// Blindly fire off a write. These will race but that's ok.
// This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
// I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
// It's possible we can just lean on the cache (no smem or syncs) and still be fast.
incoming_vals
[
ii
]
=
0
;
int
i
=
chunk_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
)
incoming_vals
[
ii
]
=
static_cast
<
float
>
(
in
[
i
]);
}
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
chunk_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
)
if
(
isfinite
(
incoming_vals
[
ii
]))
out
[
i
]
=
incoming_vals
[
ii
]
*
scale
;
else
*
overflow_global
=
1
;
// Blindly fire off a write. These will race but that's ok.
}
// This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
}
// I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
}
// It's possible we can just lean on the cache (no smem or syncs) and still be fast.
void
scale_check_overflow_cuda
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment