Unverified Commit baf60758 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Prepare for multi-framework support in NAS (#4976)

parent 0e835aa9
Uncategorized Modules
=====================
nni.common.framework
--------------------
.. automodule:: nni.common.framework
:members:
nni.common.serializer
---------------------
......
......@@ -9,6 +9,7 @@ except ModuleNotFoundError:
from .runtime.log import _init_logger
_init_logger()
from .common.framework import *
from .common.serializer import trace, dump, load
from .experiment import Experiment
from .runtime.env_vars import dispatcher_env_vars
......
......@@ -15,7 +15,6 @@ import tqdm
__all__ = ['NNI_BLOB', 'load_or_download_file', 'upload_file', 'nni_cache_home']
# Blob that contains some downloadable files.
NNI_BLOB = 'https://nni.blob.core.windows.net'
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['set_default_framework', 'get_default_framework', 'shortcut_module', 'shortcut_framework']
import importlib
import os
import sys
from typing import Optional, cast
from typing_extensions import Literal
framework_type = Literal['pytorch', 'tensorflow', 'mxnet', 'none']
"""Supported framework types."""
ENV_NNI_FRAMEWORK = 'NNI_FRAMEWORK'
def framework_from_env() -> framework_type:
framework = os.getenv(ENV_NNI_FRAMEWORK, 'pytorch')
if framework not in framework_type.__args__: # type: ignore
raise ValueError(f'{framework} does not belong to {framework_type.__args__}') # type: ignore
return cast(framework_type, framework)
DEFAULT_FRAMEWORK = framework_from_env()
def set_default_framework(framework: framework_type) -> None:
"""Set default deep learning framework to simplify imports.
Some functionalities in NNI (e.g., NAS / Compression), relies on an underlying DL framework.
For different DL frameworks, the implementation of NNI can be very different.
Thus, users need import things tailored for their own framework. For example: ::
from nni.nas.xxx.pytorch import yyy
rather than: ::
from nni.nas.xxx import yyy
By setting a default framework, shortcuts will be made. As such ``nni.nas.xxx`` will be equivalent to ``nni.nas.xxx.pytorch``.
Another way to setting it is through environment variable ``NNI_FRAMEWORK``,
which needs to be set before the whole process starts.
If you set the framework with :func:`set_default_framework`,
it should be done before all imports (except nni itself) happen,
because it will affect other import's behaviors.
And the behavior is undefined if the framework is "re"-set in the middle.
The supported frameworks here are listed below.
It doesn't mean that they are fully supported by NAS / Compression in NNI.
* ``pytorch`` (default)
* ``tensorflow``
* ``mxnet``
* ``none`` (to disable the shortcut-import behavior).
Examples
--------
>>> import nni
>>> nni.set_default_framework('tensorflow')
>>> # then other imports
>>> from nni.nas.xxx import yyy
"""
# In case 'none' is written as None.
if framework is None:
framework = 'none'
global DEFAULT_FRAMEWORK
DEFAULT_FRAMEWORK = framework
def get_default_framework() -> framework_type:
"""Retrieve default deep learning framework set either with env variables or manually."""
return DEFAULT_FRAMEWORK
def shortcut_module(current: str, target: str, package: Optional[str] = None) -> None:
"""Make ``current`` module an alias of ``target`` module in ``package``."""
# Reference: https://github.com/dmlc/dgl/blob/d70a362dba8d46fd9838c79d76998a5e33f22cb7/python/dgl/nn/__init__.py#L27
mod = importlib.import_module(target, package)
thismod = sys.modules[current]
for api, obj in mod.__dict__.items():
setattr(thismod, api, obj)
def shortcut_framework(current: str) -> None:
"""Make ``current`` a shortcut of ``current.framework``."""
if get_default_framework() != 'none':
# Throw ModuleNotFoundError if framework is not supported
shortcut_module(current, '.' + get_default_framework(), current)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .pytorch import model_to_pytorch_script
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['model_to_pytorch_script']
import logging
import re
from typing import Dict, List, Tuple, Any, cast
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
from .functional import FunctionalEvaluator
shortcut_framework(__name__)
del shortcut_framework
......@@ -146,7 +146,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def pack_model_data(cls, model: Model) -> Any:
mutation_summary = get_mutation_summary(model)
assert model.evaluator is not None, 'Model evaluator can not be None'
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary)
return BaseGraphData(codegen.pytorch.model_to_pytorch_script(model), model.evaluator, mutation_summary) # type: ignore
@classmethod
def trial_execute_graph(cls) -> None:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
\ No newline at end of file
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
......@@ -20,7 +20,7 @@ from .config import (
RetiariiExeConfig, OneshotEngineConfig, BaseEngineConfig,
PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig
)
from ..codegen import model_to_pytorch_script
from ..codegen.pytorch import model_to_pytorch_script
from ..converter import convert_to_graph
from ..converter.graph_gen import GraphConverterWithShape
from ..execution import list_models, set_execution_engine
......@@ -97,7 +97,7 @@ def debug_mutated_model(base_model, evaluator, applied_mutators):
a list of mutators that will be applied on the base model for generating a new model
"""
base_model_ir, applied_mutators = preprocess_model(base_model, evaluator, applied_mutators)
from ..strategy import _LocalDebugStrategy
from ..strategy.local_debug_strategy import _LocalDebugStrategy
strategy = _LocalDebugStrategy()
strategy.run(base_model_ir, applied_mutators)
_logger.info('local debug completed!')
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import tensorflow as tf
class LayerChoice(tf.keras.Layer):
# FIXME: This is only a draft to test multi-framework support, it's not unimplemented at all.
pass
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
from .interface import BaseOneShotTrainer
shortcut_framework(__name__)
del shortcut_framework
......@@ -5,6 +5,5 @@ from .base import BaseStrategy
from .bruteforce import Random, GridSearch
from .evolution import RegularizedEvolution
from .tpe_strategy import TPEStrategy, TPE
from .local_debug_strategy import _LocalDebugStrategy
from .rl import PolicyBasedRL
from .oneshot import DARTS, Proxyless, GumbelDARTS, ENAS, RandomOneShot
......@@ -24,7 +24,7 @@ class _LocalDebugStrategy(BaseStrategy):
def run_one_model(self, model):
mutation_summary = get_mutation_summary(model)
graph_data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary)
graph_data = BaseGraphData(codegen.pytorch.model_to_pytorch_script(model), model.evaluator, mutation_summary) # type: ignore
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
......
"""To test the cases of importing NAS without certain DL libraries installed."""
import argparse
import subprocess
import sys
import pytest
def import_related(mask_out):
import nni
nni.set_default_framework(mask_out)
import nni.retiarii
import nni.retiarii.evaluator
import nni.retiarii.hub
import nni.retiarii.strategy # FIXME: this doesn't work yet
import nni.retiarii.experiment
def main():
parser = argparse.ArgumentParser()
parser.add_argument('masked', choices=['torch', 'torch_none', 'tensorflow'])
args = parser.parse_args()
if args.masked == 'torch':
# https://stackoverflow.com/questions/1350466/preventing-python-code-from-importing-certain-modules
sys.modules['torch'] = None
import_related('tensorflow')
if args.masked == 'torch_none':
sys.modules['torch'] = None
import_related('none')
elif args.masked == 'tensorflow':
sys.modules['tensorflow'] = None
import_related('pytorch')
@pytest.mark.parametrize('framework', ['torch', 'torch_none', 'tensorflow'])
def test_import_without_framework(framework):
subprocess.run([sys.executable, __file__, framework], check=True)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment