Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .execution import *
from .fixed import fixed_arch
from .mutable import *
from .utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
from .evaluator import Evaluator
from .functional import FunctionalEvaluator
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['Evaluator']
import abc
from typing import Any, Callable, Type, Union, cast
class Evaluator(abc.ABC):
"""
Evaluator of a model. An evaluator should define where the training code is, and the configuration of
training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs)
or tune-able parameters (such as learning rate), depending on the implementation of training code.
Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class.
For example, functional evaluator might directly import the function and call the function.
"""
def evaluate(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
"""To run evaluation of a model. The model could be either a concrete model or a callable returning a model.
The concrete implementation of evaluate depends on the implementation of ``_execute()`` in sub-class.
"""
return self._execute(model_cls)
def __repr__(self):
items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()])
return f'{self.__class__.__name__}({items})'
@staticmethod
def _load(ir: Any) -> 'Evaluator':
evaluator_type = ir.get('type')
if isinstance(evaluator_type, str):
# for debug purposes only
for subclass in Evaluator.__subclasses__():
if subclass.__name__ == evaluator_type:
evaluator_type = subclass
break
assert issubclass(cast(type, evaluator_type), Evaluator)
return cast(Type[Evaluator], evaluator_type)._load(ir)
@abc.abstractmethod
def _dump(self) -> Any:
"""
Subclass implements ``_dump`` for their own serialization.
They should return a dict, with a key ``type`` which equals ``self.__class__``,
and optionally other keys.
"""
pass
@abc.abstractmethod
def _execute(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
pass
@abc.abstractmethod
def __eq__(self, other) -> bool:
pass
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import nni
from nni.retiarii.graph import Evaluator
from .evaluator import Evaluator
@nni.trace
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .lightning import *
......@@ -15,6 +15,12 @@ import nni
from ..lightning import LightningModule, _AccuracyWithLogits, Lightning
from .trainer import Trainer
__all__ = [
'_MultiModelSupervisedLearningModule', 'MultiModelSupervisedLearningModule',
'_ClassificationModule', 'Classification',
'_RegressionModule', 'Regression',
]
@nni.trace
class _MultiModelSupervisedLearningModule(LightningModule):
......
......@@ -21,11 +21,15 @@ try:
except ImportError:
cgo_import_failed = True
from nni.retiarii.graph import Evaluator
from nni.nas.evaluator import Evaluator
from nni.typehint import Literal
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression']
__all__ = [
'LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression',
'_AccuracyWithLogits', '_SupervisedLearningModule', '_ClassificationModule', '_RegressionModule',
# FIXME: hack to make it importable for tests
]
class LightningModule(pl.LightningModule):
......@@ -113,7 +117,7 @@ class Lightning(Evaluator):
else:
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
f'Trainer must be imported from {__name__} or nni.nas.evaluator.pytorch.cgo.trainer'
if not _check_dataloader(train_dataloaders):
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
f'import DataLoader from {__name__}: {train_dataloaders}',
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .api import *
from .common import *
......@@ -5,9 +5,11 @@ import time
import warnings
from typing import Iterable
from ..graph import Model, ModelStatus
from .interface import AbstractExecutionEngine
from .listener import DefaultListener
from nni.nas.execution.common import (
Model, ModelStatus,
AbstractExecutionEngine,
DefaultListener
)
_execution_engine = None
_default_listener = None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .engine import *
from .graph_op import *
from .graph import *
from .integration_api import *
from .integration import *
from .listener import *
from .utils import *
......@@ -4,10 +4,10 @@
from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, Iterable, NewType, List, Union, Type
from ..graph import Model, MetricData
from .graph import Model, MetricData
__all__ = [
'GraphData', 'WorkerInfo',
'GraphData', 'WorkerInfo', 'MetricData',
'AbstractGraphListener', 'AbstractExecutionEngine'
]
......
......@@ -2,24 +2,27 @@
# Licensed under the MIT license.
"""
Model representation.
Model representation for engines based on graph.
"""
from __future__ import annotations
import abc
import json
from enum import Enum
from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union, cast, overload)
if TYPE_CHECKING:
from .mutator import Mutator
from .operation import Cell, Operation, _IOPseudoOperation
from .utils import uid
from nni.nas.evaluator import Evaluator
from nni.nas.utils import uid
from .graph_op import Cell, Operation, _IOPseudoOperation
__all__ = ['Evaluator', 'Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData']
__all__ = [
'Evaluator', 'Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData',
'DebugEvaluator',
]
MetricData = Any
......@@ -33,57 +36,6 @@ Type hint for edge's endpoint. The int indicates nodes' order.
"""
class Evaluator(abc.ABC):
"""
Evaluator of a model. An evaluator should define where the training code is, and the configuration of
training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs)
or tune-able parameters (such as learning rate), depending on the implementation of training code.
Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class.
For example, functional evaluator might directly import the function and call the function.
"""
def evaluate(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
"""To run evaluation of a model. The model could be either a concrete model or a callable returning a model.
The concrete implementation of evaluate depends on the implementation of ``_execute()`` in sub-class.
"""
return self._execute(model_cls)
def __repr__(self):
items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()])
return f'{self.__class__.__name__}({items})'
@staticmethod
def _load(ir: Any) -> 'Evaluator':
evaluator_type = ir.get('type')
if isinstance(evaluator_type, str):
# for debug purposes only
for subclass in Evaluator.__subclasses__():
if subclass.__name__ == evaluator_type:
evaluator_type = subclass
break
assert issubclass(cast(type, evaluator_type), Evaluator)
return cast(Type[Evaluator], evaluator_type)._load(ir)
@abc.abstractmethod
def _dump(self) -> Any:
"""
Subclass implements ``_dump`` for their own serialization.
They should return a dict, with a key ``type`` which equals ``self.__class__``,
and optionally other keys.
"""
pass
@abc.abstractmethod
def _execute(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
pass
@abc.abstractmethod
def __eq__(self, other) -> bool:
pass
class Model:
"""
Represents a neural network model.
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Operations used in graph-based engine.
"""
from typing import (Any, Dict, List, Optional, cast)
from . import debug_configs
from nni.common.framework import get_default_framework
__all__ = ['Operation', 'Cell']
__all__ = ['Operation', 'Cell', 'PyTorchOperation', 'TensorFlowOperation']
def _convert_name(name: str) -> str:
......@@ -63,14 +68,14 @@ class Operation:
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return Cell(cell_name, parameters)
else:
if debug_configs.framework.lower() in ('torch', 'pytorch'):
from .operation_def import torch_op_def # pylint: disable=unused-import
if get_default_framework() in ('torch', 'pytorch'):
from nni.nas.execution.pytorch import op_def # pylint: disable=unused-import
cls = PyTorchOperation._find_subclass(type_name)
elif debug_configs.framework.lower() in ('tf', 'tensorflow'):
from .operation_def import tf_op_def # pylint: disable=unused-import
elif get_default_framework() in ('tf', 'tensorflow'):
from nni.nas.execution.tensorflow import op_def # pylint: disable=unused-import
cls = TensorFlowOperation._find_subclass(type_name)
else:
raise ValueError(f'Unsupported framework: {debug_configs.framework}')
raise ValueError(f'Unsupported framework: {get_default_framework()}')
return cls(type_name, parameters, _internal=True, attributes=attributes)
@classmethod
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['RetiariiAdvisor']
import logging
import os
from typing import Any, Callable, Optional, Dict, List, Tuple
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = [
'get_advisor', 'register_advisor', 'send_trial', 'receive_trial_parameters', 'get_experiment_id',
'_advisor' # FIXME: hack to make it importable for tests
]
import warnings
from typing import NewType, Any
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from ..graph import Model, ModelStatus
from .interface import MetricData, AbstractGraphListener
__all__ = ['DefaultListener']
from .graph import Model, ModelStatus, MetricData
from .engine import AbstractGraphListener
class DefaultListener(AbstractGraphListener):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['unpack_if_only_one', 'get_mutation_dict', 'mutation_dict_to_summary', 'get_mutation_summary']
from typing import Any, List
from ..graph import Model
from .graph import Model
def _unpack_if_only_one(ele: List[Any]):
def unpack_if_only_one(ele: List[Any]):
if len(ele) == 1:
return ele[0]
return ele
def get_mutation_dict(model: Model):
return {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
return {mut.mutator.label: unpack_if_only_one(mut.samples) for mut in model.history}
def mutation_dict_to_summary(mutation: dict) -> dict:
mutation_summary = {}
......@@ -23,6 +28,7 @@ def mutation_dict_to_summary(mutation: dict) -> dict:
mutation_summary[f'{label}_{i}'] = sample
return mutation_summary
def get_mutation_summary(model: Model) -> dict:
mutation = get_mutation_dict(model)
return mutation_dict_to_summary(mutation)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment