Unverified Commit a911b856 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Resolve conflicts for #4760 (#4762)

parent 14d2966b
......@@ -47,6 +47,7 @@ class _AlgorithmConfig(ConfigBase):
else: # custom algorithm
assert self.name is None
assert self.class_name
assert self.code_directory is not None
if not Path(self.code_directory).is_dir():
raise ValueError(f'CustomAlgorithmConfig: code_directory "{self.code_directory}" is not a directory')
......
......@@ -37,6 +37,8 @@ def to_v2(v1):
_move_field(v1_trial, v2, 'command', 'trialCommand')
_move_field(v1_trial, v2, 'codeDir', 'trialCodeDirectory')
_move_field(v1_trial, v2, 'gpuNum', 'trialGpuNumber')
else:
v1_trial = {}
for algo_type in ['tuner', 'assessor', 'advisor']:
v1_algo = v1.pop(algo_type, None)
......
......@@ -46,6 +46,7 @@ class FrameworkControllerConfig(TrainingServiceConfig):
service_account_name: Optional[str]
task_roles: List[FrameworkControllerRoleConfig]
reuse_mode: Optional[bool] = True
namespace: str = 'default'
def _canonicalize(self, parents):
super()._canonicalize(parents)
......
......@@ -43,6 +43,7 @@ class KubeflowConfig(TrainingServiceConfig):
ps: Optional[KubeflowRoleConfig] = None
master: Optional[KubeflowRoleConfig] = None
reuse_mode: Optional[bool] = True #set reuse mode as true for v2 config
namespace: str = 'default'
def _canonicalize(self, parents):
super()._canonicalize(parents)
......
......@@ -53,7 +53,7 @@ class RemoteMachineConfig(ConfigBase):
if self.password is not None:
warnings.warn('SSH password will be exposed in web UI as plain text. We recommend to use SSH key file.')
elif not Path(self.ssh_key_file).is_file():
elif not Path(self.ssh_key_file).is_file(): # type: ignore
raise ValueError(
f'RemoteMachineConfig: You must either provide password or a valid SSH key file "{self.ssh_key_file}"'
)
......
......@@ -20,6 +20,15 @@ import nni.runtime.config
from .public import is_missing
__all__ = [
'get_base_path', 'set_base_path', 'unset_base_path', 'resolve_path',
'case_insensitive', 'camel_case',
'is_instance', 'validate_type', 'is_path_like',
'guess_config_type', 'guess_list_config_type',
'training_service_config_factory', 'load_training_service_config',
'get_ipv4_address'
]
## handle relative path ##
_current_base_path = None
......
......@@ -10,6 +10,12 @@ import math
from pathlib import Path
from typing import Union
__all__ = [
'PathLike', 'is_missing',
'canonical_gpu_indices', 'validate_gpu_indices',
'parse_time', 'parse_memory_size'
]
PathLike = Union[Path, str]
def is_missing(value):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
import json
from typing import List
......@@ -111,6 +114,8 @@ class TrialJob:
Stderr log path.
sequenceId: int
Sequence Id.
message: str
Message including platform/environment.
"""
trialJobId: str
status: str
......@@ -121,9 +126,11 @@ class TrialJob:
finalMetricData: List[TrialMetricData]
stderrPath: str
sequenceId: int
message: str
def __init__(self, trialJobId: str, status: str, logPath: str, startTime: int, sequenceId: int,
endTime: int = -1, stderrPath: str = '', hyperParameters: List = [], finalMetricData: List = []):
def __init__(self, trialJobId: str, status: str, startTime: int, sequenceId: int, logPath: str = '',
endTime: int = -1, stderrPath: str = '', hyperParameters: List = [], finalMetricData: List = [],
message: str = '--'):
self.trialJobId = trialJobId
self.status = status
self.hyperParameters = [TrialHyperParameters(**json.loads(e)) for e in hyperParameters]
......@@ -133,3 +140,4 @@ class TrialJob:
self.finalMetricData = [TrialMetricData(**e) for e in finalMetricData]
self.stderrPath = stderrPath
self.sequenceId = sequenceId
self.message = message
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import atexit
from enum import Enum
import logging
......@@ -5,7 +10,7 @@ from pathlib import Path
import socket
from subprocess import Popen
import time
from typing import Optional, Union, List, overload, Any
from typing import Any
import colorama
import psutil
......@@ -25,80 +30,61 @@ class RunMode(Enum):
"""
Config lifecycle and ouput redirection of NNI manager process.
- Background: stop NNI manager when Python script exits; do not print NNI manager log. (default)
- Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout.
- Detach: do not stop NNI manager when Python script exits.
- Background: stop NNI manager when Python script exits; do not print NNI manager log. (default)
- Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout.
- Detach: do not stop NNI manager when Python script exits.
NOTE:
This API is non-stable and is likely to get refactored in next release.
NNI manager should treat log level more seriously so we can default to "foreground" without being too verbose.
NOTE: This API is non-stable and is likely to get refactored in upcoming release.
"""
# TODO:
# NNI manager should treat log level more seriously so we can default to "foreground" without being too verbose.
Background = 'background'
Foreground = 'foreground'
Detach = 'detach'
class Experiment:
"""
Create and stop an NNI experiment.
Manage NNI experiment.
You can either specify an :class:`ExperimentConfig` object, or a training service name.
If a platform name is used, a blank config template for that training service will be generated.
When configuration is completed, use :meth:`Experiment.run` to launch the experiment.
Example
-------
.. code-block::
experiment = Experiment('remote')
experiment.config.trial_command = 'python3 trial.py'
experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...))
...
experiment.run(8080)
Attributes
----------
config
Experiment configuration.
id
Experiment ID.
port
Web UI port of the experiment, or `None` if it is not running.
Web portal port. Or ``None`` if the experiment is not running.
"""
@overload
def __init__(self, config: ExperimentConfig) -> None:
"""
Prepare an experiment.
Use `Experiment.run()` to launch it.
Parameters
----------
config
Experiment configuration.
"""
...
@overload
def __init__(self, training_service: Union[str, List[str]]) -> None:
"""
Prepare an experiment, leaving configuration fields to be set later.
Example usage::
experiment = Experiment('remote')
experiment.config.trial_command = 'python3 trial.py'
experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...))
...
experiment.run(8080)
Parameters
----------
training_service
Name of training service.
Supported value: "local", "remote", "openpai", "aml", "kubeflow", "frameworkcontroller", "adl" and hybrid training service.
"""
...
def __init__(self, config=None, training_service=None):
def __init__(self, config_or_platform: ExperimentConfig | str | list[str] | None):
nni.runtime.log.init_logger_for_command_line()
self.config: Optional[ExperimentConfig] = None
self.config: ExperimentConfig | None = None
self.id: str = management.generate_experiment_id()
self.port: Optional[int] = None
self._proc: Optional[Popen] = None
self.mode = 'new'
self.url_prefix: Optional[str] = None
args = [config, training_service] # deal with overloading
if isinstance(args[0], (str, list)):
self.config = ExperimentConfig(args[0])
self.port: int | None = None
self._proc: Popen | psutil.Process | None = None
self._action = 'create'
self.url_prefix: str | None = None
if isinstance(config_or_platform, (str, list)):
self.config = ExperimentConfig(config_or_platform)
else:
self.config = args[0]
self.config = config_or_platform
def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None:
"""
......@@ -114,6 +100,7 @@ class Experiment:
debug
Whether to start in debug mode.
"""
assert self.config is not None
if run_mode is not RunMode.Detach:
atexit.register(self.stop)
......@@ -127,7 +114,7 @@ class Experiment:
log_dir = Path.home() / f'nni-experiments/{self.id}/log'
nni.runtime.log.start_experiment_log(self.id, log_dir, debug)
self._proc = launcher.start_experiment(self.mode, self.id, config, port, debug, run_mode, self.url_prefix)
self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode, self.url_prefix)
assert self._proc is not None
self.port = port # port will be None if start up failed
......@@ -138,12 +125,12 @@ class Experiment:
if interface.family == socket.AF_INET:
ips.append(interface.address)
ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL
msg = 'Web portal URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL
_logger.info(msg)
def stop(self) -> None:
"""
Stop background experiment.
Stop the experiment.
"""
_logger.info('Stopping experiment, please wait...')
atexit.unregister(self.stop)
......@@ -157,20 +144,20 @@ class Experiment:
_logger.warning('Cannot gracefully stop experiment, killing NNI process...')
kill_command(self._proc.pid)
self.id = None
self.id = None # type: ignore
self.port = None
self._proc = None
_logger.info('Experiment stopped')
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool:
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None:
"""
Run the experiment.
If wait_completion is True, this function will block until experiment finish or error.
If ``wait_completion`` is ``True``, this function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
Return ``True`` when experiment done; or return ``False`` when experiment failed.
Else if wait_completion is False, this function will non-block and return None immediately.
Else if ``wait_completion`` is ``False``, this function will non-block and return None immediately.
"""
self.start(port, debug)
if wait_completion:
......@@ -184,7 +171,6 @@ class Experiment:
return False
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
finally:
self.stop()
@classmethod
......@@ -197,7 +183,7 @@ class Experiment:
port
The port of web UI.
"""
experiment = Experiment()
experiment = Experiment(None)
experiment.port = port
experiment.id = experiment.get_experiment_profile().get('id')
status = experiment.get_status()
......@@ -259,17 +245,17 @@ class Experiment:
@staticmethod
def _resume(exp_id, exp_dir=None):
exp = Experiment()
exp = Experiment(None)
exp.id = exp_id
exp.mode = 'resume'
exp._action = 'resume'
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
return exp
@staticmethod
def _view(exp_id, exp_dir=None):
exp = Experiment()
exp = Experiment(None)
exp.id = exp_id
exp.mode = 'view'
exp._action = 'view'
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
return exp
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import contextlib
from dataclasses import dataclass, fields
from datetime import datetime
......@@ -27,23 +29,27 @@ _logger = logging.getLogger('nni.experiment')
@dataclass(init=False)
class NniManagerArgs:
# argv sent to "ts/nni_manager/main.js"
port: int
experiment_id: int
start_mode: str # new or resume
mode: str # training service platform
log_dir: str
action: str # 'new', 'resume', 'view'
mode: str # training service platform, to be removed
experiments_directory: str # renamed "config.nni_experiments_directory", must be absolute
log_level: str
readonly: bool = False
foreground: bool = False
url_prefix: Optional[str] = None
url_prefix: Optional[str] = None # leading and trailing "/" must be stripped
dispatcher_pipe: Optional[str] = None
def __init__(self, action, exp_id, config, port, debug, foreground, url_prefix):
self.port = port
self.experiment_id = exp_id
self.action = action
self.foreground = foreground
self.url_prefix = url_prefix
self.log_dir = config.experiment_working_directory
# config field name "experiment_working_directory" is a mistake
# see "ts/nni_manager/common/globals/arguments.ts" for details
self.experiments_directory = config.experiment_working_directory
if isinstance(config.training_service, list):
self.mode = 'hybrid'
......@@ -54,20 +60,14 @@ class NniManagerArgs:
if debug and self.log_level not in ['debug', 'trace']:
self.log_level = 'debug'
if action == 'resume':
self.start_mode = 'resume'
elif action == 'view':
self.start_mode = 'resume'
self.readonly = True
else:
self.start_mode = 'new'
def to_command_line_args(self):
# reformat fields to meet yargs library's format
# see "ts/nni_manager/common/globals/arguments.ts" for details
ret = []
for field in fields(self):
value = getattr(self, field.name)
if value is not None:
ret.append('--' + field.name)
ret.append('--' + field.name.replace('_', '-'))
if isinstance(value, bool):
ret.append(str(value).lower())
else:
......@@ -76,6 +76,8 @@ class NniManagerArgs:
def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
foreground = run_mode.value == 'foreground'
if url_prefix is not None:
url_prefix = url_prefix.strip('/')
nni_manager_args = NniManagerArgs(action, exp_id, config, port, debug, foreground, url_prefix)
_ensure_port_idle(port)
......@@ -118,7 +120,11 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
link = Path(config.experiment_working_directory, '_latest')
try:
link.unlink(missing_ok=True)
if sys.version_info >= (3, 8):
link.unlink(missing_ok=True)
else:
if link.exists():
link.unlink()
link.symlink_to(exp_id, target_is_directory=True)
except Exception:
if sys.platform != 'win32':
......@@ -126,16 +132,16 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
return proc
def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]:
def _start_rest_server(nni_manager_args, run_mode) -> Popen:
import nni_node
node_dir = Path(nni_node.__path__[0])
node_dir = Path(nni_node.__path__[0]) # type: ignore
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js]
cmd += nni_manager_args.to_command_line_args()
if run_mode.value == 'detach':
log = Path(nni_manager_args.log_dir, nni_manager_args.experiment_id, 'log')
log = Path(nni_manager_args.experiments_directory, nni_manager_args.experiment_id, 'log')
out = (log / 'nnictl_stdout.log').open('a')
err = (log / 'nnictl_stderr.log').open('a')
header = f'Experiment {nni_manager_args.experiment_id} start: {datetime.now()}'
......@@ -151,10 +157,10 @@ def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]:
from subprocess import CREATE_NEW_PROCESS_GROUP
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP)
else:
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp)
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp) # type: ignore
def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Popen:
def start_experiment_retiarii(exp_id, config, port, debug):
pipe = None
proc = None
......@@ -201,7 +207,7 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
def _start_rest_server_retiarii(config: ExperimentConfig, port: int, debug: bool, experiment_id: str,
pipe_path: str = None, mode: str = 'new') -> Tuple[int, Popen]:
pipe_path: str, mode: str = 'create') -> Tuple[int, Popen]:
if isinstance(config.training_service, list):
ts = 'hybrid'
else:
......@@ -213,24 +219,20 @@ def _start_rest_server_retiarii(config: ExperimentConfig, port: int, debug: bool
'port': port,
'mode': ts,
'experiment_id': experiment_id,
'start_mode': mode,
'log_dir': config.experiment_working_directory,
'action': mode,
'experiments_directory': config.experiment_working_directory,
'log_level': 'debug' if debug else 'info'
}
if pipe_path is not None:
args['dispatcher_pipe'] = pipe_path
if mode == 'view':
args['start_mode'] = 'resume'
args['readonly'] = 'true'
import nni_node
node_dir = Path(nni_node.__path__[0])
node_dir = Path(nni_node.__path__[0]) # type: ignore
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js]
for arg_key, arg_value in args.items():
cmd.append('--' + arg_key)
cmd.append('--' + arg_key.replace('_', '-'))
cmd.append(str(arg_value))
if sys.platform == 'win32':
......@@ -263,8 +265,8 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int,
def get_stopped_experiment_config(exp_id, exp_dir=None):
config_json = get_stopped_experiment_config_json(exp_id, exp_dir)
config = ExperimentConfig(**config_json)
config_json = get_stopped_experiment_config_json(exp_id, exp_dir) # type: ignore
config = ExperimentConfig(**config_json) # type: ignore
if exp_dir and not os.path.samefile(exp_dir, config.experiment_working_directory):
msg = 'Experiment working directory provided in command line (%s) is different from experiment config (%s)'
_logger.warning(msg, exp_dir, config.experiment_working_directory)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path
import random
import string
......
from io import BufferedIOBase
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import sys
......@@ -25,7 +27,7 @@ if sys.platform == 'win32':
_winapi.NULL
)
def connect(self) -> BufferedIOBase:
def connect(self):
_winapi.ConnectNamedPipe(self._handle, _winapi.NULL)
fd = msvcrt.open_osfhandle(self._handle, 0)
self.file = os.fdopen(fd, 'w+b')
......@@ -55,7 +57,7 @@ else:
self._socket.bind(self.path)
self._socket.listen(1) # only accepts one connection
def connect(self) -> BufferedIOBase:
def connect(self):
conn, _ = self._socket.accept()
self.file = conn.makefile('rwb')
return self.file
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Any, Optional
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import os
class Recoverable:
def load_checkpoint(self):
def load_checkpoint(self) -> None:
pass
def save_checkpoint(self):
def save_checkpoint(self) -> None:
pass
def get_checkpoint_path(self):
def get_checkpoint_path(self) -> str | None:
ckp_path = os.getenv('NNI_CHECKPOINT_DIRECTORY')
if ckp_path is not None and os.path.isdir(ckp_path):
return ckp_path
......
......@@ -6,6 +6,7 @@ import re
from typing import Dict, List, Tuple, Any
from nni.retiarii.operation_def.torch_op_def import ToDevice
from nni.retiarii.utils import STATE_DICT_PY_MAPPING
from nni.common.device import Device, GPUDevice
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
......@@ -97,7 +98,18 @@ def _format_variable_name(name: str, graph_name: str) -> str:
name = name.replace('/', '__')
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
return re.sub('\W|^(?=\d)','_', name)
name = re.sub('\W|^(?=\d)','_', name)
if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
# name can't start with double underscore
# it's reserved in Python: https://stackoverflow.com/a/1301409/6837658
# but it's actually very common in our generated code
name = name[1:]
elif name.startswith('_'):
# to avoid conflicts between '_' and '__'
name = 'i' + name
return name
def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
......@@ -125,6 +137,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
# only need to generate code for module here
import_pkgs = set()
node_codes = []
node_python_mappings = {}
cuda_remapped_id = None
if placement:
cuda_remapped_id = generate_cuda_mapping(placement)
......@@ -138,7 +151,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
pkg_name = node.operation.get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
node_code = node.operation.to_init_code(_format_variable_name(node.name, graph_name))
py_variable_name = _format_variable_name(node.name, graph_name)
node_code = node.operation.to_init_code(py_variable_name)
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
if isinstance(placement[node], GPUDevice):
......@@ -149,6 +164,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
else:
node_codes.append(node_code)
# Map to module hierarchies in original search space python code
node_python_mappings[py_variable_name] = node.python_name
node_codes.append(f'self.{STATE_DICT_PY_MAPPING} = {node_python_mappings}')
if graph.input_node.operation.io_names is None:
input_code = '*_inputs'
else:
......
......@@ -660,8 +660,9 @@ class GraphConverter:
attrs = {
'mutation': 'repeat',
'label': module.label,
'depth': module.depth_choice,
'max_depth': module.max_depth,
'min_depth': module.min_depth,
'max_depth': module.max_depth
}
return ir_graph, attrs
......@@ -695,15 +696,17 @@ class GraphConverter:
class GraphConverterWithShape(GraphConverter):
"""
Convert a pytorch model to nni ir along with input/output shape info.
Based ir acquired through `torch.jit.script`
and shape info acquired through `torch.jit.trace`.
Known issues
------------
1. `InputChoice` and `ValueChoice` not supported yet.
2. Currently random inputs are fed while tracing layerchoice.
If forward path of candidates depends on input data, then wrong path will be traced.
This will result in incomplete shape info.
Based ir acquired through ``torch.jit.script``
and shape info acquired through ``torch.jit.trace``.
.. warning::
Known issues:
1. ``InputChoice`` and ``ValueChoice`` not supported yet.
2. Currently random inputs are fed while tracing layerchoice.
If forward path of candidates depends on input data, then wrong path will be traced.
This will result in incomplete shape info.
"""
def convert_module(self, script_module, module, module_name, ir_model, dummy_input):
module.eval()
......@@ -815,7 +818,7 @@ class GraphConverterWithShape(GraphConverter):
def convert_to_graph(script_module, module, converter=None, **kwargs):
"""
Convert module to our graph ir, i.e., build a ```Model``` type
Convert module to our graph ir, i.e., build a :class:`Model` type
Parameters
----------
......
......@@ -29,8 +29,20 @@ class _MultiModelSupervisedLearningModule(LightningModule):
self.criterion_cls = criterion
self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
self.metrics_args = metrics
self.n_models = n_models
def dump_kwargs(self):
kwargs = {}
kwargs['criterion'] = self.criterion_cls
kwargs['metrics'] = self.metrics_args
kwargs['n_models'] = self.n_models
kwargs['learning_rate'] = self.hparams['learning_rate']
kwargs['weight_decay'] = self.hparams['weight_decay']
kwargs['optimizer'] = self.optimizer
return kwargs
def forward(self, x):
y_hat = self.model(x)
return y_hat
......@@ -125,8 +137,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
super().__init__(criterion, metrics, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
@nni.trace
class _ClassificationModule(MultiModelSupervisedLearningModule):
class _ClassificationModule(_MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
......@@ -157,7 +168,7 @@ class Classification(Lightning):
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
......@@ -172,9 +183,7 @@ class Classification(Lightning):
super().__init__(module, Trainer(use_cgo=True, **trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@nni.trace
class _RegressionModule(MultiModelSupervisedLearningModule):
class _RegressionModule(_MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
......@@ -205,7 +214,7 @@ class Regression(Lightning):
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.MSELoss,
......
......@@ -2,11 +2,8 @@
# Licensed under the MIT license.
import pytorch_lightning as pl
import nni
from .accelerator import BypassAccelerator
@nni.trace
class Trainer(pl.Trainer):
"""
Trainer for cross-graph optimization.
......@@ -20,7 +17,7 @@ class Trainer(pl.Trainer):
default: False
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, use_cgo=False, **trainer_kwargs):
......
......@@ -4,7 +4,7 @@
import os
import warnings
from pathlib import Path
from typing import Dict, Union, Optional, List, Type
from typing import Dict, Union, Optional, List, Callable
import pytorch_lightning as pl
import torch.nn as nn
......@@ -29,11 +29,20 @@ __all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classificat
class LightningModule(pl.LightningModule):
"""
Basic wrapper of generated model.
Lightning modules used in NNI should inherit this class.
It's a subclass of ``pytorch_lightning.LightningModule``.
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
"""
def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> None:
def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None:
"""Set the inner model (architecture) to train / evaluate.
Parameters
----------
model : callable or nn.Module
Can be a callable returning nn.Module or nn.Module.
"""
if isinstance(model, nn.Module):
self.model = model
else:
......@@ -41,7 +50,13 @@ class LightningModule(pl.LightningModule):
Trainer = nni.trace(pl.Trainer)
Trainer.__doc__ = """
Traced version of ``pytorch_lightning.Trainer``. See https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
"""
DataLoader = nni.trace(torch_data.DataLoader)
DataLoader.__doc__ = """
Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html
"""
@nni.trace
class Lightning(Evaluator):
......@@ -162,7 +177,6 @@ class _SupervisedLearningModule(LightningModule):
self.export_onnx = Path(export_onnx)
else:
self.export_onnx = None
self._already_exported = False
def forward(self, x):
y_hat = self.model(x)
......@@ -181,12 +195,12 @@ class _SupervisedLearningModule(LightningModule):
x, y = batch
y_hat = self(x)
if not self._already_exported:
if self.export_onnx is not None:
try:
self.to_onnx(self.export_onnx, x, export_params=True)
except RuntimeError as e:
warnings.warn(f'ONNX conversion failed. As a result, you might not be able to use visualization. Error message: {e}')
self._already_exported = True
self.export_onnx = None
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items():
......@@ -236,7 +250,7 @@ class _ClassificationModule(_SupervisedLearningModule):
class Classification(Lightning):
"""
Trainer that is used for classification.
Evaluator that is used for classification.
Parameters
----------
......@@ -289,7 +303,7 @@ class _RegressionModule(_SupervisedLearningModule):
class Regression(Lightning):
"""
Trainer that is used for regression.
Evaluator that is used for regression.
Parameters
----------
......
......@@ -21,7 +21,9 @@ def set_execution_engine(engine: AbstractExecutionEngine) -> None:
if _execution_engine is None:
_execution_engine = engine
else:
raise RuntimeError('Execution engine is already set.')
raise RuntimeError('Execution engine is already set. '
'You should avoid instantiating RetiariiExperiment twice in one process. '
'If you are running in a Jupyter notebook, please restart the kernel.')
def get_execution_engine() -> AbstractExecutionEngine:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment