Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5c9797a2
Unverified
Commit
5c9797a2
authored
Aug 21, 2021
by
Zhenhua Han
Committed by
GitHub
Aug 21, 2021
Browse files
Fix interface of CGO's accelerator to support pytorch-lightning 1.4.2 (#4075)
parent
cae1e6d4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
8 deletions
+37
-8
dependencies/recommended.txt
dependencies/recommended.txt
+1
-1
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
+35
-6
nni/retiarii/evaluator/pytorch/cgo/trainer.py
nni/retiarii/evaluator/pytorch/cgo/trainer.py
+1
-1
No files found.
dependencies/recommended.txt
View file @
5c9797a2
...
@@ -8,7 +8,7 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
...
@@ -8,7 +8,7 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
torch == 1.9.0 ; sys_platform == "darwin"
torch == 1.9.0 ; sys_platform == "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin"
pytorch-lightning >=
1.2.8, <
1.4.2
pytorch-lightning >= 1.4.2
onnx
onnx
peewee
peewee
graphviz
graphviz
...
...
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
View file @
5c9797a2
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
from
pytorch_lightning.accelerators.accelerator
import
Accelerator
from
pytorch_lightning.accelerators.accelerator
import
Accelerator
from
pytorch_lightning.plugins.training_type.training_type_plugin
import
TrainingTypePlugin
from
pytorch_lightning.plugins.training_type.training_type_plugin
import
TrainingTypePlugin
from
pytorch_lightning.trainer.connectors.accelerator_connector
import
AcceleratorConnector
from
pytorch_lightning.trainer.connectors.accelerator_connector
import
AcceleratorConnector
from
pytorch_lightning.trainer
import
Trainer
from
pytorch_lightning.plugins
import
Plugin
from
pytorch_lightning.plugins
import
Plugin
from
pytorch_lightning.plugins.environments
import
ClusterEnvironment
from
pytorch_lightning.plugins.environments
import
ClusterEnvironment
...
@@ -53,6 +54,13 @@ class BypassPlugin(TrainingTypePlugin):
...
@@ -53,6 +54,13 @@ class BypassPlugin(TrainingTypePlugin):
"""Perform a all_gather on all processes """
"""Perform a all_gather on all processes """
return
tensor
return
tensor
def
teardown
(
self
):
"""
This method is called to teardown the training process.
It is the right place to release memory and free other resources.
"""
pass
@
property
@
property
def
root_device
(
self
)
->
torch
.
device
:
def
root_device
(
self
)
->
torch
.
device
:
return
torch
.
device
(
self
.
device
)
return
torch
.
device
(
self
.
device
)
...
@@ -78,10 +86,13 @@ class BypassPlugin(TrainingTypePlugin):
...
@@ -78,10 +86,13 @@ class BypassPlugin(TrainingTypePlugin):
def
get_accelerator_connector
(
def
get_accelerator_connector
(
num_processes
:
int
=
1
,
num_processes
:
int
=
1
,
devices
:
Optional
[
Union
[
List
[
int
],
str
,
int
]]
=
None
,
tpu_cores
:
Optional
[
Union
[
List
[
int
],
str
,
int
]]
=
None
,
tpu_cores
:
Optional
[
Union
[
List
[
int
],
str
,
int
]]
=
None
,
ipus
:
Optional
[
int
]
=
None
,
distributed_backend
:
Optional
[
str
]
=
None
,
distributed_backend
:
Optional
[
str
]
=
None
,
a
uto_select_gpus
:
bool
=
Fals
e
,
a
ccelerator
:
Optional
[
Union
[
str
,
Accelerator
]]
=
Non
e
,
gpus
:
Optional
[
Union
[
List
[
int
],
str
,
int
]]
=
None
,
gpus
:
Optional
[
Union
[
List
[
int
],
str
,
int
]]
=
None
,
auto_select_gpus
:
bool
=
False
,
num_nodes
:
int
=
1
,
num_nodes
:
int
=
1
,
sync_batchnorm
:
bool
=
False
,
sync_batchnorm
:
bool
=
False
,
benchmark
:
bool
=
False
,
benchmark
:
bool
=
False
,
...
@@ -90,17 +101,35 @@ def get_accelerator_connector(
...
@@ -90,17 +101,35 @@ def get_accelerator_connector(
precision
:
int
=
32
,
precision
:
int
=
32
,
amp_backend
:
str
=
'native'
,
amp_backend
:
str
=
'native'
,
amp_level
:
str
=
'O2'
,
amp_level
:
str
=
'O2'
,
plugins
:
Optional
[
Union
[
List
[
Union
[
Plugin
,
ClusterEnvironment
,
str
]],
Plugin
,
ClusterEnvironment
,
str
]]
=
None
):
plugins
:
Optional
[
Union
[
List
[
Union
[
Plugin
,
ClusterEnvironment
,
str
]],
Plugin
,
ClusterEnvironment
,
str
]]
=
None
,
**
other_trainier_kwargs
)
->
AcceleratorConnector
:
gpu_ids
=
Trainer
().
_parse_devices
(
gpus
,
auto_select_gpus
,
tpu_cores
)
return
AcceleratorConnector
(
return
AcceleratorConnector
(
num_processes
,
tpu_cores
,
distributed_backend
,
auto_select_gpus
,
gpus
,
num_nodes
,
sync_batchnorm
,
benchmark
,
num_processes
,
replace_sampler_ddp
,
deterministic
,
precision
,
amp_backend
,
amp_level
,
plugins
devices
,
tpu_cores
,
ipus
,
distributed_backend
,
accelerator
,
gpus
,
gpu_ids
,
num_nodes
,
sync_batchnorm
,
benchmark
,
replace_sampler_ddp
,
deterministic
,
precision
,
amp_backend
,
amp_level
,
plugins
,
)
)
@
serialize_cls
@
serialize_cls
class
BypassAccelerator
(
Accelerator
):
class
BypassAccelerator
(
Accelerator
):
def
__init__
(
self
,
precision_plugin
=
None
,
device
=
"cpu"
):
def
__init__
(
self
,
precision_plugin
=
None
,
device
=
"cpu"
,
**
trainer_kwargs
):
if
precision_plugin
is
None
:
if
precision_plugin
is
None
:
precision_plugin
=
get_accelerator_connector
().
precision_plugin
precision_plugin
=
get_accelerator_connector
(
**
trainer_kwargs
).
select_precision_plugin
()
# pylint: disable=abstract-class-instantiated
# pylint: disable=abstract-class-instantiated
super
().
__init__
(
precision_plugin
=
precision_plugin
,
training_type_plugin
=
BypassPlugin
(
device
))
super
().
__init__
(
precision_plugin
=
precision_plugin
,
training_type_plugin
=
BypassPlugin
(
device
))
nni/retiarii/evaluator/pytorch/cgo/trainer.py
View file @
5c9797a2
...
@@ -26,6 +26,6 @@ class Trainer(pl.Trainer):
...
@@ -26,6 +26,6 @@ class Trainer(pl.Trainer):
if
use_cgo
:
if
use_cgo
:
if
"accelerator"
in
trainer_kwargs
:
if
"accelerator"
in
trainer_kwargs
:
raise
ValueError
(
"accelerator should not be set when cross-graph optimization is enabled."
)
raise
ValueError
(
"accelerator should not be set when cross-graph optimization is enabled."
)
trainer_kwargs
[
'accelerator'
]
=
BypassAccelerator
(
device
=
'cpu'
)
trainer_kwargs
[
'accelerator'
]
=
BypassAccelerator
(
device
=
'cpu'
,
**
trainer_kwargs
)
super
().
__init__
(
**
trainer_kwargs
)
super
().
__init__
(
**
trainer_kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment