Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
fab319f1
Commit
fab319f1
authored
Oct 09, 2019
by
Bram Vanroy
Committed by
mcarilli
Oct 09, 2019
Browse files
allow for non-distributed envs (Windows) (#531)
parent
753c427a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
8 deletions
+15
-8
apex/__init__.py
apex/__init__.py
+3
-1
apex/amp/_initialize.py
apex/amp/_initialize.py
+7
-5
apex/amp/handle.py
apex/amp/handle.py
+5
-2
No files found.
apex/__init__.py
View file @
fab319f1
...
@@ -2,7 +2,9 @@
...
@@ -2,7 +2,9 @@
import
torch
import
torch
import
warnings
import
warnings
from
.
import
parallel
if
torch
.
distributed
.
is_available
():
from
.
import
parallel
from
.
import
amp
from
.
import
amp
from
.
import
fp16_utils
from
.
import
fp16_utils
...
...
apex/amp/_initialize.py
View file @
fab319f1
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
from
torch._six
import
string_classes
from
torch._six
import
string_classes
import
functools
import
functools
import
numpy
as
np
import
numpy
as
np
import
sys
import
warnings
import
warnings
from
._amp_state
import
_amp_state
,
warn_or_err
,
container_abcs
from
._amp_state
import
_amp_state
,
warn_or_err
,
container_abcs
from
.handle
import
disable_casts
from
.handle
import
disable_casts
...
@@ -10,8 +11,10 @@ from ._process_optimizer import _process_optimizer
...
@@ -10,8 +11,10 @@ from ._process_optimizer import _process_optimizer
from
apex.fp16_utils
import
convert_network
from
apex.fp16_utils
import
convert_network
from
..fp16_utils
import
FP16_Optimizer
as
FP16_Optimizer_general
from
..fp16_utils
import
FP16_Optimizer
as
FP16_Optimizer_general
from
..contrib.optimizers
import
FP16_Optimizer
as
FP16_Optimizer_for_fused
from
..contrib.optimizers
import
FP16_Optimizer
as
FP16_Optimizer_for_fused
from
..parallel
import
DistributedDataParallel
as
apex_DDP
from
..parallel.LARC
import
LARC
if
torch
.
distributed
.
is_available
():
from
..parallel
import
DistributedDataParallel
as
apex_DDP
from
..parallel.LARC
import
LARC
def
to_type
(
dtype
,
t
):
def
to_type
(
dtype
,
t
):
...
@@ -62,7 +65,7 @@ def check_models(models):
...
@@ -62,7 +65,7 @@ def check_models(models):
parallel_type
=
None
parallel_type
=
None
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
DistributedDataParallel
):
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
DistributedDataParallel
):
parallel_type
=
"torch.nn.parallel.DistributedDataParallel"
parallel_type
=
"torch.nn.parallel.DistributedDataParallel"
if
isinstance
(
model
,
apex_DDP
):
if
(
'apex_DDP'
in
sys
.
modules
)
and
isinstance
(
model
,
apex_DDP
):
parallel_type
=
"apex.parallel.DistributedDataParallel"
parallel_type
=
"apex.parallel.DistributedDataParallel"
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
DataParallel
):
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
DataParallel
):
parallel_type
=
"torch.nn.parallel.DataParallel"
parallel_type
=
"torch.nn.parallel.DataParallel"
...
@@ -139,11 +142,10 @@ class O2StateDictHook(object):
...
@@ -139,11 +142,10 @@ class O2StateDictHook(object):
def
_initialize
(
models
,
optimizers
,
properties
,
num_losses
=
1
,
cast_model_outputs
=
None
):
def
_initialize
(
models
,
optimizers
,
properties
,
num_losses
=
1
,
cast_model_outputs
=
None
):
from
apex.parallel
import
DistributedDataParallel
as
apex_DDP
from
.amp
import
init
as
amp_init
from
.amp
import
init
as
amp_init
optimizers_was_list
=
False
optimizers_was_list
=
False
if
isinstance
(
optimizers
,
torch
.
optim
.
Optimizer
)
or
isinstance
(
optimizers
,
LARC
):
if
isinstance
(
optimizers
,
torch
.
optim
.
Optimizer
)
or
(
'LARC'
in
sys
.
modules
and
isinstance
(
optimizers
,
LARC
)
)
:
optimizers
=
[
optimizers
]
optimizers
=
[
optimizers
]
elif
optimizers
is
None
:
elif
optimizers
is
None
:
optimizers
=
[]
optimizers
=
[]
...
...
apex/amp/handle.py
View file @
fab319f1
import
contextlib
import
contextlib
import
warnings
import
warnings
import
sys
import
torch
import
torch
from
.
import
utils
from
.
import
utils
from
.opt
import
OptimWrapper
from
.opt
import
OptimWrapper
from
.scaler
import
LossScaler
from
.scaler
import
LossScaler
from
._amp_state
import
_amp_state
,
master_params
,
maybe_print
from
._amp_state
import
_amp_state
,
master_params
,
maybe_print
from
..parallel.LARC
import
LARC
if
torch
.
distributed
.
is_available
():
from
..parallel.LARC
import
LARC
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
...
@@ -84,7 +87,7 @@ def scale_loss(loss,
...
@@ -84,7 +87,7 @@ def scale_loss(loss,
yield
loss
yield
loss
return
return
if
isinstance
(
optimizers
,
torch
.
optim
.
Optimizer
)
or
isinstance
(
optimizers
,
LARC
):
if
isinstance
(
optimizers
,
torch
.
optim
.
Optimizer
)
or
(
'LARC'
in
sys
.
modules
and
isinstance
(
optimizers
,
LARC
)
)
:
optimizers
=
[
optimizers
]
optimizers
=
[
optimizers
]
loss_scaler
=
_amp_state
.
loss_scalers
[
loss_id
]
loss_scaler
=
_amp_state
.
loss_scalers
[
loss_id
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment