Unverified Commit 443ba8c1 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Serialization infrastructure V2 (#4337)

parent 896c516f
......@@ -114,7 +114,9 @@ CGO Execution
Utilities
---------
.. autofunction:: nni.retiarii.serialize
.. autofunction:: nni.retiarii.basic_unit
.. autofunction:: nni.retiarii.model_wrapper
.. autofunction:: nni.retiarii.fixed_arch
......
......@@ -78,3 +78,9 @@ Utilities
---------
.. autofunction:: nni.utils.merge_parameter
.. autofunction:: nni.trace
.. autofunction:: nni.dump
.. autofunction:: nni.load
......@@ -3,7 +3,7 @@ import nni
import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn as nn
import torchmetrics
from nni.retiarii import model_wrapper, serialize, serialize_cls
from nni.retiarii import model_wrapper, serialize
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.nn.pytorch import NasBench101Cell
from nni.retiarii.strategy import Random
......@@ -82,7 +82,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy):
return super().update(nn.functional.softmax(pred), target)
@serialize_cls
@nni.trace
class NasBench101TrainingModule(pl.LightningModule):
def __init__(self, max_epochs=108, learning_rate=0.1, weight_decay=1e-4):
super().__init__()
......
......@@ -3,7 +3,7 @@ import nni
import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn as nn
import torchmetrics
from nni.retiarii import model_wrapper, serialize, serialize_cls
from nni.retiarii import model_wrapper, serialize
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.nn.pytorch import NasBench201Cell
from nni.retiarii.strategy import Random
......@@ -71,7 +71,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy):
return super().update(nn.functional.softmax(pred), target)
@serialize_cls
@nni.trace
class NasBench201TrainingModule(pl.LightningModule):
def __init__(self, max_epochs=200, learning_rate=0.1, weight_decay=5e-4):
super().__init__()
......
......@@ -9,7 +9,7 @@ except ModuleNotFoundError:
from .runtime.log import init_logger
init_logger()
from .common.serializer import *
from .common.serializer import trace, dump, load
from .runtime.env_vars import dispatcher_env_vars
from .utils import ClassArgsValidator
......
......@@ -7,12 +7,12 @@ bohb_advisor.py
import sys
import math
import logging
import json_tricks
from schema import Schema, Optional
import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH
from ConfigSpace.read_and_write import pcs_new
import nni
from nni import ClassArgsValidator
from nni.runtime.protocol import CommandType, send
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
......@@ -428,7 +428,7 @@ class BOHB(MsgDispatcherBase):
'parameter_source': 'algorithm',
'parameters': ''
}
send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret))
send(CommandType.NoMoreTrialJobs, nni.dump(ret))
return None
assert self.generated_hyper_configs
params = self.generated_hyper_configs.pop(0)
......@@ -459,7 +459,7 @@ class BOHB(MsgDispatcherBase):
"""
ret = self._get_one_trial_job()
if ret is not None:
send(CommandType.NewTrialJob, json_tricks.dumps(ret))
send(CommandType.NewTrialJob, nni.dump(ret))
self.credit -= 1
def handle_update_search_space(self, data):
......@@ -536,7 +536,7 @@ class BOHB(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
logger.debug('Tuner handle trial end, result is %s', data)
hyper_params = json_tricks.loads(data['hyper_params'])
hyper_params = nni.load(data['hyper_params'])
self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']]
......@@ -551,7 +551,7 @@ class BOHB(MsgDispatcherBase):
ret['parameter_index'] = one_unsatisfied['parameter_index']
# update parameter_id in self.job_id_para_id_map
self.job_id_para_id_map[ret['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret))
send(CommandType.SendTrialJobParameter, nni.dump(ret))
for _ in range(self.credit):
self._request_one_trial_job()
......@@ -584,7 +584,7 @@ class BOHB(MsgDispatcherBase):
"""
logger.debug('handle report metric data = %s', data)
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
data['value'] = nni.load(data['value'])
if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled()
assert data['trial_job_id'] is not None
......@@ -599,7 +599,7 @@ class BOHB(MsgDispatcherBase):
ret['parameter_index'] = data['parameter_index']
# update parameter_id in self.job_id_para_id_map
self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret))
send(CommandType.SendTrialJobParameter, nni.dump(ret))
else:
assert 'value' in data
value = extract_scalar_reward(data['value'])
......@@ -655,7 +655,7 @@ class BOHB(MsgDispatcherBase):
data doesn't have required key 'parameter' and 'value'
"""
for entry in data:
entry['value'] = json_tricks.loads(entry['value'])
entry['value'] = nni.load(entry['value'])
_completed_num = 0
for trial_info in data:
logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
......
......@@ -10,10 +10,10 @@ import logging
import math
import sys
import json_tricks
import numpy as np
from schema import Schema, Optional
import nni
from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.runtime.common import multi_phase_enabled
......@@ -336,7 +336,7 @@ class Hyperband(MsgDispatcherBase):
def _request_one_trial_job(self):
ret = self._get_one_trial_job()
if ret is not None:
send(CommandType.NewTrialJob, json_tricks.dumps(ret))
send(CommandType.NewTrialJob, nni.dump(ret))
self.credit -= 1
def _get_one_trial_job(self):
......@@ -365,7 +365,7 @@ class Hyperband(MsgDispatcherBase):
'parameter_source': 'algorithm',
'parameters': ''
}
send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret))
send(CommandType.NoMoreTrialJobs, nni.dump(ret))
return None
assert self.generated_hyper_configs
......@@ -408,7 +408,7 @@ class Hyperband(MsgDispatcherBase):
event: the job's state
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
hyper_params = json_tricks.loads(data['hyper_params'])
hyper_params = nni.load(data['hyper_params'])
self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']]
......@@ -426,7 +426,7 @@ class Hyperband(MsgDispatcherBase):
Data type not supported
"""
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
data['value'] = nni.load(data['value'])
# multiphase? need to check
if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled()
......@@ -440,7 +440,7 @@ class Hyperband(MsgDispatcherBase):
if data['parameter_index'] is not None:
ret['parameter_index'] = data['parameter_index']
self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret))
send(CommandType.SendTrialJobParameter, nni.dump(ret))
else:
value = extract_scalar_reward(data['value'])
bracket_id, i, _ = data['parameter_id'].split('_')
......
from .serializer import trace, dump, load, is_traceable
This diff is collapsed.
......@@ -6,11 +6,11 @@ from subprocess import Popen
import time
from typing import Optional, Union, List, overload, Any
import json_tricks
import colorama
import psutil
import nni.runtime.log
from nni.common import dump
from .config import ExperimentConfig, AlgorithmConfig
from .data import TrialJob, TrialMetricData, TrialResult
......@@ -439,7 +439,7 @@ class Experiment:
value: dict
New search_space.
"""
value = json_tricks.dumps(value)
value = dump(value)
self._update_experiment_profile('searchSpace', value)
def update_max_trial_number(self, value: int):
......
......@@ -6,4 +6,4 @@ from .graph import *
from .execution import *
from .fixed import fixed_arch
from .mutator import *
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper
from .serializer import basic_unit, model_wrapper, serialize, serialize_cls
......@@ -637,7 +637,7 @@ class GraphConverter:
original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
m_attrs = get_init_parameters_or_fail(module)
elif getattr(module, '_stop_parsing', False):
elif getattr(module, '_nni_basic_unit', False):
# this module is marked as serialize, won't continue to parse
m_attrs = get_init_parameters_or_fail(module)
if m_attrs is not None:
......
......@@ -10,7 +10,7 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import Trainin
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from ....serializer import serialize_cls
import nni
class BypassPlugin(TrainingTypePlugin):
......@@ -126,7 +126,7 @@ def get_accelerator_connector(
)
@serialize_cls
@nni.trace
class BypassAccelerator(Accelerator):
def __init__(self, precision_plugin=None, device="cpu", **trainer_kwargs):
if precision_plugin is None:
......
......@@ -14,10 +14,9 @@ import nni
from ..lightning import LightningModule, _AccuracyWithLogits, Lightning
from .trainer import Trainer
from ....serializer import serialize_cls
@serialize_cls
@nni.trace
class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
n_models: int = 0,
......@@ -126,7 +125,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
super().__init__(criterion, metrics, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
@serialize_cls
@nni.trace
class _ClassificationModule(MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
......@@ -174,7 +173,7 @@ class Classification(Lightning):
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@serialize_cls
@nni.trace
class _RegressionModule(MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import pytorch_lightning as pl
from ....serializer import serialize_cls
import nni
from .accelerator import BypassAccelerator
@serialize_cls
@nni.trace
class Trainer(pl.Trainer):
"""
Trainer for cross-graph optimization.
......
......@@ -10,17 +10,17 @@ import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import DataLoader
import torch.utils.data as torch_data
import nni
from nni.common.serializer import is_traceable
try:
from .cgo import trainer as cgo_trainer
cgo_import_failed = False
except ImportError:
cgo_import_failed = True
from ...graph import Evaluator
from ...serializer import serialize_cls
from nni.retiarii.graph import Evaluator
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression']
......@@ -40,9 +40,10 @@ class LightningModule(pl.LightningModule):
self.model = model
Trainer = serialize_cls(pl.Trainer)
DataLoader = serialize_cls(DataLoader)
Trainer = nni.trace(pl.Trainer)
DataLoader = nni.trace(torch_data.DataLoader)
@nni.trace
class Lightning(Evaluator):
"""
Delegate the whole training to PyTorch Lightning.
......@@ -74,9 +75,10 @@ class Lightning(Evaluator):
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None):
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
if cgo_import_failed:
assert isinstance(trainer, Trainer), f'Trainer must be imported from {__name__}'
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}'
else:
assert isinstance(trainer, Trainer) or isinstance(trainer, cgo_trainer.Trainer), \
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
......@@ -135,7 +137,7 @@ def _check_dataloader(dataloader):
return True
if isinstance(dataloader, list):
return all([_check_dataloader(d) for d in dataloader])
return isinstance(dataloader, DataLoader)
return isinstance(dataloader, torch_data.DataLoader) and is_traceable(dataloader)
### The following are some commonly used Lightning modules ###
......@@ -219,7 +221,7 @@ class _AccuracyWithLogits(torchmetrics.Accuracy):
return super().update(nn.functional.softmax(pred), target)
@serialize_cls
@nni.trace
class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
......@@ -272,7 +274,7 @@ class Classification(Lightning):
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@serialize_cls
@nni.trace
class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
......
......@@ -200,7 +200,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
# replace the module with a new instance whose n_models is set
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
new_module_init_params = model.evaluator.module._init_parameters.copy()
new_module_init_params = model.evaluator.module.trace_kwargs.copy()
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
new_module_init_params['n_models'] = len(multi_model)
......
......@@ -4,13 +4,13 @@
import logging
from typing import Any, Callable
import nni
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType
from .graph import MetricData
from .integration_api import register_advisor
from .serializer import json_dumps, json_loads
_logger = logging.getLogger(__name__)
......@@ -121,7 +121,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
'placement_constraint': placement_constraint
}
_logger.debug('New trial sent: %s', new_trial)
send(CommandType.NewTrialJob, json_dumps(new_trial))
send(CommandType.NewTrialJob, nni.dump(new_trial))
if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count
......@@ -140,7 +140,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
def handle_trial_end(self, data):
_logger.debug('Trial end: %s', data)
self.trial_end_callback(json_loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
......@@ -156,7 +156,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
@staticmethod
def _process_value(value) -> Any: # hopefully a float
value = json_loads(value)
value = nni.load(value)
if isinstance(value, dict):
if 'default' in value:
return value['default']
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from typing import NewType, Any
import nni
from .serializer import json_loads
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
......@@ -41,7 +38,6 @@ def receive_trial_parameters() -> dict:
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
"""
params = nni.get_next_parameter()
params = json_loads(json.dumps(params))
return params
......
......@@ -8,8 +8,9 @@ from typing import Any, List, Union, Dict, Optional
import torch
import torch.nn as nn
from ...serializer import Translatable, basic_unit
from ...utils import NoContextError
from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import NoContextError
from .utils import generate_new_label, get_fixed_value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment