Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
baf60758
Unverified
Commit
baf60758
authored
Jul 13, 2022
by
Yuge Zhang
Committed by
GitHub
Jul 13, 2022
Browse files
Prepare for multi-framework support in NAS (#4976)
parent
0e835aa9
Changes
20
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
201 additions
and
8 deletions
+201
-8
docs/source/reference/others.rst
docs/source/reference/others.rst
+6
-0
nni/__init__.py
nni/__init__.py
+1
-0
nni/common/blob_utils.py
nni/common/blob_utils.py
+0
-1
nni/common/framework.py
nni/common/framework.py
+93
-0
nni/retiarii/codegen/__init__.py
nni/retiarii/codegen/__init__.py
+5
-1
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+2
-0
nni/retiarii/evaluator/__init__.py
nni/retiarii/evaluator/__init__.py
+6
-0
nni/retiarii/evaluator/tensorflow.py
nni/retiarii/evaluator/tensorflow.py
+0
-0
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+1
-1
nni/retiarii/experiment/__init__.py
nni/retiarii/experiment/__init__.py
+7
-1
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+2
-2
nni/retiarii/experiment/tensorflow.py
nni/retiarii/experiment/tensorflow.py
+2
-0
nni/retiarii/hub/__init__.py
nni/retiarii/hub/__init__.py
+8
-0
nni/retiarii/hub/tensorflow.py
nni/retiarii/hub/tensorflow.py
+0
-0
nni/retiarii/nn/__init__.py
nni/retiarii/nn/__init__.py
+8
-0
nni/retiarii/nn/tensorflow/api.py
nni/retiarii/nn/tensorflow/api.py
+11
-0
nni/retiarii/oneshot/__init__.py
nni/retiarii/oneshot/__init__.py
+6
-0
nni/retiarii/strategy/__init__.py
nni/retiarii/strategy/__init__.py
+0
-1
nni/retiarii/strategy/local_debug_strategy.py
nni/retiarii/strategy/local_debug_strategy.py
+1
-1
test/ut/nas/test_import_nodep.py
test/ut/nas/test_import_nodep.py
+42
-0
No files found.
docs/source/reference/others.rst
View file @
baf60758
Uncategorized Modules
=====================
nni.common.framework
--------------------
.. automodule:: nni.common.framework
:members:
nni.common.serializer
---------------------
...
...
nni/__init__.py
View file @
baf60758
...
...
@@ -9,6 +9,7 @@ except ModuleNotFoundError:
from
.runtime.log
import
_init_logger
_init_logger
()
from
.common.framework
import
*
from
.common.serializer
import
trace
,
dump
,
load
from
.experiment
import
Experiment
from
.runtime.env_vars
import
dispatcher_env_vars
...
...
nni/common/blob_utils.py
View file @
baf60758
...
...
@@ -15,7 +15,6 @@ import tqdm
__all__
=
[
'NNI_BLOB'
,
'load_or_download_file'
,
'upload_file'
,
'nni_cache_home'
]
# Blob that contains some downloadable files.
NNI_BLOB
=
'https://nni.blob.core.windows.net'
...
...
nni/common/framework.py
0 → 100644
View file @
baf60758
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__
=
[
'set_default_framework'
,
'get_default_framework'
,
'shortcut_module'
,
'shortcut_framework'
]
import
importlib
import
os
import
sys
from
typing
import
Optional
,
cast
from
typing_extensions
import
Literal
framework_type
=
Literal
[
'pytorch'
,
'tensorflow'
,
'mxnet'
,
'none'
]
"""Supported framework types."""
ENV_NNI_FRAMEWORK
=
'NNI_FRAMEWORK'
def
framework_from_env
()
->
framework_type
:
framework
=
os
.
getenv
(
ENV_NNI_FRAMEWORK
,
'pytorch'
)
if
framework
not
in
framework_type
.
__args__
:
# type: ignore
raise
ValueError
(
f
'
{
framework
}
does not belong to
{
framework_type
.
__args__
}
'
)
# type: ignore
return
cast
(
framework_type
,
framework
)
DEFAULT_FRAMEWORK
=
framework_from_env
()
def
set_default_framework
(
framework
:
framework_type
)
->
None
:
"""Set default deep learning framework to simplify imports.
Some functionalities in NNI (e.g., NAS / Compression), relies on an underlying DL framework.
For different DL frameworks, the implementation of NNI can be very different.
Thus, users need import things tailored for their own framework. For example: ::
from nni.nas.xxx.pytorch import yyy
rather than: ::
from nni.nas.xxx import yyy
By setting a default framework, shortcuts will be made. As such ``nni.nas.xxx`` will be equivalent to ``nni.nas.xxx.pytorch``.
Another way to setting it is through environment variable ``NNI_FRAMEWORK``,
which needs to be set before the whole process starts.
If you set the framework with :func:`set_default_framework`,
it should be done before all imports (except nni itself) happen,
because it will affect other import's behaviors.
And the behavior is undefined if the framework is "re"-set in the middle.
The supported frameworks here are listed below.
It doesn't mean that they are fully supported by NAS / Compression in NNI.
* ``pytorch`` (default)
* ``tensorflow``
* ``mxnet``
* ``none`` (to disable the shortcut-import behavior).
Examples
--------
>>> import nni
>>> nni.set_default_framework('tensorflow')
>>> # then other imports
>>> from nni.nas.xxx import yyy
"""
# In case 'none' is written as None.
if
framework
is
None
:
framework
=
'none'
global
DEFAULT_FRAMEWORK
DEFAULT_FRAMEWORK
=
framework
def
get_default_framework
()
->
framework_type
:
"""Retrieve default deep learning framework set either with env variables or manually."""
return
DEFAULT_FRAMEWORK
def
shortcut_module
(
current
:
str
,
target
:
str
,
package
:
Optional
[
str
]
=
None
)
->
None
:
"""Make ``current`` module an alias of ``target`` module in ``package``."""
# Reference: https://github.com/dmlc/dgl/blob/d70a362dba8d46fd9838c79d76998a5e33f22cb7/python/dgl/nn/__init__.py#L27
mod
=
importlib
.
import_module
(
target
,
package
)
thismod
=
sys
.
modules
[
current
]
for
api
,
obj
in
mod
.
__dict__
.
items
():
setattr
(
thismod
,
api
,
obj
)
def
shortcut_framework
(
current
:
str
)
->
None
:
"""Make ``current`` a shortcut of ``current.framework``."""
if
get_default_framework
()
!=
'none'
:
# Throw ModuleNotFoundError if framework is not supported
shortcut_module
(
current
,
'.'
+
get_default_framework
(),
current
)
nni/retiarii/codegen/__init__.py
View file @
baf60758
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.pytorch
import
model_to_pytorch_script
from
nni.common.framework
import
shortcut_framework
shortcut_framework
(
__name__
)
del
shortcut_framework
nni/retiarii/codegen/pytorch.py
View file @
baf60758
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__
=
[
'model_to_pytorch_script'
]
import
logging
import
re
from
typing
import
Dict
,
List
,
Tuple
,
Any
,
cast
...
...
nni/retiarii/evaluator/__init__.py
View file @
baf60758
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
nni.common.framework
import
shortcut_framework
from
.functional
import
FunctionalEvaluator
shortcut_framework
(
__name__
)
del
shortcut_framework
nni/retiarii/evaluator/tensorflow.py
0 → 100644
View file @
baf60758
nni/retiarii/execution/base.py
View file @
baf60758
...
...
@@ -146,7 +146,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
mutation_summary
=
get_mutation_summary
(
model
)
assert
model
.
evaluator
is
not
None
,
'Model evaluator can not be None'
return
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
return
BaseGraphData
(
codegen
.
pytorch
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
# type: ignore
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
...
...
nni/retiarii/experiment/__init__.py
View file @
baf60758
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
nni.common.framework
import
shortcut_framework
shortcut_framework
(
__name__
)
del
shortcut_framework
nni/retiarii/experiment/pytorch.py
View file @
baf60758
...
...
@@ -20,7 +20,7 @@ from .config import (
RetiariiExeConfig
,
OneshotEngineConfig
,
BaseEngineConfig
,
PyEngineConfig
,
CgoEngineConfig
,
BenchmarkEngineConfig
)
from
..codegen
import
model_to_pytorch_script
from
..codegen
.pytorch
import
model_to_pytorch_script
from
..converter
import
convert_to_graph
from
..converter.graph_gen
import
GraphConverterWithShape
from
..execution
import
list_models
,
set_execution_engine
...
...
@@ -97,7 +97,7 @@ def debug_mutated_model(base_model, evaluator, applied_mutators):
a list of mutators that will be applied on the base model for generating a new model
"""
base_model_ir
,
applied_mutators
=
preprocess_model
(
base_model
,
evaluator
,
applied_mutators
)
from
..strategy
import
_LocalDebugStrategy
from
..strategy
.local_debug_strategy
import
_LocalDebugStrategy
strategy
=
_LocalDebugStrategy
()
strategy
.
run
(
base_model_ir
,
applied_mutators
)
_logger
.
info
(
'local debug completed!'
)
...
...
nni/retiarii/experiment/tensorflow.py
0 → 100644
View file @
baf60758
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
nni/retiarii/hub/__init__.py
View file @
baf60758
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
nni.common.framework
import
shortcut_framework
shortcut_framework
(
__name__
)
del
shortcut_framework
nni/retiarii/hub/tensorflow.py
0 → 100644
View file @
baf60758
nni/retiarii/nn/__init__.py
View file @
baf60758
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
nni.common.framework
import
shortcut_framework
shortcut_framework
(
__name__
)
del
shortcut_framework
nni/retiarii/nn/tensorflow/api.py
0 → 100644
View file @
baf60758
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
import
tensorflow
as
tf
class
LayerChoice
(
tf
.
keras
.
Layer
):
# FIXME: This is only a draft to test multi-framework support, it's not unimplemented at all.
pass
nni/retiarii/oneshot/__init__.py
View file @
baf60758
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
nni.common.framework
import
shortcut_framework
from
.interface
import
BaseOneShotTrainer
shortcut_framework
(
__name__
)
del
shortcut_framework
nni/retiarii/strategy/__init__.py
View file @
baf60758
...
...
@@ -5,6 +5,5 @@ from .base import BaseStrategy
from
.bruteforce
import
Random
,
GridSearch
from
.evolution
import
RegularizedEvolution
from
.tpe_strategy
import
TPEStrategy
,
TPE
from
.local_debug_strategy
import
_LocalDebugStrategy
from
.rl
import
PolicyBasedRL
from
.oneshot
import
DARTS
,
Proxyless
,
GumbelDARTS
,
ENAS
,
RandomOneShot
nni/retiarii/strategy/local_debug_strategy.py
View file @
baf60758
...
...
@@ -24,7 +24,7 @@ class _LocalDebugStrategy(BaseStrategy):
def
run_one_model
(
self
,
model
):
mutation_summary
=
get_mutation_summary
(
model
)
graph_data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
graph_data
=
BaseGraphData
(
codegen
.
pytorch
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
# type: ignore
random_str
=
''
.
join
(
random
.
choice
(
string
.
ascii_uppercase
+
string
.
digits
)
for
_
in
range
(
6
))
file_name
=
f
'_generated_model/
{
random_str
}
.py'
os
.
makedirs
(
os
.
path
.
dirname
(
file_name
),
exist_ok
=
True
)
...
...
test/ut/nas/test_import_nodep.py
0 → 100644
View file @
baf60758
"""To test the cases of importing NAS without certain DL libraries installed."""
import
argparse
import
subprocess
import
sys
import
pytest
def
import_related
(
mask_out
):
import
nni
nni
.
set_default_framework
(
mask_out
)
import
nni.retiarii
import
nni.retiarii.evaluator
import
nni.retiarii.hub
import
nni.retiarii.strategy
# FIXME: this doesn't work yet
import
nni.retiarii.experiment
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'masked'
,
choices
=
[
'torch'
,
'torch_none'
,
'tensorflow'
])
args
=
parser
.
parse_args
()
if
args
.
masked
==
'torch'
:
# https://stackoverflow.com/questions/1350466/preventing-python-code-from-importing-certain-modules
sys
.
modules
[
'torch'
]
=
None
import_related
(
'tensorflow'
)
if
args
.
masked
==
'torch_none'
:
sys
.
modules
[
'torch'
]
=
None
import_related
(
'none'
)
elif
args
.
masked
==
'tensorflow'
:
sys
.
modules
[
'tensorflow'
]
=
None
import_related
(
'pytorch'
)
@
pytest
.
mark
.
parametrize
(
'framework'
,
[
'torch'
,
'torch_none'
,
'tensorflow'
])
def
test_import_without_framework
(
framework
):
subprocess
.
run
([
sys
.
executable
,
__file__
,
framework
],
check
=
True
)
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment