"docs/source/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "db3130d7059bbeb2341857120f6a3c4690fbd2c5"
Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .execution import *
from .fixed import fixed_arch
from .mutable import *
from .utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
from .evaluator import Evaluator
from .functional import FunctionalEvaluator
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['Evaluator']
import abc
from typing import Any, Callable, Type, Union, cast
class Evaluator(abc.ABC):
"""
Evaluator of a model. An evaluator should define where the training code is, and the configuration of
training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs)
or tune-able parameters (such as learning rate), depending on the implementation of training code.
Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class.
For example, functional evaluator might directly import the function and call the function.
"""
def evaluate(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
"""To run evaluation of a model. The model could be either a concrete model or a callable returning a model.
The concrete implementation of evaluate depends on the implementation of ``_execute()`` in sub-class.
"""
return self._execute(model_cls)
def __repr__(self):
items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()])
return f'{self.__class__.__name__}({items})'
@staticmethod
def _load(ir: Any) -> 'Evaluator':
evaluator_type = ir.get('type')
if isinstance(evaluator_type, str):
# for debug purposes only
for subclass in Evaluator.__subclasses__():
if subclass.__name__ == evaluator_type:
evaluator_type = subclass
break
assert issubclass(cast(type, evaluator_type), Evaluator)
return cast(Type[Evaluator], evaluator_type)._load(ir)
@abc.abstractmethod
def _dump(self) -> Any:
"""
Subclass implements ``_dump`` for their own serialization.
They should return a dict, with a key ``type`` which equals ``self.__class__``,
and optionally other keys.
"""
pass
@abc.abstractmethod
def _execute(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
pass
@abc.abstractmethod
def __eq__(self, other) -> bool:
pass
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import nni import nni
from nni.retiarii.graph import Evaluator from .evaluator import Evaluator
@nni.trace @nni.trace
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .lightning import *
...@@ -15,6 +15,12 @@ import nni ...@@ -15,6 +15,12 @@ import nni
from ..lightning import LightningModule, _AccuracyWithLogits, Lightning from ..lightning import LightningModule, _AccuracyWithLogits, Lightning
from .trainer import Trainer from .trainer import Trainer
__all__ = [
'_MultiModelSupervisedLearningModule', 'MultiModelSupervisedLearningModule',
'_ClassificationModule', 'Classification',
'_RegressionModule', 'Regression',
]
@nni.trace @nni.trace
class _MultiModelSupervisedLearningModule(LightningModule): class _MultiModelSupervisedLearningModule(LightningModule):
......
...@@ -21,11 +21,15 @@ try: ...@@ -21,11 +21,15 @@ try:
except ImportError: except ImportError:
cgo_import_failed = True cgo_import_failed = True
from nni.retiarii.graph import Evaluator from nni.nas.evaluator import Evaluator
from nni.typehint import Literal from nni.typehint import Literal
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression'] __all__ = [
'LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression',
'_AccuracyWithLogits', '_SupervisedLearningModule', '_ClassificationModule', '_RegressionModule',
# FIXME: hack to make it importable for tests
]
class LightningModule(pl.LightningModule): class LightningModule(pl.LightningModule):
...@@ -113,7 +117,7 @@ class Lightning(Evaluator): ...@@ -113,7 +117,7 @@ class Lightning(Evaluator):
else: else:
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different # this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \ assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer' f'Trainer must be imported from {__name__} or nni.nas.evaluator.pytorch.cgo.trainer'
if not _check_dataloader(train_dataloaders): if not _check_dataloader(train_dataloaders):
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or ' warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
f'import DataLoader from {__name__}: {train_dataloaders}', f'import DataLoader from {__name__}: {train_dataloaders}',
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .api import *
from .common import *
...@@ -5,9 +5,11 @@ import time ...@@ -5,9 +5,11 @@ import time
import warnings import warnings
from typing import Iterable from typing import Iterable
from ..graph import Model, ModelStatus from nni.nas.execution.common import (
from .interface import AbstractExecutionEngine Model, ModelStatus,
from .listener import DefaultListener AbstractExecutionEngine,
DefaultListener
)
_execution_engine = None _execution_engine = None
_default_listener = None _default_listener = None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .engine import *
from .graph_op import *
from .graph import *
from .integration_api import *
from .integration import *
from .listener import *
from .utils import *
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
from abc import ABC, abstractmethod, abstractclassmethod from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, Iterable, NewType, List, Union, Type from typing import Any, Iterable, NewType, List, Union, Type
from ..graph import Model, MetricData from .graph import Model, MetricData
__all__ = [ __all__ = [
'GraphData', 'WorkerInfo', 'GraphData', 'WorkerInfo', 'MetricData',
'AbstractGraphListener', 'AbstractExecutionEngine' 'AbstractGraphListener', 'AbstractExecutionEngine'
] ]
......
...@@ -2,24 +2,27 @@ ...@@ -2,24 +2,27 @@
# Licensed under the MIT license. # Licensed under the MIT license.
""" """
Model representation. Model representation for engines based on graph.
""" """
from __future__ import annotations from __future__ import annotations
import abc
import json import json
from enum import Enum from enum import Enum
from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List, from typing import (TYPE_CHECKING, Any, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union, cast, overload) Optional, Set, Tuple, Type, Union, cast, overload)
if TYPE_CHECKING: if TYPE_CHECKING:
from .mutator import Mutator from .mutator import Mutator
from .operation import Cell, Operation, _IOPseudoOperation from nni.nas.evaluator import Evaluator
from .utils import uid from nni.nas.utils import uid
from .graph_op import Cell, Operation, _IOPseudoOperation
__all__ = ['Evaluator', 'Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData'] __all__ = [
'Evaluator', 'Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData',
'DebugEvaluator',
]
MetricData = Any MetricData = Any
...@@ -33,57 +36,6 @@ Type hint for edge's endpoint. The int indicates nodes' order. ...@@ -33,57 +36,6 @@ Type hint for edge's endpoint. The int indicates nodes' order.
""" """
class Evaluator(abc.ABC):
"""
Evaluator of a model. An evaluator should define where the training code is, and the configuration of
training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs)
or tune-able parameters (such as learning rate), depending on the implementation of training code.
Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class.
For example, functional evaluator might directly import the function and call the function.
"""
def evaluate(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
"""To run evaluation of a model. The model could be either a concrete model or a callable returning a model.
The concrete implementation of evaluate depends on the implementation of ``_execute()`` in sub-class.
"""
return self._execute(model_cls)
def __repr__(self):
items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()])
return f'{self.__class__.__name__}({items})'
@staticmethod
def _load(ir: Any) -> 'Evaluator':
evaluator_type = ir.get('type')
if isinstance(evaluator_type, str):
# for debug purposes only
for subclass in Evaluator.__subclasses__():
if subclass.__name__ == evaluator_type:
evaluator_type = subclass
break
assert issubclass(cast(type, evaluator_type), Evaluator)
return cast(Type[Evaluator], evaluator_type)._load(ir)
@abc.abstractmethod
def _dump(self) -> Any:
"""
Subclass implements ``_dump`` for their own serialization.
They should return a dict, with a key ``type`` which equals ``self.__class__``,
and optionally other keys.
"""
pass
@abc.abstractmethod
def _execute(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
pass
@abc.abstractmethod
def __eq__(self, other) -> bool:
pass
class Model: class Model:
""" """
Represents a neural network model. Represents a neural network model.
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
"""
Operations used in graph-based engine.
"""
from typing import (Any, Dict, List, Optional, cast) from typing import (Any, Dict, List, Optional, cast)
from . import debug_configs from nni.common.framework import get_default_framework
__all__ = ['Operation', 'Cell'] __all__ = ['Operation', 'Cell', 'PyTorchOperation', 'TensorFlowOperation']
def _convert_name(name: str) -> str: def _convert_name(name: str) -> str:
...@@ -63,14 +68,14 @@ class Operation: ...@@ -63,14 +68,14 @@ class Operation:
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node # NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return Cell(cell_name, parameters) return Cell(cell_name, parameters)
else: else:
if debug_configs.framework.lower() in ('torch', 'pytorch'): if get_default_framework() in ('torch', 'pytorch'):
from .operation_def import torch_op_def # pylint: disable=unused-import from nni.nas.execution.pytorch import op_def # pylint: disable=unused-import
cls = PyTorchOperation._find_subclass(type_name) cls = PyTorchOperation._find_subclass(type_name)
elif debug_configs.framework.lower() in ('tf', 'tensorflow'): elif get_default_framework() in ('tf', 'tensorflow'):
from .operation_def import tf_op_def # pylint: disable=unused-import from nni.nas.execution.tensorflow import op_def # pylint: disable=unused-import
cls = TensorFlowOperation._find_subclass(type_name) cls = TensorFlowOperation._find_subclass(type_name)
else: else:
raise ValueError(f'Unsupported framework: {debug_configs.framework}') raise ValueError(f'Unsupported framework: {get_default_framework()}')
return cls(type_name, parameters, _internal=True, attributes=attributes) return cls(type_name, parameters, _internal=True, attributes=attributes)
@classmethod @classmethod
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
__all__ = ['RetiariiAdvisor']
import logging import logging
import os import os
from typing import Any, Callable, Optional, Dict, List, Tuple from typing import Any, Callable, Optional, Dict, List, Tuple
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
__all__ = [
'get_advisor', 'register_advisor', 'send_trial', 'receive_trial_parameters', 'get_experiment_id',
'_advisor' # FIXME: hack to make it importable for tests
]
import warnings import warnings
from typing import NewType, Any from typing import NewType, Any
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from ..graph import Model, ModelStatus __all__ = ['DefaultListener']
from .interface import MetricData, AbstractGraphListener
from .graph import Model, ModelStatus, MetricData
from .engine import AbstractGraphListener
class DefaultListener(AbstractGraphListener): class DefaultListener(AbstractGraphListener):
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
__all__ = ['unpack_if_only_one', 'get_mutation_dict', 'mutation_dict_to_summary', 'get_mutation_summary']
from typing import Any, List from typing import Any, List
from ..graph import Model from .graph import Model
def _unpack_if_only_one(ele: List[Any]): def unpack_if_only_one(ele: List[Any]):
if len(ele) == 1: if len(ele) == 1:
return ele[0] return ele[0]
return ele return ele
def get_mutation_dict(model: Model): def get_mutation_dict(model: Model):
return {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history} return {mut.mutator.label: unpack_if_only_one(mut.samples) for mut in model.history}
def mutation_dict_to_summary(mutation: dict) -> dict: def mutation_dict_to_summary(mutation: dict) -> dict:
mutation_summary = {} mutation_summary = {}
...@@ -23,6 +28,7 @@ def mutation_dict_to_summary(mutation: dict) -> dict: ...@@ -23,6 +28,7 @@ def mutation_dict_to_summary(mutation: dict) -> dict:
mutation_summary[f'{label}_{i}'] = sample mutation_summary[f'{label}_{i}'] = sample
return mutation_summary return mutation_summary
def get_mutation_summary(model: Model) -> dict: def get_mutation_summary(model: Model) -> dict:
mutation = get_mutation_dict(model) mutation = get_mutation_dict(model)
return mutation_dict_to_summary(mutation) return mutation_dict_to_summary(mutation)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment