Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c763f0fe
Commit
c763f0fe
authored
May 09, 2019
by
Michael Carilli
Browse files
materialize_master_weights for FusedSGD
parent
f3528d99
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
64 additions
and
34 deletions
+64
-34
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+25
-17
apex/optimizers/fused_sgd.py
apex/optimizers/fused_sgd.py
+26
-10
csrc/amp_C_frontend.cpp
csrc/amp_C_frontend.cpp
+2
-1
csrc/multi_tensor_sgd_kernel.cu
csrc/multi_tensor_sgd_kernel.cu
+11
-6
No files found.
apex/amp/_process_optimizer.py
View file @
c763f0fe
...
@@ -130,7 +130,8 @@ def prepare_backward_with_master_weights(self):
...
@@ -130,7 +130,8 @@ def prepare_backward_with_master_weights(self):
self
.
_amp_lazy_init
()
self
.
_amp_lazy_init
()
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
# Set up to leverage grad copy elision:
# Set up to leverage grad copy elision.
# This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused.
param
.
grad
=
None
param
.
grad
=
None
# for i, param in enumerate(stash.all_fp32_from_fp16_params):
# for i, param in enumerate(stash.all_fp32_from_fp16_params):
...
@@ -298,31 +299,38 @@ def post_backward_no_master_weights_FusedAdam(self, scaler):
...
@@ -298,31 +299,38 @@ def post_backward_no_master_weights_FusedAdam(self, scaler):
# FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params
# FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params
# outside the kernel, so we must accumulate directly into the model grads.
# outside the kernel, so we must accumulate directly into the model grads.
def
prepare_backward_with_master_weights_FusedSGD
(
self
):
def
prepare_backward_with_master_weights_FusedSGD
(
self
):
stash
=
self
.
_amp_stash
if
self
.
materialize_master_grads
:
prepare_backward_with_master_weights
(
self
)
else
:
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
self
.
_amp_lazy_init
()
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
for
i
,
param
in
enumerate
(
stash
.
all_fp16_params
):
stash
.
all_fp16_grad_stash
[
i
]
=
param
.
grad
stash
.
all_fp16_grad_stash
[
i
]
=
param
.
grad
# Set up to leverage grad copy elision:
# Set up to leverage grad copy elision:
param
.
grad
=
None
param
.
grad
=
None
for
i
,
param
in
enumerate
(
stash
.
all_fp32_from_fp32_params
):
for
i
,
param
in
enumerate
(
stash
.
all_fp32_from_fp32_params
):
stash
.
all_fp32_from_fp32_grad_stash
[
i
]
=
param
.
grad
stash
.
all_fp32_from_fp32_grad_stash
[
i
]
=
param
.
grad
# Set up to leverage grad copy elision:
# Set up to leverage grad copy elision:
param
.
grad
=
None
param
.
grad
=
None
def
post_backward_with_master_weights_FusedSGD
(
self
,
scaler
):
def
post_backward_with_master_weights_FusedSGD
(
self
,
scaler
):
stash
=
self
.
_amp_stash
if
self
.
materialize_master_grads
:
post_backward_with_master_weights
(
self
,
scaler
)
else
:
# TODO: handle gradient clipping and removal of any lingering scale here.
stash
=
self
.
_amp_stash
self
.
_amp_lazy_init
()
self
.
_amp_lazy_init
()
split_types
=
((
stash
.
all_fp16_params
,
stash
.
all_fp16_grad_stash
),
split_types
=
((
stash
.
all_fp16_params
,
stash
.
all_fp16_grad_stash
),
(
stash
.
all_fp32_from_fp32_params
,
stash
.
all_fp32_from_fp32_grad_stash
))
(
stash
.
all_fp32_from_fp32_params
,
stash
.
all_fp32_from_fp32_grad_stash
))
for
params
,
stashed_grads
in
split_types
:
for
params
,
stashed_grads
in
split_types
:
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
)
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
)
def
prepare_backward_no_master_weights_FusedSGD
(
self
):
def
prepare_backward_no_master_weights_FusedSGD
(
self
):
...
...
apex/optimizers/fused_sgd.py
View file @
c763f0fe
...
@@ -51,7 +51,8 @@ class FusedSGD(Optimizer):
...
@@ -51,7 +51,8 @@ class FusedSGD(Optimizer):
def
__init__
(
self
,
params
,
lr
=
required
,
momentum
=
0
,
dampening
=
0
,
def
__init__
(
self
,
params
,
lr
=
required
,
momentum
=
0
,
dampening
=
0
,
weight_decay
=
0
,
nesterov
=
False
,
weight_decay
=
0
,
nesterov
=
False
,
wd_after_momentum
=
False
):
wd_after_momentum
=
False
,
materialize_master_grads
=
True
):
if
lr
is
not
required
and
lr
<
0.0
:
if
lr
is
not
required
and
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
momentum
<
0.0
:
if
momentum
<
0.0
:
...
@@ -67,6 +68,8 @@ class FusedSGD(Optimizer):
...
@@ -67,6 +68,8 @@ class FusedSGD(Optimizer):
self
.
wd_after_momentum
=
wd_after_momentum
self
.
wd_after_momentum
=
wd_after_momentum
self
.
scale
=
1.0
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
:
import
amp_C
import
amp_C
# Skip buffer
# Skip buffer
...
@@ -130,18 +133,30 @@ class FusedSGD(Optimizer):
...
@@ -130,18 +133,30 @@ class FusedSGD(Optimizer):
if
explicit_master_params
:
if
explicit_master_params
:
stash
=
self
.
_amp_stash
stash
=
self
.
_amp_stash
fp16_model_params
=
[
p
for
p
in
stash
.
fp16_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp16_model_grads
=
[
p
.
grad
for
p
in
stash
.
fp16_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp32_from_fp16_params
=
[
p
for
i
,
p
in
enumerate
(
stash
.
fp32_from_fp16_groups
[
gid
])
if
stash
.
fp16_groups
[
gid
][
i
].
grad
is
not
None
]
fp32_from_fp16_momentums
,
first_runs
[
0
]
=
self
.
get_momentums
(
fp32_from_fp16_params
)
fp32_params
=
[
p
for
p
in
stash
.
fp32_from_fp32_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp32_params
=
[
p
for
p
in
stash
.
fp32_from_fp32_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp32_grads
=
[
p
.
grad
for
p
in
stash
.
fp32_from_fp32_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp32_grads
=
[
p
.
grad
for
p
in
stash
.
fp32_from_fp32_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp32_momentums
,
first_runs
[
1
]
=
self
.
get_momentums
(
fp32_params
)
fp32_momentums
,
first_runs
[
1
]
=
self
.
get_momentums
(
fp32_params
)
launch_sets
=
[[
fp16_model_grads
,
fp32_from_fp16_params
,
fp32_from_fp16_momentums
,
fp16_model_params
],
if
materialize_master_grads
:
[
fp32_grads
,
fp32_params
,
fp32_momentums
]]
fp16_params
=
[
p
for
i
,
p
in
enumerate
(
stash
.
fp16_groups
[
gid
])
if
stash
.
fp32_from_fp16_groups
[
gid
][
i
].
grad
is
not
None
]
fp32_from_fp16_grads
=
[
p
.
grad
for
p
in
stash
.
fp32_from_fp16_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp32_from_fp16_params
=
[
p
for
p
in
stash
.
fp32_from_fp16_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp32_from_fp16_momentums
,
first_runs
[
0
]
=
self
.
get_momentums
(
fp32_from_fp16_params
)
fp16_set
=
[
fp32_from_fp16_grads
,
fp32_from_fp16_params
,
fp32_from_fp16_momentums
,
fp16_model_params
]
else
:
fp16_model_params
=
[
p
for
p
in
stash
.
fp16_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp16_model_grads
=
[
p
.
grad
for
p
in
stash
.
fp16_groups
[
gid
]
if
p
.
grad
is
not
None
]
fp32_from_fp16_params
=
[
p
for
i
,
p
in
enumerate
(
stash
.
fp32_from_fp16_groups
[
gid
])
if
stash
.
fp16_groups
[
gid
][
i
].
grad
is
not
None
]
fp32_from_fp16_momentums
,
first_runs
[
0
]
=
self
.
get_momentums
(
fp32_from_fp16_params
)
fp16_set
=
[
fp16_model_grads
,
fp32_from_fp16_params
,
fp32_from_fp16_momentums
,
fp16_model_params
]
launch_sets
=
[
fp16_set
,
[
fp32_grads
,
fp32_params
,
fp32_momentums
]]
else
:
else
:
fp16_params
=
[
p
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
torch
.
float16
and
p
.
grad
is
not
None
)]
fp16_params
=
[
p
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
torch
.
float16
and
p
.
grad
is
not
None
)]
fp16_grads
=
[
p
.
grad
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
torch
.
float16
and
p
.
grad
is
not
None
)]
fp16_grads
=
[
p
.
grad
for
p
in
group
[
'params'
]
if
(
p
.
dtype
==
torch
.
float16
and
p
.
grad
is
not
None
)]
...
@@ -168,6 +183,7 @@ class FusedSGD(Optimizer):
...
@@ -168,6 +183,7 @@ class FusedSGD(Optimizer):
group
[
'lr'
],
group
[
'lr'
],
nesterov
,
nesterov
,
first_run
,
first_run
,
self
.
wd_after_momentum
)
self
.
wd_after_momentum
,
self
.
scale
)
return
loss
return
loss
csrc/amp_C_frontend.cpp
View file @
c763f0fe
...
@@ -16,7 +16,8 @@ void multi_tensor_sgd_cuda(
...
@@ -16,7 +16,8 @@ void multi_tensor_sgd_cuda(
float
lr
,
float
lr
,
bool
nesterov
,
bool
nesterov
,
bool
first_run
,
bool
first_run
,
bool
wd_after_momentum
);
bool
wd_after_momentum
,
float
scale
);
void
multi_tensor_axpby_cuda
(
void
multi_tensor_axpby_cuda
(
int
chunk_size
,
int
chunk_size
,
...
...
csrc/multi_tensor_sgd_kernel.cu
View file @
c763f0fe
...
@@ -38,7 +38,8 @@ struct SGDFunctor
...
@@ -38,7 +38,8 @@ struct SGDFunctor
float
lr
,
float
lr
,
bool
nesterov
,
bool
nesterov
,
bool
first_run
,
bool
first_run
,
bool
wd_after_momentum
)
bool
wd_after_momentum
,
float
scale
)
{
{
// Early exit if we don't need to do anything
// Early exit if we don't need to do anything
if
(
*
noop_gmem
)
return
;
if
(
*
noop_gmem
)
return
;
...
@@ -82,7 +83,7 @@ struct SGDFunctor
...
@@ -82,7 +83,7 @@ struct SGDFunctor
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
if
(
i
<
n
&&
i
<
chunk_size
)
{
{
incoming_grads
[
ii
]
=
static_cast
<
float
>
(
grad_in
[
i
]);
incoming_grads
[
ii
]
=
static_cast
<
float
>
(
grad_in
[
i
])
*
scale
;
incoming_weights
[
ii
]
=
static_cast
<
float
>
(
weight_in
[
i
]);
incoming_weights
[
ii
]
=
static_cast
<
float
>
(
weight_in
[
i
]);
incoming_moms
[
ii
]
=
static_cast
<
float
>
(
mom_in
[
i
]);
incoming_moms
[
ii
]
=
static_cast
<
float
>
(
mom_in
[
i
]);
}
}
...
@@ -146,7 +147,8 @@ void multi_tensor_sgd_cuda(
...
@@ -146,7 +147,8 @@ void multi_tensor_sgd_cuda(
float
lr
,
float
lr
,
bool
nesterov
,
bool
nesterov
,
bool
first_run
,
bool
first_run
,
bool
wd_after_momentum
)
bool
wd_after_momentum
,
float
scale
)
{
{
auto
num_tensors
=
tensor_lists
.
size
();
auto
num_tensors
=
tensor_lists
.
size
();
auto
grad_type
=
tensor_lists
[
0
][
0
].
scalar_type
();
auto
grad_type
=
tensor_lists
[
0
][
0
].
scalar_type
();
...
@@ -178,7 +180,8 @@ void multi_tensor_sgd_cuda(
...
@@ -178,7 +180,8 @@ void multi_tensor_sgd_cuda(
lr
,
lr
,
nesterov
,
nesterov
,
first_run
,
first_run
,
wd_after_momentum
);
wd_after_momentum
,
scale
);
}
}
// Case 2. fp16, fp32, fp32, No
// Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half &&
// else if (grad_type == at::ScalarType::Half &&
...
@@ -215,7 +218,8 @@ void multi_tensor_sgd_cuda(
...
@@ -215,7 +218,8 @@ void multi_tensor_sgd_cuda(
lr
,
lr
,
nesterov
,
nesterov
,
first_run
,
first_run
,
wd_after_momentum
);
wd_after_momentum
,
scale
);
}
}
// Case 3. fp16, fp32, fp32, Yes
// Case 3. fp16, fp32, fp32, Yes
else
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
else
if
(
grad_type
==
at
::
ScalarType
::
Half
&&
...
@@ -234,7 +238,8 @@ void multi_tensor_sgd_cuda(
...
@@ -234,7 +238,8 @@ void multi_tensor_sgd_cuda(
lr
,
lr
,
nesterov
,
nesterov
,
first_run
,
first_run
,
wd_after_momentum
);
wd_after_momentum
,
scale
);
}
}
else
else
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment