Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
6af5980e
Commit
6af5980e
authored
Apr 23, 2019
by
Michael Carilli
Browse files
Merging in FusedAdam treatment
parents
16a3bdf3
7aad54f7
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
664 additions
and
250 deletions
+664
-250
README.md
README.md
+4
-4
apex/amp/_initialize.py
apex/amp/_initialize.py
+5
-25
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+111
-72
apex/amp/handle.py
apex/amp/handle.py
+4
-11
apex/optimizers/fused_adam.py
apex/optimizers/fused_adam.py
+67
-15
apex/parallel/distributed.py
apex/parallel/distributed.py
+235
-118
csrc/fused_adam_cuda.cpp
csrc/fused_adam_cuda.cpp
+4
-0
csrc/fused_adam_cuda_kernel.cu
csrc/fused_adam_cuda_kernel.cu
+198
-0
tests/L0/run_mixed_adam/test_mixed_adam.py
tests/L0/run_mixed_adam/test_mixed_adam.py
+36
-5
No files found.
README.md
View file @
6af5980e
# Introduction
This repository holds NVIDIA-maintained utilities to streamline
mixed precision and distributed training in Pytorch.
This repository holds NVIDIA-maintained utilities to streamline
mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually.
The intention of Apex is to make up-to-date utilities available to
The intention of Apex is to make up-to-date utilities available to
users as quickly as possible.
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
...
...
@@ -29,7 +29,7 @@ different flags to `amp.initialize`.
## 2. Distributed Training
`apex.parallel.DistributedDataParallel`
is a module wrapper, similar to
`apex.parallel.DistributedDataParallel`
is a module wrapper, similar to
`torch.nn.parallel.DistributedDataParallel`
. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.
...
...
apex/amp/_initialize.py
View file @
6af5980e
...
...
@@ -114,29 +114,13 @@ def check_optimizers(optimizers):
raise
RuntimeError
(
"An incoming optimizer is an instance of {}. "
.
format
(
optim_type
)
+
"The optimizer(s) passed to amp.initialize() must be bare
\n
"
"instances of either ordinary Pytorch optimizers, or Apex fused
\n
"
"optimizers (
currently just
FusedAdam
, but
FusedSGD
will be added
\n
"
"
soon).
You should not manually wrap your optimizer in either
\n
"
"optimizers (FusedAdam
or
FusedSGD
).
\n
"
"You should not manually wrap your optimizer in either
\n
"
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer.
\n
"
"amp.initialize will take care of that for you (if necessary) based
\n
"
"on the specified opt_level (and optional overridden properties)."
)
def
wrap_fused_adam
(
optimizer
,
properties
):
msg
=
'Currently, the usage of FusedAdam is restricted to '
\
'amp.initialize(..., opt_level="O2", keep_batchnorm_fp32=False, '
\
'loss_scale=float or "dynamic"). We are working on enabling more general usage.'
assert
properties
.
master_weights
is
True
,
msg
assert
properties
.
cast_model_type
is
torch
.
float16
,
msg
assert
(
properties
.
keep_batchnorm_fp32
is
False
or
properties
.
keep_batchnorm_fp32
is
None
),
msg
if
properties
.
loss_scale
==
"dynamic"
:
return
FP16_Optimizer_for_fused
(
optimizer
,
dynamic_loss_scale
=
True
)
else
:
return
FP16_Optimizer_for_fused
(
optimizer
,
static_loss_scale
=
properties
.
loss_scale
)
def
_initialize
(
models
,
optimizers
,
properties
,
num_losses
=
1
,
cast_model_outputs
=
None
):
from
apex.parallel
import
DistributedDataParallel
as
apex_DDP
from
.amp
import
init
as
amp_init
...
...
@@ -163,7 +147,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if
not
_amp_state
.
allow_incoming_model_not_fp32
:
check_params_fp32
(
models
)
check_optimizers
(
optimizers
)
# In the future, when FP16_Optimizer can be deprecated and master weights can
...
...
@@ -196,7 +180,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
model
.
forward
=
patch_forward
(
model
.
forward
)
# State dict trick to recast any preexisting per-param state tensors
# State dict trick to recast any preexisting per-param state tensors
for
optimizer
in
optimizers
:
optimizer
.
load_state_dict
(
optimizer
.
state_dict
())
elif
cast_model_outputs
is
not
None
:
...
...
@@ -212,11 +196,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
model
.
forward
=
patch_forward
(
model
.
forward
)
for
i
,
optimizer
in
enumerate
(
optimizers
):
# Still need to special case this for the first pass
if
isinstance
(
optimizer
,
FusedAdam
):
optimizers
[
i
]
=
wrap_fused_adam
(
optimizer
,
properties
)
else
:
optimizers
[
i
]
=
_process_optimizer
(
optimizer
,
properties
)
optimizers
[
i
]
=
_process_optimizer
(
optimizer
,
properties
)
_amp_state
.
loss_scalers
=
[]
for
_
in
range
(
num_losses
):
...
...
apex/amp/_process_optimizer.py
View file @
6af5980e
...
...
@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params
from
..multi_tensor_apply
import
multi_tensor_applier
from
._amp_state
import
maybe_print
import
torch
from
..optimizers
import
FusedAdam
class
AmpOptimizerState
(
object
):
...
...
@@ -73,6 +74,40 @@ def lazy_init_with_master_weights(self):
self
.
load_state_dict
(
self
.
state_dict
())
def
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
):
# This is a lot of python overhead...
grads_needing_unscale
=
[]
grads_needing_unscale_with_stash
=
[]
stashed
=
[]
for
param
,
stashed_grad
in
zip
(
params
,
stashed_grads
):
if
param
.
grad
is
None
and
stashed_grad
is
not
None
:
param
.
grad
=
stashed_grad
elif
param
.
grad
is
not
None
and
stashed_grad
is
None
:
grads_needing_unscale
.
append
(
param
.
grad
)
elif
param
.
grad
is
not
None
and
stashed_grad
is
not
None
:
grads_needing_unscale_with_stash
.
append
(
param
.
grad
)
stashed
.
append
(
stashed_grad
)
else
:
# param.grad is None and stashed_grad is None
continue
if
len
(
grads_needing_unscale
)
>
0
:
scaler
.
unscale
(
grads_needing_unscale
,
grads_needing_unscale
,
scaler
.
loss_scale
(),
models_are_masters
=
True
)
if
len
(
grads_needing_unscale_with_stash
)
>
0
:
scaler
.
unscale_with_stashed
(
grads_needing_unscale_with_stash
,
stashed
,
grads_needing_unscale_with_stash
)
# Clear the stash.
for
i
in
range
(
len
(
stashed_grads
)):
stashed_grads
[
i
]
=
None
def
prepare_backward_with_master_weights
(
self
):
stash
=
self
.
_amp_stash
...
...
@@ -106,7 +141,7 @@ def post_backward_with_master_weights(self, scaler):
if
fp16_param
.
grad
is
None
and
fp32_param
.
grad
is
not
None
:
continue
elif
fp16_param
.
grad
is
not
None
and
fp32_param
.
grad
is
None
:
fp32_param
.
grad
=
torch
.
empty_like
(
fp32_param
)
fp32_param
.
grad
=
torch
.
empty_like
(
fp32_param
)
fp16_grads_needing_unscale
.
append
(
fp16_param
.
grad
)
new_fp32_grads
.
append
(
fp32_param
.
grad
)
elif
fp16_param
.
grad
is
not
None
and
fp32_param
.
grad
is
not
None
:
...
...
@@ -129,37 +164,10 @@ def post_backward_with_master_weights(self, scaler):
preexisting_fp32_grads
)
# fp32 params can be treated as they would be in the "no_master_weights" case.
grads_needing_unscale
=
[]
grads_needing_unscale_with_stash
=
[]
stashed
=
[]
for
param
,
stashed_grad
in
zip
(
stash
.
all_fp32_from_fp32_params
,
stash
.
all_fp32_from_fp32_grad_stash
):
if
param
.
grad
is
None
and
stashed_grad
is
not
None
:
param
.
grad
=
stashed_grad
elif
param
.
grad
is
not
None
and
stashed_grad
is
None
:
grads_needing_unscale
.
append
(
param
.
grad
)
elif
param
.
grad
is
not
None
and
stashed_grad
is
not
None
:
grads_needing_unscale_with_stash
.
append
(
param
.
grad
)
stashed
.
append
(
stashed_grad
)
else
:
# param.grad is None and stashed_grad is None:
continue
if
len
(
grads_needing_unscale
)
>
0
:
scaler
.
unscale
(
grads_needing_unscale
,
grads_needing_unscale
,
scaler
.
loss_scale
(),
models_are_masters
=
True
)
if
len
(
grads_needing_unscale_with_stash
)
>
0
:
scaler
.
unscale_with_stashed
(
grads_needing_unscale_with_stash
,
stashed
,
grads_needing_unscale_with_stash
)
# Clear the stash.
for
i
in
range
(
len
(
stash
.
all_fp32_from_fp32_grad_stash
)):
stash
.
all_fp32_from_fp32_grad_stash
[
i
]
=
None
post_backward_models_are_masters
(
scaler
,
stash
.
all_fp32_from_fp32_params
,
stash
.
all_fp32_from_fp32_grad_stash
)
def
lazy_init_no_master_weights
(
self
):
...
...
@@ -176,7 +184,7 @@ def lazy_init_no_master_weights(self):
raise
TypeError
(
"Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}"
.
format
(
param
.
type
()))
stash
.
all_fp16_grad_stash
=
[
None
for
_
in
stash
.
all_fp16_params
]
stash
.
all_fp32_grad_stash
=
[
None
for
_
in
stash
.
all_fp32_params
]
...
...
@@ -206,37 +214,56 @@ def post_backward_no_master_weights(self, scaler):
(
stash
.
all_fp32_params
,
stash
.
all_fp32_grad_stash
))
for
params
,
stashed_grads
in
split_types
:
# This is a lot of python overhead...
grads_needing_unscale
=
[]
grads_needing_unscale_with_stash
=
[]
stashed
=
[]
for
param
,
stashed_grad
in
zip
(
params
,
stashed_grads
):
if
param
.
grad
is
None
and
stashed_grad
is
not
None
:
param
.
grad
=
stashed_grad
elif
param
.
grad
is
not
None
and
stashed_grad
is
None
:
grads_needing_unscale
.
append
(
param
.
grad
)
elif
param
.
grad
is
not
None
and
stashed_grad
is
not
None
:
grads_needing_unscale_with_stash
.
append
(
param
.
grad
)
stashed
.
append
(
stashed_grad
)
else
:
# param.grad is None and stashed_grad is None
continue
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
)
if
len
(
grads_needing_unscale
)
>
0
:
scaler
.
unscale
(
grads_needing_unscale
,
grads_needing_unscale
,
scaler
.
loss_scale
(),
models_are_masters
=
True
)
if
len
(
grads_needing_unscale_with_stash
)
>
0
:
scaler
.
unscale_with_stashed
(
grads_needing_unscale_with_stash
,
stashed
,
grads_needing_unscale_with_stash
)
def
prepare_backward_with_master_weights_fused
(
self
):
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
# Clear the stash.
for
i
in
range
(
len
(
stashed_grads
)):
stashed_grads
[
i
]
=
None
def
post_backward_with_master_weights_fused
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
grads
=
[[
param
.
grad
.
data
for
param
in
group
]
for
group
in
stash
.
fp16_groups
]
stash
.
output_params
=
[[
param
for
param
in
group
]
for
group
in
stash
.
fp16_groups
]
norm_groups
=
[]
skip
=
False
for
grad_group
in
stash
.
grads
:
norm
=
multi_tensor_applier
(
stash
.
multi_tensor_l2norm
,
stash
.
dummy_overflow_buf
,
[
grad_group
])
# Still syncing here for now.
norm
=
float
(
norm
)
norm_groups
.
append
(
norm
)
if
norm
==
float
(
'inf'
)
or
norm
==
-
float
(
'inf'
)
or
norm
!=
norm
:
skip
=
True
if
skip
:
scaler
.
_overflow_buf
.
fill_
(
1.
)
scaler
.
_has_overflow
=
True
stash
.
grad_norms
=
norm_groups
def
prepare_backward_no_master_weights_fused
(
self
):
stash
=
self
.
_amp_stash
if
not
stash
.
lazy_init_called
:
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
def
post_backward_no_master_weights_fused
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
grads
=
None
stash
.
output_params
=
None
stash
.
grad_norms
=
None
def
_master_params_to_model_params
(
self
):
...
...
@@ -274,6 +301,7 @@ def _process_optimizer(optimizer, properties):
if
multi_tensor_applier
.
available
:
import
amp_C
optimizer
.
_amp_stash
.
multi_tensor_scale
=
amp_C
.
multi_tensor_scale
optimizer
.
_amp_stash
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
optimizer
.
_amp_stash
.
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
]);
if
properties
.
master_weights
:
...
...
@@ -286,7 +314,8 @@ def _process_optimizer(optimizer, properties):
old_step
=
optimizer
.
step
def
new_step
(
self
):
retval
=
old_step
()
self
.
_master_params_to_model_params
()
if
not
isinstance
(
self
,
FusedAdam
):
self
.
_master_params_to_model_params
()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for
param
in
self
.
_amp_stash
.
all_fp32_from_fp16_params
:
param
.
grad
=
None
...
...
@@ -313,19 +342,29 @@ def _process_optimizer(optimizer, properties):
param
.
grad
=
None
optimizer
.
zero_grad
=
types
.
MethodType
(
new_zero_grad
,
optimizer
)
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_with_master_weights
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_with_master_weights
,
optimizer
)
if
isinstance
(
optimizer
,
FusedAdam
):
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_with_master_weights_fused
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_with_master_weights_fused
,
optimizer
)
else
:
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_with_master_weights
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_with_master_weights
,
optimizer
)
else
:
optimizer
.
_lazy_init_maybe_master_weights
=
types
.
MethodType
(
lazy_init_no_master_weights
,
optimizer
)
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_no_master_weights
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_no_master_weights
,
optimizer
)
if
isinstance
(
optimizer
,
FusedAdam
):
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_no_master_weights_fused
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_no_master_weights_fused
,
optimizer
)
else
:
optimizer
.
_prepare_amp_backward
=
types
.
MethodType
(
prepare_backward_no_master_weights
,
optimizer
)
optimizer
.
_post_amp_backward
=
types
.
MethodType
(
post_backward_no_master_weights
,
optimizer
)
return
optimizer
apex/amp/handle.py
View file @
6af5980e
...
...
@@ -6,8 +6,6 @@ from . import utils
from
.opt
import
OptimWrapper
from
.scaler
import
LossScaler
from
._amp_state
import
_amp_state
,
master_params
,
maybe_print
from
..fp16_utils
import
FP16_Optimizer
as
FP16_Optimizer_general
from
..optimizers
import
FP16_Optimizer
as
FP16_Optimizer_for_fused
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
...
...
@@ -82,13 +80,8 @@ def scale_loss(loss,
if
isinstance
(
optimizers
,
torch
.
optim
.
Optimizer
):
optimizers
=
[
optimizers
]
# this is what happens when i have to support tools from different sources under the same API...
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
if
isinstance
(
optimizers
,
FP16_Optimizer_for_fused
):
loss_scale
=
optimizers
.
cur_scale
else
:
loss_scaler
=
_amp_state
.
loss_scalers
[
loss_id
]
loss_scale
=
loss_scaler
.
loss_scale
()
loss_scaler
=
_amp_state
.
loss_scalers
[
loss_id
]
loss_scale
=
loss_scaler
.
loss_scale
()
if
((
not
_amp_state
.
opt_properties
.
master_weights
)
and
(
not
loss_scaler
.
dynamic
)
...
...
@@ -113,8 +106,8 @@ def scale_loss(loss,
for
optimizer
in
optimizers
:
optimizer
.
_amp_stash
.
params_have_scaled_gradients
=
True
else
:
# FusedAdam and FusedSGD
will
take care of unscaling as part of their step() methods.
if
not
isinstance
(
optimizers
,
FP16_Optimizer_for_fused
):
# FusedAdam and FusedSGD
may
take care of unscaling as part of their step() methods.
#
if not isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scaler
.
clear_overflow_state
()
for
optimizer
in
optimizers
:
optimizer
.
_post_amp_backward
(
loss_scaler
)
...
...
apex/optimizers/fused_adam.py
View file @
6af5980e
...
...
@@ -2,6 +2,8 @@ import types
import
torch
import
importlib
from
..multi_tensor_apply
import
multi_tensor_applier
class
FusedAdam
(
torch
.
optim
.
Optimizer
):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
...
...
@@ -25,6 +27,8 @@ class FusedAdam(torch.optim.Optimizer):
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
...
...
@@ -35,10 +39,21 @@ class FusedAdam(torch.optim.Optimizer):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
eps_inside_sqrt
=
False
,
weight_decay
=
0.
,
max_grad_norm
=
0.
,
amsgrad
=
False
):
weight_decay
=
0.
,
max_grad_norm
=
0.
,
amsgrad
=
False
,
use_mt
=
False
,
amp_scale_adjustment
=
1.0
):
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
self
.
_use_multi_tensor
=
False
if
use_mt
:
if
not
multi_tensor_applier
.
available
:
print
(
"Warning: multi_tensor_applier is unavailable"
)
else
:
self
.
_use_multi_tensor
=
True
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_amp_scale_adjustment
=
amp_scale_adjustment
if
amsgrad
:
raise
RuntimeError
(
'FusedAdam does not support the AMSGrad variant.'
)
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
...
...
@@ -66,6 +81,12 @@ class FusedAdam(torch.optim.Optimizer):
if
closure
is
not
None
:
loss
=
closure
()
if
hasattr
(
self
,
"_amp_stash"
):
grads
=
self
.
_amp_stash
.
grads
output_params
=
self
.
_amp_stash
.
output_params
scale
=
self
.
_amp_stash
.
scale
*
self
.
_amp_scale_adjustment
grad_norms
=
self
.
_amp_stash
.
grad_norms
if
grads
is
None
:
grads_group
=
[
None
]
*
len
(
self
.
param_groups
)
# backward compatibility
...
...
@@ -105,6 +126,12 @@ class FusedAdam(torch.optim.Optimizer):
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
if
self
.
_use_multi_tensor
:
if
output_params
:
tensorlists
=
[[],[],[],[],[]]
else
:
tensorlists
=
[[],[],[],[]]
for
p
,
grad
,
output_param
in
zip
(
group
[
'params'
],
grads_this_group
,
output_params_this_group
):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if
p
.
grad
is
None
and
grad
is
None
:
...
...
@@ -130,18 +157,43 @@ class FusedAdam(torch.optim.Optimizer):
state
[
'step'
]
+=
1
out_p
=
torch
.
tensor
([],
dtype
=
torch
.
float
)
if
output_param
is
None
else
output_param
fused_adam_cuda
.
adam
(
p
.
data
,
out_p
,
exp_avg
,
exp_avg_sq
,
grad
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
state
[
'step'
],
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
if
self
.
_use_multi_tensor
:
pl
=
[
p
.
data
,
exp_avg
,
exp_avg_sq
,
grad
]
if
output_param
is
not
None
:
pl
.
append
(
out_p
)
for
tl
,
t
in
zip
(
tensorlists
,
pl
):
tl
.
append
(
t
)
else
:
fused_adam_cuda
.
adam
(
p
.
data
,
out_p
,
exp_avg
,
exp_avg_sq
,
grad
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
state
[
'step'
],
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
if
self
.
_use_multi_tensor
:
multi_tensor_applier
(
fused_adam_cuda
.
adam_mt
,
self
.
_overflow_buf
,
tensorlists
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
state
[
'step'
],
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
return
loss
apex/parallel/distributed.py
View file @
6af5980e
This diff is collapsed.
Click to expand it.
csrc/fused_adam_cuda.cpp
View file @
6af5980e
...
...
@@ -3,6 +3,9 @@
// CUDA forward declaration
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
...
...
@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
}
csrc/fused_adam_cuda_kernel.cu
View file @
6af5980e
...
...
@@ -9,6 +9,10 @@
#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
#include "type_shim.h"
...
...
@@ -55,6 +59,93 @@ __global__ void adam_cuda_kernel(
}
}
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
struct
AdamFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
DEPTH
>&
tl
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
adamMode_t
mode
,
const
float
decay
)
{
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
p_copy
=
NULL
;
if
(
DEPTH
==
5
)
{
p_copy
=
(
GRAD_T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
p_copy
+=
chunk_idx
*
chunk_size
;
}
n
-=
chunk_idx
*
chunk_size
;
T
incoming_p
[
ILP
];
T
incoming_m
[
ILP
];
T
incoming_v
[
ILP
];
T
incoming_g
[
ILP
];
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
incoming_p
[
ii
]
=
0
;
incoming_m
[
ii
]
=
0
;
incoming_v
[
ii
]
=
0
;
incoming_g
[
ii
]
=
0
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
incoming_p
[
ii
]
=
p
[
i
];
incoming_m
[
ii
]
=
m
[
i
];
incoming_v
[
ii
]
=
v
[
i
];
incoming_g
[
ii
]
=
static_cast
<
T
>
(
g
[
i
]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
n
&&
j
<
chunk_size
)
{
T
scaled_grad
=
incoming_g
[
ii
]
/
grad_scale
;
m
[
j
]
=
b1
*
incoming_m
[
ii
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
incoming_v
[
ii
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
float
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
incoming_p
[
ii
]);
p
[
j
]
=
incoming_p
[
ii
]
-
(
step_size
*
update
);
if
(
DEPTH
==
5
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
}
}
}
}
};
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
...
...
@@ -135,3 +226,110 @@ void fused_adam_cuda(
THCudaCheck
(
cudaGetLastError
());
}
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
tl_sz
=
tensor_lists
.
size
();
AT_ASSERTM
(
tl_sz
==
4
||
tl_sz
==
5
,
"expected tensor lists of size 4 or 5"
);
if
(
tensor_lists
[
3
][
0
].
type
().
scalarType
()
==
at
::
ScalarType
::
Half
)
{
//alher values should be fp32 for half gradients
AT_ASSERTM
(
tensor_lists
[
0
][
0
].
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dich is done on the gradient type
if
(
tl_sz
==
5
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
tensor_lists
[
3
][
0
].
type
(),
"adam_cuda_mt_kernel"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
accscalar_t
,
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
}));
}
else
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
tensor_lists
[
3
][
0
].
type
(),
"adam_cuda_mt_kernel"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
accscalar_t
,
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
}));
}
}
else
{
if
(
tl_sz
==
5
)
{
AT_DISPATCH_FLOATING_TYPES
(
tensor_lists
[
3
][
0
].
type
(),
"adam_cuda_mt_kernel"
,
([
&
]
{
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
scalar_t
,
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
}));
}
else
{
AT_DISPATCH_FLOATING_TYPES
(
tensor_lists
[
3
][
0
].
type
(),
"adam_cuda_mt_kernel"
,
([
&
]
{
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
scalar_t
,
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
}));
}
}
THCudaCheck
(
cudaGetLastError
());
}
tests/L0/run_mixed_adam/test_mixed_adam.py
View file @
6af5980e
...
...
@@ -15,15 +15,18 @@ class TestFusedAdam(unittest.TestCase):
def
tearDown
(
self
):
pass
def
gen_param_optim
(
self
,
tensors
,
adam_option
):
def
gen_param_optim
(
self
,
tensors
,
ref_
adam_option
,
tst_adam_option
=
None
):
ref_param
=
[]
tst_param
=
[]
for
tensor
in
tensors
:
ref_param
.
append
(
torch
.
nn
.
Parameter
(
tensor
.
clone
()))
tst_param
.
append
(
torch
.
nn
.
Parameter
(
tensor
.
clone
()))
ref_optim
=
torch
.
optim
.
Adam
(
ref_param
,
**
adam_option
)
tst_optim
=
apex
.
optimizers
.
FusedAdam
(
tst_param
,
**
adam_option
)
ref_optim
=
torch
.
optim
.
Adam
(
ref_param
,
**
ref_adam_option
)
if
tst_adam_option
:
tst_optim
=
apex
.
optimizers
.
FusedAdam
(
tst_param
,
**
tst_adam_option
)
else
:
tst_optim
=
apex
.
optimizers
.
FusedAdam
(
tst_param
,
**
ref_adam_option
)
return
(
ref_param
,
tst_param
,
ref_optim
,
tst_optim
)
...
...
@@ -42,8 +45,8 @@ class TestFusedAdam(unittest.TestCase):
def
get_max_diff
(
self
,
ref_param
,
tst_param
):
max_abs_diff
=
max_rel_diff
=
0
for
p_ref
,
p_tst
in
zip
(
ref_param
,
tst_param
):
max_abs_diff_p
=
(
p_ref
-
p_tst
).
abs
().
max
().
item
()
max_rel_diff_p
=
((
p_ref
-
p_tst
)
/
p_ref
).
abs
().
max
().
item
()
max_abs_diff_p
=
(
p_ref
-
p_tst
.
type
(
p_ref
.
type
())
).
abs
().
max
().
item
()
max_rel_diff_p
=
((
p_ref
-
p_tst
.
type
(
p_ref
.
type
())
)
/
p_ref
).
abs
().
max
().
item
()
if
max_abs_diff_p
>
max_abs_diff
:
max_abs_diff
=
max_abs_diff_p
if
max_rel_diff_p
>
max_rel_diff
:
max_rel_diff
=
max_rel_diff_p
...
...
@@ -173,6 +176,34 @@ class TestFusedAdam(unittest.TestCase):
self
.
assertLessEqual
(
max_abs_diff
,
self
.
max_abs_diff
)
self
.
assertLessEqual
(
max_rel_diff
,
self
.
max_rel_diff
)
def
test_multi_tensor
(
self
):
sizes
=
[[
4096
,
1024
],
[
4096
],
[
4096
,
2048
],
[
32320
,
1024
],
[
1
]]
ref_adam_option
=
{
'lr'
:
5e-4
,
'betas'
:(
0.9
,
0.999
),
'eps'
:
1e-08
,
'weight_decay'
:
0
,
'amsgrad'
:
False
}
tst_adam_option
=
dict
(
ref_adam_option
,
**
{
'use_mt'
:
True
})
tensors
=
[]
fp16_params
=
[]
for
size
in
sizes
:
tensors
.
append
(
torch
.
rand
(
size
,
dtype
=
torch
.
float
,
device
=
'cuda'
))
fp16_params
.
append
(
torch
.
nn
.
Parameter
(
tensors
[
-
1
].
clone
().
half
()))
ref_param
,
tst_param
,
ref_optim
,
tst_optim
=
\
self
.
gen_param_optim
(
tensors
,
ref_adam_option
,
tst_adam_option
)
for
i
in
range
(
self
.
iters
):
half_grads
=
self
.
gen_mixed_grad
(
ref_param
,
tst_param
)
ref_optim
.
step
()
tst_optim
.
step
(
grads
=
half_grads
,
output_params
=
fp16_params
)
max_abs_diff
,
max_rel_diff
=
self
.
get_max_diff
(
ref_param
,
tst_param
)
self
.
assertLessEqual
(
max_abs_diff
,
self
.
max_abs_diff
)
self
.
assertLessEqual
(
max_rel_diff
,
self
.
max_rel_diff
)
max_abs_diff
,
max_rel_diff
=
self
.
get_max_diff
(
tst_param
,
\
fp16_params
)
self
.
assertLessEqual
(
max_abs_diff
,
self
.
max_abs_diff
)
self
.
assertLessEqual
(
max_rel_diff
,
self
.
max_rel_diff
)
if
__name__
==
'__main__'
:
script_path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment