Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
75139ca3
Commit
75139ca3
authored
Apr 25, 2019
by
Michael Carilli
Browse files
let's see
parent
e0f2ffa5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
126 additions
and
137 deletions
+126
-137
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+32
-43
apex/optimizers/fused_sgd.py
apex/optimizers/fused_sgd.py
+57
-60
csrc/multi_tensor_sgd_kernel.cu
csrc/multi_tensor_sgd_kernel.cu
+37
-34
No files found.
apex/amp/_process_optimizer.py
View file @
75139ca3
...
@@ -11,6 +11,20 @@ class AmpOptimizerState(object):
...
@@ -11,6 +11,20 @@ class AmpOptimizerState(object):
pass
pass
def
_master_params_to_model_params
(
self
):
stash
=
self
.
_amp_stash
if
multi_tensor_applier
.
available
:
if
len
(
stash
.
all_fp16_params
)
>
0
:
multi_tensor_applier
(
stash
.
multi_tensor_scale
,
stash
.
dummy_overflow_buf
,
[
stash
.
all_fp32_from_fp16_params
,
stash
.
all_fp16_params
],
1.0
)
else
:
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
stash
.
fp16_groups
,
stash
.
fp32_from_fp16_groups
):
master_params_to_model_params
(
fp16_group
,
fp32_from_fp16_group
)
def
lazy_init_with_master_weights
(
self
):
def
lazy_init_with_master_weights
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
stash
.
fp16_groups
=
[]
stash
.
fp16_groups
=
[]
...
@@ -277,6 +291,8 @@ def post_backward_no_master_weights_FusedAdam(self, scaler):
...
@@ -277,6 +291,8 @@ def post_backward_no_master_weights_FusedAdam(self, scaler):
# implementations until I have them both working.
# implementations until I have them both working.
#####################################################################################
#####################################################################################
# FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params
# outside the kernel, so we must accumulate directly into the model grads.
def
prepare_backward_with_master_weights_FusedSGD
(
self
):
def
prepare_backward_with_master_weights_FusedSGD
(
self
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
...
@@ -284,60 +300,33 @@ def prepare_backward_with_master_weights_FusedSGD(self):
...
@@ -284,60 +300,33 @@ def prepare_backward_with_master_weights_FusedSGD(self):
self
.
_lazy_init_maybe_master_weights
()
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
stash
.
lazy_init_called
=
True
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
stash
.
all_fp16_grad_stash
[
i
]
=
param
.
grad
# Set up to leverage grad copy elision:
param
.
grad
=
None
for
i
,
param
in
enumerate
(
stash
.
all_fp32_from_fp32_params
):
stash
.
all_fp32_from_fp32_grad_stash
[
i
]
=
param
.
grad
# Set up to leverage grad copy elision:
param
.
grad
=
None
def
post_backward_with_master_weights_FusedSGD
(
self
,
scaler
):
def
post_backward_with_master_weights_FusedSGD
(
self
,
scaler
):
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
grads
=
[[
param
.
grad
.
data
for
param
in
group
]
for
group
in
stash
.
fp16_groups
]
stash
.
output_params
=
[[
param
for
param
in
group
]
for
group
in
stash
.
fp16_groups
]
norm_groups
=
[]
split_types
=
((
stash
.
all_fp16_params
,
stash
.
all_fp16_grad_stash
),
skip
=
False
(
stash
.
all_fp32_from_fp32_params
,
stash
.
all_fp32_from_fp32_grad_stash
))
for
grad_group
in
stash
.
grads
:
norm
=
multi_tensor_applier
(
stash
.
multi_tensor_l2norm
,
stash
.
dummy_overflow_buf
,
[
grad_group
])
# Still syncing here for now.
norm
=
float
(
norm
)
norm_groups
.
append
(
norm
)
if
norm
==
float
(
'inf'
)
or
norm
==
-
float
(
'inf'
)
or
norm
!=
norm
:
skip
=
True
if
skip
:
scaler
.
_overflow_buf
.
fill_
(
1.
)
scaler
.
_has_overflow
=
True
stash
.
grad_norms
=
norm_groups
for
params
,
stashed_grads
in
split_types
:
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
)
def
prepare_backward_no_master_weights_FusedSGD
(
self
):
def
prepare_backward_no_master_weights_FusedSGD
(
self
):
stash
=
self
.
_amp_stash
prepare_backward_no_master_weights
(
self
)
if
not
stash
.
lazy_init_called
:
self
.
_lazy_init_maybe_master_weights
()
stash
.
lazy_init_called
=
True
def
post_backward_no_master_weights_FusedSGD
(
self
,
scaler
):
def
post_backward_no_master_weights_FusedSGD
(
self
,
scaler
):
stash
=
self
.
_amp_stash
post_backward_no_master_weights
(
self
,
scaler
)
stash
.
scale
=
scaler
.
loss_scale
()
stash
.
grads
=
None
stash
.
output_params
=
None
stash
.
grad_norms
=
None
def
_master_params_to_model_params
(
self
):
stash
=
self
.
_amp_stash
if
multi_tensor_applier
.
available
:
if
len
(
stash
.
all_fp16_params
)
>
0
:
multi_tensor_applier
(
stash
.
multi_tensor_scale
,
stash
.
dummy_overflow_buf
,
[
stash
.
all_fp32_from_fp16_params
,
stash
.
all_fp16_params
],
1.0
)
else
:
for
fp16_group
,
fp32_from_fp16_group
in
zip
(
stash
.
fp16_groups
,
stash
.
fp32_from_fp16_groups
):
master_params_to_model_params
(
fp16_group
,
fp32_from_fp16_group
)
def
_process_optimizer
(
optimizer
,
properties
):
def
_process_optimizer
(
optimizer
,
properties
):
...
...
apex/optimizers/fused_sgd.py
View file @
75139ca3
...
@@ -80,6 +80,22 @@ class FusedSGD(Optimizer):
...
@@ -80,6 +80,22 @@ class FusedSGD(Optimizer):
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
group
.
setdefault
(
'nesterov'
,
False
)
group
.
setdefault
(
'nesterov'
,
False
)
def
get_momentums
(
params
):
momentums
=
[]
for
p
in
params
:
param_state
=
self
.
state
[
p
]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if
'momentum_buffer'
not
in
param_state
:
first_run
=
True
buf
=
param_state
[
'momentum_buffer'
]
=
torch
.
zeros_like
(
p
.
data
)
momentums
.
append
(
buf
)
else
:
first_run
=
False
momentums
.
append
(
param_state
[
'momentum_buffer'
])
return
momentums
,
first_run
def
step
(
self
,
closure
=
None
):
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
"""Performs a single optimization step.
...
@@ -91,73 +107,59 @@ class FusedSGD(Optimizer):
...
@@ -91,73 +107,59 @@ class FusedSGD(Optimizer):
if
closure
is
not
None
:
if
closure
is
not
None
:
loss
=
closure
()
loss
=
closure
()
for
group
in
self
.
param_groups
:
explicit_master_params
=
(
hasattr
(
self
,
"_amp_stash"
)
and
hasattr
(
self
.
_amp_stash
,
"fp32_from_fp16_groups"
))
for
gid
,
group
in
enumerate
(
self
.
param_groups
):
weight_decay
=
group
[
'weight_decay'
]
weight_decay
=
group
[
'weight_decay'
]
momentum
=
group
[
'momentum'
]
momentum
=
group
[
'momentum'
]
dampening
=
group
[
'dampening'
]
dampening
=
group
[
'dampening'
]
nesterov
=
group
[
'nesterov'
]
nesterov
=
group
[
'nesterov'
]
params
=
[
p
for
p
in
group
[
'params'
]
if
p
is
not
None
]
grads
=
[
p
.
grad
for
p
in
params
]
# For each group, there are 3 possible combinations we need to consider:
momentums
=
[]
# grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
for
p
in
params
:
param_state
=
self
.
state
[
p
]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if
'momentum_buffer'
not
in
param_state
:
first_run
=
True
buf
=
param_state
[
'momentum_buffer'
]
=
torch
.
zeros_like
(
p
.
data
)
momentums
.
append
(
buf
)
else
:
first_run
=
False
momentums
.
append
(
param_state
[
'momentum_buffer'
])
# We have all parameters now, split them into appropriate groups for
# parallel execution, following the 4 possible combos that the underlying
# kernels support:
# grad_type, param_type, momentum_type, requires_fp16_copy
# 1. fp16, fp16, fp16, No
# 1. fp16, fp16, fp16, No
# 2. fp
16
, fp32, fp32, No
# 2. fp
32
, fp32, fp32, No
# 3. fp16, fp32, fp32, Yes
# 3. fp16, fp32, fp32, Yes
# 4. fp32, fp32, fp32, No
# As in the kernel, easier to hardcode these options
first_runs
=
[
True
,
True
]
# Store only indices into the weight / grad / momentum lists
# I think a bit of code divergence in exchange for naming clarity is worthwhile
# { gradient-type : { param-type : List } | List }
if
explicit_master_params
param_sets
=
{
'fp16'
:
{
'fp16'
:
[],
'fp32'
:
[]
},
'fp32'
:
[]
}
stash
=
self
.
_amp_stash
for
i
,
(
g
,
p
)
in
enumerate
(
zip
(
grads
,
params
)):
fp16_model_params
=
[
p
for
p
in
stash
.
fp16_groups
[
gid
]
if
p
.
grad
is
not
None
]
if
g
.
dtype
==
torch
.
float16
:
fp16_model_grads
=
[
p
.
grad
for
p
in
stash
.
fp16_groups
[
gid
]
if
p
.
grad
is
not
None
]
# fp16 grads, fp16 params
fp32_from_fp16_params
=
[
p
for
i
,
p
in
enumerate
(
if
p
.
dtype
==
torch
.
float16
:
stash
.
fp32_from_fp16_groups
[
gid
])
if
stash
.
fp16_groups
[
gid
][
i
].
grad
is
not
None
]
param_sets
[
'fp16'
][
'fp16'
].
append
(
i
)
fp32_from_fp16_momentums
,
first_runs
[
0
]
=
self
.
get_momentums
(
fp32_from_fp16_params
)
# fp16 grads, fp32 params
elif
p
.
dtype
==
torch
.
float32
:
fp32_params
=
[
p
for
p
in
stash
.
fp32_from_fp32_groups
[
gid
]
if
p
.
grad
is
not
None
]
param_sets
[
'fp16'
][
'fp32'
].
append
(
i
)
fp32_grads
=
[
p
.
grad
for
p
in
stash
.
fp32_from_fp32_groups
[
gid
]
if
p
.
grad
is
not
None
)]
else
:
fp32_momentums
,
first_runs
[
1
]
=
self
.
get_momentums
(
fp32_params
)
raise
RuntimeError
(
'fp16 gradients need either fp16 or fp32 weights'
)
# fp32 grads, fp32 params
launch_sets
=
[[
fp16_model_grads
,
fp32_from_fp16_params
,
fp32_from_fp16_momentums
,
fp16_model_params
],
elif
g
.
dtype
==
torch
.
float32
:
[
fp32_grads
,
fp32_params
,
fp32_momentums
]]
param_sets
[
'fp32'
].
append
(
i
)
else
:
else
:
fp16_params
=
[
p
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
torch
.
float16
and
p
.
grad
is
not
None
)]
raise
RuntimeError
(
'gradients must either be fp16 or fp32'
)
fp16_grads
=
[
p
.
grad
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
torch
.
float16
and
p
.
grad
is
not
None
)]
fp16_momentums
,
first_runs
[
0
]
=
self
.
get_momentums
(
fp16_params
)
def
launch_sgd_set
(
param_set
):
local_params
,
local_grads
,
local_momentums
=
[],
[],
[]
fp32_params
=
[
p
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
torch
.
float32
and
p
.
grad
is
not
None
)]
if
len
(
param_set
)
==
0
:
fp32_grads
=
[
p
.
grad
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
torch
.
float32
and
p
.
grad
is
not
None
)]
return
fp32_momentums
,
first_runs
[
1
]
=
self
.
get_momentums
(
fp32_params
)
# launch update using multi tensor applier
launch_sets
=
[[
fp16_grads
,
fp16_params
,
fp16_momentums
],
# modifies weight and momentum values inplace.
[
fp32_grads
,
fp32_params
,
fp32_momentums
]]
for
launch_set
,
first_run
in
zip
(
launch_sets
,
first_runs
):
multi_tensor_applier
(
multi_tensor_applier
(
self
.
multi_tensor_sgd
,
self
.
multi_tensor_sgd
,
self
.
_dummy_overflow_buf
,
self
.
_dummy_overflow_buf
,
# Note: Need to do this as list comprehensions otherwise
# Note: Need to do this as list comprehensions otherwise
# things don't seem to update properly.
# things don't seem to update properly.
[[
grads
[
i
]
for
i
in
param_set
],
launch_set
,
[
params
[
i
]
for
i
in
param_set
],
[
momentums
[
i
]
for
i
in
param_set
]],
weight_decay
,
weight_decay
,
momentum
,
momentum
,
dampening
,
dampening
,
...
@@ -166,9 +168,4 @@ class FusedSGD(Optimizer):
...
@@ -166,9 +168,4 @@ class FusedSGD(Optimizer):
first_run
,
first_run
,
self
.
wd_after_momentum
)
self
.
wd_after_momentum
)
# Explicitly go over the cases
launch_sgd_set
(
param_sets
[
'fp16'
][
'fp16'
])
launch_sgd_set
(
param_sets
[
'fp16'
][
'fp32'
])
launch_sgd_set
(
param_sets
[
'fp32'
])
return
loss
return
loss
csrc/multi_tensor_sgd_kernel.cu
View file @
75139ca3
...
@@ -150,23 +150,23 @@ void multi_tensor_sgd_cuda(
...
@@ -150,23 +150,23 @@ void multi_tensor_sgd_cuda(
bool
wd_after_momentum
)
bool
wd_after_momentum
)
{
{
auto
num_tensors
=
tensor_lists
.
size
();
auto
num_tensors
=
tensor_lists
.
size
();
auto
grad_type
=
tensor_lists
[
0
][
0
].
type
().
scalar
T
ype
();
auto
grad_type
=
tensor_lists
[
0
][
0
].
scalar
_t
ype
();
auto
weight_type
=
tensor_lists
[
0
][
0
].
type
().
scalar
T
ype
();
auto
weight_type
=
tensor_lists
[
1
][
0
].
scalar
_t
ype
();
// We have
4
po
tential
s to handle here, in terms of
// We have
3
po
ssibilitie
s to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// 1. fp16, fp16, fp16, No
// 2. fp
16
, fp32, fp32, No
// 2. fp
32
, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, No
// It's easier to hardcode these possibilities than to use
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
// Case 1. fp16, fp16, fp16, No
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
weight_type
==
at
::
ScalarType
::
Half
&&
weight_type
==
at
::
ScalarType
::
Half
&&
num_tensors
==
3
)
{
num_tensors
==
3
)
{
multi_tensor_apply
<
3
>
(
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
@@ -182,15 +182,34 @@ void multi_tensor_sgd_cuda(
...
@@ -182,15 +182,34 @@ void multi_tensor_sgd_cuda(
wd_after_momentum
);
wd_after_momentum
);
}
}
// Case 2. fp16, fp32, fp32, No
// Case 2. fp16, fp32, fp32, No
else
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
// else if (grad_type == at::ScalarType::Half &&
weight_type
==
at
::
ScalarType
::
Float
&&
// weight_type == at::ScalarType::Float &&
num_tensors
==
3
)
{
// num_tensors == 3) {
// multi_tensor_apply<3>(
// BLOCK_SIZE,
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No
else
if
(
grad_type
==
at
::
ScalarType
::
Float
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
3
)
{
multi_tensor_apply
<
3
>
(
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
noop_flag
,
noop_flag
,
tensor_lists
,
tensor_lists
,
SGDFunctor
<
3
,
at
::
Half
,
float
>
(),
SGDFunctor
<
3
,
float
,
float
>
(),
wd
,
wd
,
momentum
,
momentum
,
dampening
,
dampening
,
...
@@ -200,9 +219,10 @@ void multi_tensor_sgd_cuda(
...
@@ -200,9 +219,10 @@ void multi_tensor_sgd_cuda(
wd_after_momentum
);
wd_after_momentum
);
}
}
// Case 3. fp16, fp32, fp32, Yes
// Case 3. fp16, fp32, fp32, Yes
else
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
else
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
weight_type
==
at
::
ScalarType
::
Float
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
4
)
{
num_tensors
==
4
)
{
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
@@ -217,25 +237,8 @@ void multi_tensor_sgd_cuda(
...
@@ -217,25 +237,8 @@ void multi_tensor_sgd_cuda(
first_run
,
first_run
,
wd_after_momentum
);
wd_after_momentum
);
}
}
// Case 4. fp32, fp32, fp32, No
else
else
if
(
grad_type
==
at
::
ScalarType
::
Float
&&
{
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
3
)
{
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
3
,
float
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
);
}
else
{
AT_ERROR
(
"multi_tensor_sgd only supports some combinations of gradient & weight types. Given: "
,
AT_ERROR
(
"multi_tensor_sgd only supports some combinations of gradient & weight types. Given: "
,
"gradient: "
,
grad_type
,
", weight: "
,
weight_type
,
", num_lists: "
,
num_tensors
);
"gradient: "
,
grad_type
,
", weight: "
,
weight_type
,
", num_lists: "
,
num_tensors
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment