Unverified Commit 443ba8c1 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Serialization infrastructure V2 (#4337)

parent 896c516f
...@@ -114,7 +114,9 @@ CGO Execution ...@@ -114,7 +114,9 @@ CGO Execution
Utilities Utilities
--------- ---------
.. autofunction:: nni.retiarii.serialize .. autofunction:: nni.retiarii.basic_unit
.. autofunction:: nni.retiarii.model_wrapper
.. autofunction:: nni.retiarii.fixed_arch .. autofunction:: nni.retiarii.fixed_arch
......
...@@ -78,3 +78,9 @@ Utilities ...@@ -78,3 +78,9 @@ Utilities
--------- ---------
.. autofunction:: nni.utils.merge_parameter .. autofunction:: nni.utils.merge_parameter
.. autofunction:: nni.trace
.. autofunction:: nni.dump
.. autofunction:: nni.load
...@@ -3,7 +3,7 @@ import nni ...@@ -3,7 +3,7 @@ import nni
import nni.retiarii.evaluator.pytorch.lightning as pl import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn as nn import torch.nn as nn
import torchmetrics import torchmetrics
from nni.retiarii import model_wrapper, serialize, serialize_cls from nni.retiarii import model_wrapper, serialize
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.nn.pytorch import NasBench101Cell from nni.retiarii.nn.pytorch import NasBench101Cell
from nni.retiarii.strategy import Random from nni.retiarii.strategy import Random
...@@ -82,7 +82,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy): ...@@ -82,7 +82,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy):
return super().update(nn.functional.softmax(pred), target) return super().update(nn.functional.softmax(pred), target)
@serialize_cls @nni.trace
class NasBench101TrainingModule(pl.LightningModule): class NasBench101TrainingModule(pl.LightningModule):
def __init__(self, max_epochs=108, learning_rate=0.1, weight_decay=1e-4): def __init__(self, max_epochs=108, learning_rate=0.1, weight_decay=1e-4):
super().__init__() super().__init__()
......
...@@ -3,7 +3,7 @@ import nni ...@@ -3,7 +3,7 @@ import nni
import nni.retiarii.evaluator.pytorch.lightning as pl import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn as nn import torch.nn as nn
import torchmetrics import torchmetrics
from nni.retiarii import model_wrapper, serialize, serialize_cls from nni.retiarii import model_wrapper, serialize
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.nn.pytorch import NasBench201Cell from nni.retiarii.nn.pytorch import NasBench201Cell
from nni.retiarii.strategy import Random from nni.retiarii.strategy import Random
...@@ -71,7 +71,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy): ...@@ -71,7 +71,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy):
return super().update(nn.functional.softmax(pred), target) return super().update(nn.functional.softmax(pred), target)
@serialize_cls @nni.trace
class NasBench201TrainingModule(pl.LightningModule): class NasBench201TrainingModule(pl.LightningModule):
def __init__(self, max_epochs=200, learning_rate=0.1, weight_decay=5e-4): def __init__(self, max_epochs=200, learning_rate=0.1, weight_decay=5e-4):
super().__init__() super().__init__()
......
...@@ -9,7 +9,7 @@ except ModuleNotFoundError: ...@@ -9,7 +9,7 @@ except ModuleNotFoundError:
from .runtime.log import init_logger from .runtime.log import init_logger
init_logger() init_logger()
from .common.serializer import * from .common.serializer import trace, dump, load
from .runtime.env_vars import dispatcher_env_vars from .runtime.env_vars import dispatcher_env_vars
from .utils import ClassArgsValidator from .utils import ClassArgsValidator
......
...@@ -7,12 +7,12 @@ bohb_advisor.py ...@@ -7,12 +7,12 @@ bohb_advisor.py
import sys import sys
import math import math
import logging import logging
import json_tricks
from schema import Schema, Optional from schema import Schema, Optional
import ConfigSpace as CS import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH import ConfigSpace.hyperparameters as CSH
from ConfigSpace.read_and_write import pcs_new from ConfigSpace.read_and_write import pcs_new
import nni
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.runtime.protocol import CommandType, send from nni.runtime.protocol import CommandType, send
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
...@@ -428,7 +428,7 @@ class BOHB(MsgDispatcherBase): ...@@ -428,7 +428,7 @@ class BOHB(MsgDispatcherBase):
'parameter_source': 'algorithm', 'parameter_source': 'algorithm',
'parameters': '' 'parameters': ''
} }
send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret)) send(CommandType.NoMoreTrialJobs, nni.dump(ret))
return None return None
assert self.generated_hyper_configs assert self.generated_hyper_configs
params = self.generated_hyper_configs.pop(0) params = self.generated_hyper_configs.pop(0)
...@@ -459,7 +459,7 @@ class BOHB(MsgDispatcherBase): ...@@ -459,7 +459,7 @@ class BOHB(MsgDispatcherBase):
""" """
ret = self._get_one_trial_job() ret = self._get_one_trial_job()
if ret is not None: if ret is not None:
send(CommandType.NewTrialJob, json_tricks.dumps(ret)) send(CommandType.NewTrialJob, nni.dump(ret))
self.credit -= 1 self.credit -= 1
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
...@@ -536,7 +536,7 @@ class BOHB(MsgDispatcherBase): ...@@ -536,7 +536,7 @@ class BOHB(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner hyper_params: the hyperparameters (a string) generated and returned by tuner
""" """
logger.debug('Tuner handle trial end, result is %s', data) logger.debug('Tuner handle trial end, result is %s', data)
hyper_params = json_tricks.loads(data['hyper_params']) hyper_params = nni.load(data['hyper_params'])
self._handle_trial_end(hyper_params['parameter_id']) self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map: if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']] del self.job_id_para_id_map[data['trial_job_id']]
...@@ -551,7 +551,7 @@ class BOHB(MsgDispatcherBase): ...@@ -551,7 +551,7 @@ class BOHB(MsgDispatcherBase):
ret['parameter_index'] = one_unsatisfied['parameter_index'] ret['parameter_index'] = one_unsatisfied['parameter_index']
# update parameter_id in self.job_id_para_id_map # update parameter_id in self.job_id_para_id_map
self.job_id_para_id_map[ret['trial_job_id']] = ret['parameter_id'] self.job_id_para_id_map[ret['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret)) send(CommandType.SendTrialJobParameter, nni.dump(ret))
for _ in range(self.credit): for _ in range(self.credit):
self._request_one_trial_job() self._request_one_trial_job()
...@@ -584,7 +584,7 @@ class BOHB(MsgDispatcherBase): ...@@ -584,7 +584,7 @@ class BOHB(MsgDispatcherBase):
""" """
logger.debug('handle report metric data = %s', data) logger.debug('handle report metric data = %s', data)
if 'value' in data: if 'value' in data:
data['value'] = json_tricks.loads(data['value']) data['value'] = nni.load(data['value'])
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled() assert multi_phase_enabled()
assert data['trial_job_id'] is not None assert data['trial_job_id'] is not None
...@@ -599,7 +599,7 @@ class BOHB(MsgDispatcherBase): ...@@ -599,7 +599,7 @@ class BOHB(MsgDispatcherBase):
ret['parameter_index'] = data['parameter_index'] ret['parameter_index'] = data['parameter_index']
# update parameter_id in self.job_id_para_id_map # update parameter_id in self.job_id_para_id_map
self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id'] self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret)) send(CommandType.SendTrialJobParameter, nni.dump(ret))
else: else:
assert 'value' in data assert 'value' in data
value = extract_scalar_reward(data['value']) value = extract_scalar_reward(data['value'])
...@@ -655,7 +655,7 @@ class BOHB(MsgDispatcherBase): ...@@ -655,7 +655,7 @@ class BOHB(MsgDispatcherBase):
data doesn't have required key 'parameter' and 'value' data doesn't have required key 'parameter' and 'value'
""" """
for entry in data: for entry in data:
entry['value'] = json_tricks.loads(entry['value']) entry['value'] = nni.load(entry['value'])
_completed_num = 0 _completed_num = 0
for trial_info in data: for trial_info in data:
logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data)) logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
......
...@@ -10,10 +10,10 @@ import logging ...@@ -10,10 +10,10 @@ import logging
import math import math
import sys import sys
import json_tricks
import numpy as np import numpy as np
from schema import Schema, Optional from schema import Schema, Optional
import nni
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space from nni.common.hpo_utils import validate_search_space
from nni.runtime.common import multi_phase_enabled from nni.runtime.common import multi_phase_enabled
...@@ -336,7 +336,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -336,7 +336,7 @@ class Hyperband(MsgDispatcherBase):
def _request_one_trial_job(self): def _request_one_trial_job(self):
ret = self._get_one_trial_job() ret = self._get_one_trial_job()
if ret is not None: if ret is not None:
send(CommandType.NewTrialJob, json_tricks.dumps(ret)) send(CommandType.NewTrialJob, nni.dump(ret))
self.credit -= 1 self.credit -= 1
def _get_one_trial_job(self): def _get_one_trial_job(self):
...@@ -365,7 +365,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -365,7 +365,7 @@ class Hyperband(MsgDispatcherBase):
'parameter_source': 'algorithm', 'parameter_source': 'algorithm',
'parameters': '' 'parameters': ''
} }
send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret)) send(CommandType.NoMoreTrialJobs, nni.dump(ret))
return None return None
assert self.generated_hyper_configs assert self.generated_hyper_configs
...@@ -408,7 +408,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -408,7 +408,7 @@ class Hyperband(MsgDispatcherBase):
event: the job's state event: the job's state
hyper_params: the hyperparameters (a string) generated and returned by tuner hyper_params: the hyperparameters (a string) generated and returned by tuner
""" """
hyper_params = json_tricks.loads(data['hyper_params']) hyper_params = nni.load(data['hyper_params'])
self._handle_trial_end(hyper_params['parameter_id']) self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map: if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']] del self.job_id_para_id_map[data['trial_job_id']]
...@@ -426,7 +426,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -426,7 +426,7 @@ class Hyperband(MsgDispatcherBase):
Data type not supported Data type not supported
""" """
if 'value' in data: if 'value' in data:
data['value'] = json_tricks.loads(data['value']) data['value'] = nni.load(data['value'])
# multiphase? need to check # multiphase? need to check
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled() assert multi_phase_enabled()
...@@ -440,7 +440,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -440,7 +440,7 @@ class Hyperband(MsgDispatcherBase):
if data['parameter_index'] is not None: if data['parameter_index'] is not None:
ret['parameter_index'] = data['parameter_index'] ret['parameter_index'] = data['parameter_index']
self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id'] self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret)) send(CommandType.SendTrialJobParameter, nni.dump(ret))
else: else:
value = extract_scalar_reward(data['value']) value = extract_scalar_reward(data['value'])
bracket_id, i, _ = data['parameter_id'].split('_') bracket_id, i, _ = data['parameter_id'].split('_')
......
from .serializer import trace, dump, load, is_traceable
This diff is collapsed.
...@@ -6,11 +6,11 @@ from subprocess import Popen ...@@ -6,11 +6,11 @@ from subprocess import Popen
import time import time
from typing import Optional, Union, List, overload, Any from typing import Optional, Union, List, overload, Any
import json_tricks
import colorama import colorama
import psutil import psutil
import nni.runtime.log import nni.runtime.log
from nni.common import dump
from .config import ExperimentConfig, AlgorithmConfig from .config import ExperimentConfig, AlgorithmConfig
from .data import TrialJob, TrialMetricData, TrialResult from .data import TrialJob, TrialMetricData, TrialResult
...@@ -439,7 +439,7 @@ class Experiment: ...@@ -439,7 +439,7 @@ class Experiment:
value: dict value: dict
New search_space. New search_space.
""" """
value = json_tricks.dumps(value) value = dump(value)
self._update_experiment_profile('searchSpace', value) self._update_experiment_profile('searchSpace', value)
def update_max_trial_number(self, value: int): def update_max_trial_number(self, value: int):
......
...@@ -6,4 +6,4 @@ from .graph import * ...@@ -6,4 +6,4 @@ from .graph import *
from .execution import * from .execution import *
from .fixed import fixed_arch from .fixed import fixed_arch
from .mutator import * from .mutator import *
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper from .serializer import basic_unit, model_wrapper, serialize, serialize_cls
...@@ -637,7 +637,7 @@ class GraphConverter: ...@@ -637,7 +637,7 @@ class GraphConverter:
original_type_name not in MODULE_EXCEPT_LIST: original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph # this is a basic module from pytorch, no need to parse its graph
m_attrs = get_init_parameters_or_fail(module) m_attrs = get_init_parameters_or_fail(module)
elif getattr(module, '_stop_parsing', False): elif getattr(module, '_nni_basic_unit', False):
# this module is marked as serialize, won't continue to parse # this module is marked as serialize, won't continue to parse
m_attrs = get_init_parameters_or_fail(module) m_attrs = get_init_parameters_or_fail(module)
if m_attrs is not None: if m_attrs is not None:
......
...@@ -10,7 +10,7 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import Trainin ...@@ -10,7 +10,7 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import Trainin
from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from ....serializer import serialize_cls import nni
class BypassPlugin(TrainingTypePlugin): class BypassPlugin(TrainingTypePlugin):
...@@ -126,7 +126,7 @@ def get_accelerator_connector( ...@@ -126,7 +126,7 @@ def get_accelerator_connector(
) )
@serialize_cls @nni.trace
class BypassAccelerator(Accelerator): class BypassAccelerator(Accelerator):
def __init__(self, precision_plugin=None, device="cpu", **trainer_kwargs): def __init__(self, precision_plugin=None, device="cpu", **trainer_kwargs):
if precision_plugin is None: if precision_plugin is None:
......
...@@ -14,10 +14,9 @@ import nni ...@@ -14,10 +14,9 @@ import nni
from ..lightning import LightningModule, _AccuracyWithLogits, Lightning from ..lightning import LightningModule, _AccuracyWithLogits, Lightning
from .trainer import Trainer from .trainer import Trainer
from ....serializer import serialize_cls
@serialize_cls @nni.trace
class _MultiModelSupervisedLearningModule(LightningModule): class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric], def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
n_models: int = 0, n_models: int = 0,
...@@ -126,7 +125,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule): ...@@ -126,7 +125,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
super().__init__(criterion, metrics, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer) super().__init__(criterion, metrics, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
@serialize_cls @nni.trace
class _ClassificationModule(MultiModelSupervisedLearningModule): class _ClassificationModule(MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss, def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
...@@ -174,7 +173,7 @@ class Classification(Lightning): ...@@ -174,7 +173,7 @@ class Classification(Lightning):
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@serialize_cls @nni.trace
class _RegressionModule(MultiModelSupervisedLearningModule): class _RegressionModule(MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss, def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import pytorch_lightning as pl import pytorch_lightning as pl
from ....serializer import serialize_cls import nni
from .accelerator import BypassAccelerator from .accelerator import BypassAccelerator
@serialize_cls @nni.trace
class Trainer(pl.Trainer): class Trainer(pl.Trainer):
""" """
Trainer for cross-graph optimization. Trainer for cross-graph optimization.
......
...@@ -10,17 +10,17 @@ import pytorch_lightning as pl ...@@ -10,17 +10,17 @@ import pytorch_lightning as pl
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torchmetrics import torchmetrics
from torch.utils.data import DataLoader import torch.utils.data as torch_data
import nni import nni
from nni.common.serializer import is_traceable
try: try:
from .cgo import trainer as cgo_trainer from .cgo import trainer as cgo_trainer
cgo_import_failed = False cgo_import_failed = False
except ImportError: except ImportError:
cgo_import_failed = True cgo_import_failed = True
from ...graph import Evaluator from nni.retiarii.graph import Evaluator
from ...serializer import serialize_cls
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression'] __all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression']
...@@ -40,9 +40,10 @@ class LightningModule(pl.LightningModule): ...@@ -40,9 +40,10 @@ class LightningModule(pl.LightningModule):
self.model = model self.model = model
Trainer = serialize_cls(pl.Trainer) Trainer = nni.trace(pl.Trainer)
DataLoader = serialize_cls(DataLoader) DataLoader = nni.trace(torch_data.DataLoader)
@nni.trace
class Lightning(Evaluator): class Lightning(Evaluator):
""" """
Delegate the whole training to PyTorch Lightning. Delegate the whole training to PyTorch Lightning.
...@@ -74,9 +75,10 @@ class Lightning(Evaluator): ...@@ -74,9 +75,10 @@ class Lightning(Evaluator):
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None): val_dataloaders: Union[DataLoader, List[DataLoader], None] = None):
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.' assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
if cgo_import_failed: if cgo_import_failed:
assert isinstance(trainer, Trainer), f'Trainer must be imported from {__name__}' assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}'
else: else:
assert isinstance(trainer, Trainer) or isinstance(trainer, cgo_trainer.Trainer), \ # this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer' f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.' assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.' assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
...@@ -135,7 +137,7 @@ def _check_dataloader(dataloader): ...@@ -135,7 +137,7 @@ def _check_dataloader(dataloader):
return True return True
if isinstance(dataloader, list): if isinstance(dataloader, list):
return all([_check_dataloader(d) for d in dataloader]) return all([_check_dataloader(d) for d in dataloader])
return isinstance(dataloader, DataLoader) return isinstance(dataloader, torch_data.DataLoader) and is_traceable(dataloader)
### The following are some commonly used Lightning modules ### ### The following are some commonly used Lightning modules ###
...@@ -219,7 +221,7 @@ class _AccuracyWithLogits(torchmetrics.Accuracy): ...@@ -219,7 +221,7 @@ class _AccuracyWithLogits(torchmetrics.Accuracy):
return super().update(nn.functional.softmax(pred), target) return super().update(nn.functional.softmax(pred), target)
@serialize_cls @nni.trace
class _ClassificationModule(_SupervisedLearningModule): class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss, def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
...@@ -272,7 +274,7 @@ class Classification(Lightning): ...@@ -272,7 +274,7 @@ class Classification(Lightning):
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@serialize_cls @nni.trace
class _RegressionModule(_SupervisedLearningModule): class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss, def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
......
...@@ -200,7 +200,7 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -200,7 +200,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
# replace the module with a new instance whose n_models is set # replace the module with a new instance whose n_models is set
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls # n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
new_module_init_params = model.evaluator.module._init_parameters.copy() new_module_init_params = model.evaluator.module.trace_kwargs.copy()
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users # MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
new_module_init_params['n_models'] = len(multi_model) new_module_init_params['n_models'] = len(multi_model)
......
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
import logging import logging
from typing import Any, Callable from typing import Any, Callable
import nni
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType from nni.utils import MetricType
from .graph import MetricData from .graph import MetricData
from .integration_api import register_advisor from .integration_api import register_advisor
from .serializer import json_dumps, json_loads
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -121,7 +121,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -121,7 +121,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
'placement_constraint': placement_constraint 'placement_constraint': placement_constraint
} }
_logger.debug('New trial sent: %s', new_trial) _logger.debug('New trial sent: %s', new_trial)
send(CommandType.NewTrialJob, json_dumps(new_trial)) send(CommandType.NewTrialJob, nni.dump(new_trial))
if self.send_trial_callback is not None: if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count return self.parameters_count
...@@ -140,7 +140,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -140,7 +140,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
def handle_trial_end(self, data): def handle_trial_end(self, data):
_logger.debug('Trial end: %s', data) _logger.debug('Trial end: %s', data)
self.trial_end_callback(json_loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED') data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
...@@ -156,7 +156,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -156,7 +156,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
@staticmethod @staticmethod
def _process_value(value) -> Any: # hopefully a float def _process_value(value) -> Any: # hopefully a float
value = json_loads(value) value = nni.load(value)
if isinstance(value, dict): if isinstance(value, dict):
if 'default' in value: if 'default' in value:
return value['default'] return value['default']
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import json
from typing import NewType, Any from typing import NewType, Any
import nni import nni
from .serializer import json_loads
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor # NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import # because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any) RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
...@@ -41,7 +38,6 @@ def receive_trial_parameters() -> dict: ...@@ -41,7 +38,6 @@ def receive_trial_parameters() -> dict:
Reload with our json loads because NNI didn't use Retiarii serializer to load the data. Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
""" """
params = nni.get_next_parameter() params = nni.get_next_parameter()
params = json_loads(json.dumps(params))
return params return params
......
...@@ -8,8 +8,9 @@ from typing import Any, List, Union, Dict, Optional ...@@ -8,8 +8,9 @@ from typing import Any, List, Union, Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...serializer import Translatable, basic_unit from nni.common.serializer import Translatable
from ...utils import NoContextError from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import NoContextError
from .utils import generate_new_label, get_fixed_value from .utils import generate_new_label, get_fixed_value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment