Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d137b800
Commit
d137b800
authored
Feb 24, 2019
by
Michael Carilli
Browse files
Stashing work
parent
80a3f3ca
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
64 additions
and
78 deletions
+64
-78
apex/amp/_initialize.py
apex/amp/_initialize.py
+61
-32
apex/optimizers/fp16_optimizer.py
apex/optimizers/fp16_optimizer.py
+2
-2
csrc/multi_tensor_apply.cuh
csrc/multi_tensor_apply.cuh
+1
-1
tests/run_fp16_optimizer/__init__.py
tests/run_fp16_optimizer/__init__.py
+0
-0
tests/run_fp16_optimizer/test_fp16_optimizer.py
tests/run_fp16_optimizer/test_fp16_optimizer.py
+0
-43
No files found.
apex/amp/_initialize.py
View file @
d137b800
...
@@ -4,23 +4,9 @@ import functools
...
@@ -4,23 +4,9 @@ import functools
from
apex.fp16_utils
import
convert_network
from
apex.fp16_utils
import
convert_network
from
._amp_state
import
_amp_state
from
._amp_state
import
_amp_state
from
.scaler
import
LossScaler
from
.scaler
import
LossScaler
from
..fp16_utils
import
FP16_Optimizer
from
..fp16_utils
import
FP16_Optimizer
as
FP16_Optimizer_general
from
..optimizers
import
FP16_Optimizer
as
FP16_Optimizer_for_fused
from
..optimizers
import
FusedAdam
def
check_params_fp32
(
model
):
for
name
,
param
in
model
.
named_parameters
():
if
param
.
is_floating_point
()
and
param
.
type
()
!=
"torch.cuda.FloatTensor"
:
print
(
"Warning: Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.initialize, you do not need to call .half() on your model
\n
"
"before passing it, no matter what optimization level you choose."
.
format
(
name
,
param
.
type
()))
for
name
,
buf
in
model
.
named_buffers
():
if
buf
.
is_floating_point
()
and
buf
.
type
()
!=
"torch.cuda.FloatTensor"
:
print
(
"Warning: Found buffer {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.initialize, you do not need to call .half() on your model
\n
"
"before passing it, no matter what optimization level you choose."
.
format
(
name
,
buf
.
type
()))
def
to_type
(
dtype
,
t
):
def
to_type
(
dtype
,
t
):
...
@@ -48,6 +34,56 @@ def applier(value, fn):
...
@@ -48,6 +34,56 @@ def applier(value, fn):
return
value
return
value
def
check_models
(
models
):
for
model
in
models
:
parallel_type
=
None
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
DistributedDataParallel
):
parallel_type
=
"torch.nn.parallel.DistributedDataParallel"
if
isinstance
(
model
,
apex_DDP
):
parallel_type
=
"apex.parallel.DistributedDataParallel"
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
DataParallel
):
parallel_type
=
"torch.nn.parallel.DataParallel"
if
parallel_type
is
not
None
:
raise
RuntimeError
(
"Incoming model is an instance of {}. "
.
format
(
parallel_type
)
+
"Parallel wrappers should only be applied to the model(s) AFTER
\n
"
"the model(s) have been returned from amp.initialize."
)
def
check_params_fp32
(
models
):
for
model
in
models
:
for
name
,
param
in
model
.
named_parameters
():
if
param
.
is_floating_point
()
and
param
.
type
()
!=
"torch.cuda.FloatTensor"
:
print
(
"Warning: Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.initialize, you do not need to call .half() on your model
\n
"
"before passing it, no matter what optimization level you choose."
.
format
(
name
,
param
.
type
()))
for
name
,
buf
in
model
.
named_buffers
():
if
buf
.
is_floating_point
()
and
buf
.
type
()
!=
"torch.cuda.FloatTensor"
:
print
(
"Warning: Found buffer {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.initialize, you do not need to call .half() on your model
\n
"
"before passing it, no matter what optimization level you choose."
.
format
(
name
,
buf
.
type
()))
def
check_optimizers
(
optimizers
):
for
optim
in
optimizers
:
bad_optim_type
=
None
if
isinstance
(
optim
,
FP16_Optimizer_general
):
bad_optim_type
=
"apex.fp16_utils.FP16_Optimizer"
if
isinstance
(
model
,
FP16_Optimizer_for_fused
):
bad_optim_type
=
"apex.optimizers.FP16_Optimizer"
if
bad_optim_type
is
not
None
:
raise
RuntimeError
(
"An incoming optimizer is an instance of {}. "
.
format
(
optim_type
)
+
"The optimizer(s) passed to amp.initialize() should be bare
\n
"
"instances of either ordinary Pytorch optimizers, or Apex fused
\n
"
"optimizers (currently just FusedAdam, but FusedSGD will be added
\n
"
"soon). You should not manually wrap your optimizer in either
\n
"
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer.
\n
"
"amp.initialize will take care of that for you (if necessary) based
\n
"
"on the specified opt_level (and optional overridden properties)."
def
_initialize
(
models
,
optimizers
,
properties
):
def
_initialize
(
models
,
optimizers
,
properties
):
from
apex.parallel
import
DistributedDataParallel
as
apex_DDP
from
apex.parallel
import
DistributedDataParallel
as
apex_DDP
from
.amp
import
init
as
amp_init
from
.amp
import
init
as
amp_init
...
@@ -68,21 +104,11 @@ def _initialize(models, optimizers, properties):
...
@@ -68,21 +104,11 @@ def _initialize(models, optimizers, properties):
else
:
else
:
raise
TypeError
(
"models must be either a single model or a list of models."
)
raise
TypeError
(
"models must be either a single model or a list of models."
)
for
model
in
models
:
check_models
(
models
)
parallel_type
=
None
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
DistributedDataParallel
):
parallel_type
=
"torch.nn.parallel.DistributedDataParallel"
if
isinstance
(
model
,
apex_DDP
):
parallel_type
=
"apex.parallel.DistributedDataParallel"
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
DataParallel
):
parallel_type
=
"torch.nn.parallel.DataParallel"
if
parallel_type
is
not
None
:
raise
RuntimeError
(
"Incoming model is an instance of {}. "
.
format
(
parallel_type
)
+
"Parallel wrappers should only be applied AFTER the model(s) have been "
"returned from amp.initialize."
)
for
model
in
models
:
check_params_fp32
(
models
)
check_params_fp32
(
model
)
check_optimizers
(
optimizers
)
# Stash master weights before casting the model.
# Stash master weights before casting the model.
# if properties.master_weights:
# if properties.master_weights:
...
@@ -112,8 +138,10 @@ def _initialize(models, optimizers, properties):
...
@@ -112,8 +138,10 @@ def _initialize(models, optimizers, properties):
if
properties
.
master_weights
:
if
properties
.
master_weights
:
for
i
,
optimizer
in
enumerate
(
optimizers
):
for
i
,
optimizer
in
enumerate
(
optimizers
):
if
isinstance
(
optimizer
,
FusedAdam
):
optimizers
[
i
]
=
wrap_fused_adam
(
optimizer
,
properties
)
if
properties
.
loss_scale
==
"dynamic"
:
if
properties
.
loss_scale
==
"dynamic"
:
optimizers
[
i
]
=
FP16_Optimizer
(
optimizers
[
i
],
dynamic_loss_scale
=
True
)
optimizers
[
i
]
=
FP16_Optimizer
_general
(
optimizers
[
i
],
dynamic_loss_scale
=
True
)
else
:
else
:
optimizers
[
i
]
=
FP16_Optimizer
(
optimizers
[
i
],
static_loss_scale
=
properties
.
loss_scale
)
optimizers
[
i
]
=
FP16_Optimizer
(
optimizers
[
i
],
static_loss_scale
=
properties
.
loss_scale
)
else
:
else
:
...
@@ -121,6 +149,7 @@ def _initialize(models, optimizers, properties):
...
@@ -121,6 +149,7 @@ def _initialize(models, optimizers, properties):
optimizer
.
loss_scaler
=
LossScaler
(
properties
.
loss_scale
)
optimizer
.
loss_scaler
=
LossScaler
(
properties
.
loss_scale
)
if
properties
.
patch_torch_functions
:
if
properties
.
patch_torch_functions
:
# handle is unused here. It's accessible later through a global value anyway.
handle
=
amp_init
(
loss_scale
=
properties
.
loss_scale
)
handle
=
amp_init
(
loss_scale
=
properties
.
loss_scale
)
if
optimizers_was_list
:
if
optimizers_was_list
:
...
...
apex/optimizers/fp16_optimizer.py
View file @
d137b800
...
@@ -26,7 +26,7 @@ except TypeError as err:
...
@@ -26,7 +26,7 @@ except TypeError as err:
class
FP16_Optimizer
(
object
):
class
FP16_Optimizer
(
object
):
"""
"""
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
Design
to be used in the same way but support only fused optimizers in apex
.
Design
ed only to wrap apex.optimizers.FusedAdam
.
Refer to apex.fp16_utils documents for more information.
Refer to apex.fp16_utils documents for more information.
Example::
Example::
...
@@ -179,7 +179,7 @@ class FP16_Optimizer(object):
...
@@ -179,7 +179,7 @@ class FP16_Optimizer(object):
def
backward
(
self
,
loss
):
def
backward
(
self
,
loss
):
"""
"""
:attr:`backward` performs the following
conceptual
steps:
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
2. scaled_loss = fp32_loss*loss_scale
...
...
csrc/multi_tensor_apply.cuh
View file @
d137b800
...
@@ -112,9 +112,9 @@ void multi_tensor_apply(
...
@@ -112,9 +112,9 @@ void multi_tensor_apply(
else
else
{
{
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
tl
.
sizes
[
0
]
=
tl
.
sizes
[
loc_tensor_info
-
1
];
for
(
int
d
=
0
;
d
<
depth
;
d
++
)
for
(
int
d
=
0
;
d
<
depth
;
d
++
)
tl
.
addresses
[
d
][
0
]
=
tl
.
addresses
[
d
][
loc_tensor_info
-
1
];
tl
.
addresses
[
d
][
0
]
=
tl
.
addresses
[
d
][
loc_tensor_info
-
1
];
tl
.
sizes
[
0
]
=
tl
.
sizes
[
loc_tensor_info
-
1
];
loc_tensor_info
=
1
;
loc_tensor_info
=
1
;
}
}
}
}
...
...
tests/run_fp16_optimizer/__init__.py
deleted
100644 → 0
View file @
80a3f3ca
tests/run_fp16_optimizer/test_fp16_optimizer.py
deleted
100644 → 0
View file @
80a3f3ca
import
unittest
import
functools
as
ft
import
itertools
as
it
import
torch
from
apex.fp16_utils
import
FP16_Optimizer
# Currently no-ops (tested via examples).
# FP16_Optimizer to be deprecated and moved under unified Amp API.
class
TestFP16Optimizer
(
unittest
.
TestCase
):
def
setUp
(
self
):
N
,
D_in
,
D_out
=
64
,
1024
,
16
self
.
N
=
N
self
.
D_in
=
D_in
self
.
D_out
=
D_out
self
.
x
=
torch
.
randn
((
N
,
D_in
),
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
y
=
torch
.
randn
((
N
,
D_out
),
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
model
=
torch
.
nn
.
Linear
(
D_in
,
D_out
).
cuda
().
half
()
# def tearDown(self):
# pass
def
test_minimal
(
self
):
pass
def
test_minimal_static
(
self
):
pass
def
test_minimal_dynamic
(
self
):
pass
def
test_closure
(
self
):
pass
def
test_closure_dynamic
(
self
):
pass
def
test_save_load
(
self
):
pass
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment