Unverified Commit d5ed88e4 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Retiarii serializer user experience improvements (#4437)

parent 12b5dbe2
...@@ -14,7 +14,9 @@ The recommendation is, unless you are absolutely certain that there is no proble ...@@ -14,7 +14,9 @@ The recommendation is, unless you are absolutely certain that there is no proble
**What will happen if I forget to "trace" my objects?** **What will happen if I forget to "trace" my objects?**
It is likely that the program can still run. NNI will try to serialize the untraced object into a binary. If might fail in complicated cases (e.g., circular dependency). Even if it succeeds, the result might be a substantially large object. For example, if you forgot to add ``nni.trace`` on ``MNIST``, the MNIST dataset object wil be serialized into binary, which will be dozens of megabytes because the object has the whole 60k images stored inside. You might see warnings and even errors when running experiments. To avoid such issues, the easiest way is to always remember to add ``nni.trace`` to non-primitive objects. It is likely that the program can still run. NNI will try to serialize the untraced object into a binary. It might fail in complex cases. For example, when the object is too large. Even if it succeeds, the result might be a substantially large object. For example, if you forgot to add ``nni.trace`` on ``MNIST``, the MNIST dataset object wil be serialized into binary, which will be dozens of megabytes because the object has the whole 60k images stored inside. You might see warnings and even errors when running experiments. To avoid such issues, the easiest way is to always remember to add ``nni.trace`` to non-primitive objects.
.. note:: In Retiarii, serializer will throw exception when one of an single object in the recursive serialization is larger than 64 KB when binary serialized. This indicates that such object needs to be wrapped by ``nni.trace``. In rare cases, if you insist on pickling large data, the limit can be overridden by setting an environment variable ``PICKLE_SIZE_LIMIT``, whose unit is byte. Please note that even if the experiment might be able to run, this can still cause performance issues and even the crash of NNI experiment.
To trace a function or class, users can use decorator like, To trace a function or class, users can use decorator like,
......
...@@ -13,12 +13,16 @@ from typing import Any, Dict, List, Optional, TypeVar, Union ...@@ -13,12 +13,16 @@ from typing import Any, Dict, List, Optional, TypeVar, Union
import cloudpickle # use cloudpickle as backend for unserializable types and instances import cloudpickle # use cloudpickle as backend for unserializable types and instances
import json_tricks # use json_tricks as serializer backend import json_tricks # use json_tricks as serializer backend
__all__ = ['trace', 'dump', 'load', 'Translatable', 'Traceable', 'is_traceable'] __all__ = ['trace', 'dump', 'load', 'PayloadTooLarge', 'Translatable', 'Traceable', 'is_traceable']
T = TypeVar('T') T = TypeVar('T')
class PayloadTooLarge(Exception):
pass
class Traceable(abc.ABC): class Traceable(abc.ABC):
""" """
A traceable object have copy and dict. Copy and mutate are used to copy the object for further mutations. A traceable object have copy and dict. Copy and mutate are used to copy the object for further mutations.
...@@ -563,9 +567,9 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si ...@@ -563,9 +567,9 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si
if hasattr(obj, '__class__') and (hasattr(obj, '__dict__') or hasattr(obj, '__slots__')): if hasattr(obj, '__class__') and (hasattr(obj, '__dict__') or hasattr(obj, '__slots__')):
b = cloudpickle.dumps(obj) b = cloudpickle.dumps(obj)
if len(b) > pickle_size_limit > 0: if len(b) > pickle_size_limit > 0:
raise ValueError(f'Pickle too large when trying to dump {obj}. This might be caused by classes that are ' raise PayloadTooLarge(f'Pickle too large when trying to dump {obj}. This might be caused by classes that are '
'not decorated by @nni.trace. Another option is to force bytes pickling and ' 'not decorated by @nni.trace. Another option is to force bytes pickling and '
'try to raise pickle_size_limit.') 'try to raise pickle_size_limit.')
# use base64 to dump a bytes array # use base64 to dump a bytes array
return { return {
'__nni_obj__': base64.b64encode(b).decode() '__nni_obj__': base64.b64encode(b).decode()
......
...@@ -26,6 +26,7 @@ class FunctionalEvaluator(Evaluator): ...@@ -26,6 +26,7 @@ class FunctionalEvaluator(Evaluator):
def _dump(self): def _dump(self):
return { return {
'type': self.__class__,
'function': self.function, 'function': self.function,
'arguments': self.arguments 'arguments': self.arguments
} }
......
...@@ -93,6 +93,7 @@ class Lightning(Evaluator): ...@@ -93,6 +93,7 @@ class Lightning(Evaluator):
def _dump(self): def _dump(self):
return { return {
'type': self.__class__,
'module': self.module, 'module': self.module,
'trainer': self.trainer, 'trainer': self.trainer,
'train_dataloader': self.train_dataloader, 'train_dataloader': self.train_dataloader,
......
...@@ -34,13 +34,16 @@ class BaseGraphData: ...@@ -34,13 +34,16 @@ class BaseGraphData:
def dump(self) -> dict: def dump(self) -> dict:
return { return {
'model_script': self.model_script, 'model_script': self.model_script,
'evaluator': self.evaluator, # engine needs to call dump here,
# otherwise, evaluator will become binary
# also, evaluator can be none in tests
'evaluator': self.evaluator._dump() if self.evaluator is not None else None,
'mutation_summary': self.mutation_summary 'mutation_summary': self.mutation_summary
} }
@staticmethod @staticmethod
def load(data) -> 'BaseGraphData': def load(data) -> 'BaseGraphData':
return BaseGraphData(data['model_script'], data['evaluator'], data['mutation_summary']) return BaseGraphData(data['model_script'], Evaluator._load(data['evaluator']), data['mutation_summary'])
class BaseExecutionEngine(AbstractExecutionEngine): class BaseExecutionEngine(AbstractExecutionEngine):
......
from typing import Dict, Any from typing import Dict, Any, Type
import torch.nn as nn
from ..graph import Evaluator, Model from ..graph import Evaluator, Model
from ..integration_api import receive_trial_parameters from ..integration_api import receive_trial_parameters
from ..utils import ContextStack, import_, get_importable_name from ..utils import ContextStack
from .base import BaseExecutionEngine from .base import BaseExecutionEngine
from .utils import get_mutation_dict, mutation_dict_to_summary from .utils import get_mutation_dict, mutation_dict_to_summary
class PythonGraphData: class PythonGraphData:
def __init__(self, class_name: str, init_parameters: Dict[str, Any], def __init__(self, class_: Type[nn.Module], init_parameters: Dict[str, Any],
mutation: Dict[str, Any], evaluator: Evaluator) -> None: mutation: Dict[str, Any], evaluator: Evaluator) -> None:
self.class_name = class_name self.class_ = class_
self.init_parameters = init_parameters self.init_parameters = init_parameters
self.mutation = mutation self.mutation = mutation
self.evaluator = evaluator self.evaluator = evaluator
...@@ -18,16 +20,19 @@ class PythonGraphData: ...@@ -18,16 +20,19 @@ class PythonGraphData:
def dump(self) -> dict: def dump(self) -> dict:
return { return {
'class_name': self.class_name, 'class': self.class_,
'init_parameters': self.init_parameters, 'init_parameters': self.init_parameters,
'mutation': self.mutation, 'mutation': self.mutation,
'evaluator': self.evaluator, # engine needs to call dump here,
# otherwise, evaluator will become binary
# also, evaluator can be none in tests
'evaluator': self.evaluator._dump() if self.evaluator is not None else None,
'mutation_summary': self.mutation_summary 'mutation_summary': self.mutation_summary
} }
@staticmethod @staticmethod
def load(data) -> 'PythonGraphData': def load(data) -> 'PythonGraphData':
return PythonGraphData(data['class_name'], data['init_parameters'], data['mutation'], data['evaluator']) return PythonGraphData(data['class'], data['init_parameters'], data['mutation'], Evaluator._load(data['evaluator']))
class PurePythonExecutionEngine(BaseExecutionEngine): class PurePythonExecutionEngine(BaseExecutionEngine):
...@@ -44,17 +49,15 @@ class PurePythonExecutionEngine(BaseExecutionEngine): ...@@ -44,17 +49,15 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
@classmethod @classmethod
def pack_model_data(cls, model: Model) -> Any: def pack_model_data(cls, model: Model) -> Any:
mutation = get_mutation_dict(model) mutation = get_mutation_dict(model)
graph_data = PythonGraphData(get_importable_name(model.python_class, relocate_module=True), graph_data = PythonGraphData(model.python_class, model.python_init_params, mutation, model.evaluator)
model.python_init_params, mutation, model.evaluator)
return graph_data return graph_data
@classmethod @classmethod
def trial_execute_graph(cls) -> None: def trial_execute_graph(cls) -> None:
graph_data = PythonGraphData.load(receive_trial_parameters()) graph_data = PythonGraphData.load(receive_trial_parameters())
class _model(import_(graph_data.class_name)): def _model():
def __init__(self): return graph_data.class_(**graph_data.init_parameters)
super().__init__(**graph_data.init_parameters)
with ContextStack('fixed', graph_data.mutation): with ContextStack('fixed', graph_data.mutation):
graph_data.evaluator._execute(_model) graph_data.evaluator._execute(_model)
...@@ -6,6 +6,7 @@ import logging ...@@ -6,6 +6,7 @@ import logging
import os import os
import socket import socket
import time import time
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from subprocess import Popen from subprocess import Popen
...@@ -35,6 +36,7 @@ from ..integration import RetiariiAdvisor ...@@ -35,6 +36,7 @@ from ..integration import RetiariiAdvisor
from ..mutator import Mutator from ..mutator import Mutator
from ..nn.pytorch.mutator import extract_mutation_from_pt_module, process_inline_mutation from ..nn.pytorch.mutator import extract_mutation_from_pt_module, process_inline_mutation
from ..oneshot.interface import BaseOneShotTrainer from ..oneshot.interface import BaseOneShotTrainer
from ..serializer import is_model_wrapped
from ..strategy import BaseStrategy from ..strategy import BaseStrategy
from ..strategy.utils import dry_run_for_formatted_search_space from ..strategy.utils import dry_run_for_formatted_search_space
...@@ -185,6 +187,13 @@ class RetiariiExperiment(Experiment): ...@@ -185,6 +187,13 @@ class RetiariiExperiment(Experiment):
self.url_prefix = None self.url_prefix = None
# check for sanity
if not is_model_wrapped(base_model):
warnings.warn(colorama.Style.BRIGHT + colorama.Fore.RED +
'`@model_wrapper` is missing for the base model. The experiment might still be able to run, '
'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL,
RuntimeWarning)
def _start_strategy(self): def _start_strategy(self):
base_model_ir, self.applied_mutators = preprocess_model( base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.trainer, self.applied_mutators, self.base_model, self.trainer, self.applied_mutators,
......
...@@ -11,7 +11,7 @@ from enum import Enum ...@@ -11,7 +11,7 @@ from enum import Enum
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Type, Union, overload) from typing import (Any, Dict, Iterable, List, Optional, Tuple, Type, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation from .operation import Cell, Operation, _IOPseudoOperation
from .utils import get_importable_name, import_, uid from .utils import uid
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData'] __all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData']
...@@ -41,20 +41,25 @@ class Evaluator(abc.ABC): ...@@ -41,20 +41,25 @@ class Evaluator(abc.ABC):
items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()]) items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()])
return f'{self.__class__.__name__}({items})' return f'{self.__class__.__name__}({items})'
@abc.abstractstaticmethod
def _load(ir: Any) -> 'Evaluator':
pass
@staticmethod @staticmethod
def _load_with_type(type_name: str, ir: Any) -> 'Optional[Evaluator]': def _load(ir: Any) -> 'Evaluator':
if type_name == '_debug_no_trainer': evaluator_type = ir.get('type')
return DebugEvaluator() if isinstance(evaluator_type, str):
config_cls = import_(type_name) # for debug purposes only
assert issubclass(config_cls, Evaluator) for subclass in Evaluator.__subclasses__():
return config_cls._load(ir) if subclass.__name__ == evaluator_type:
evaluator_type = subclass
break
assert issubclass(evaluator_type, Evaluator)
return evaluator_type._load(ir)
@abc.abstractmethod @abc.abstractmethod
def _dump(self) -> Any: def _dump(self) -> Any:
"""
Subclass implements ``_dump`` for their own serialization.
They should return a dict, with a key ``type`` which equals ``self.__class__``,
and optionally other keys.
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
...@@ -154,16 +159,13 @@ class Model: ...@@ -154,16 +159,13 @@ class Model:
if graph_name != '_evaluator': if graph_name != '_evaluator':
Graph._load(model, graph_name, graph_data)._register() Graph._load(model, graph_name, graph_data)._register()
if '_evaluator' in ir: if '_evaluator' in ir:
model.evaluator = Evaluator._load_with_type(ir['_evaluator']['__type__'], ir['_evaluator']) model.evaluator = Evaluator._load(ir['_evaluator'])
return model return model
def _dump(self) -> Any: def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()} ret = {name: graph._dump() for name, graph in self.graphs.items()}
if self.evaluator is not None: if self.evaluator is not None:
ret['_evaluator'] = { ret['_evaluator'] = self.evaluator._dump()
'__type__': get_importable_name(self.evaluator.__class__),
**self.evaluator._dump()
}
return ret return ret
def get_nodes(self) -> Iterable['Node']: def get_nodes(self) -> Iterable['Node']:
...@@ -787,7 +789,7 @@ class DebugEvaluator(Evaluator): ...@@ -787,7 +789,7 @@ class DebugEvaluator(Evaluator):
return DebugEvaluator() return DebugEvaluator()
def _dump(self) -> Any: def _dump(self) -> Any:
return {'__type__': '_debug_no_trainer'} return {'type': DebugEvaluator}
def _execute(self, model_cls: type) -> Any: def _execute(self, model_cls: type) -> Any:
pass pass
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
import warnings import os
from typing import Any, Callable from typing import Any, Callable
import nni import nni
from nni.common.serializer import PayloadTooLarge
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType from nni.utils import MetricType
...@@ -123,14 +124,16 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -123,14 +124,16 @@ class RetiariiAdvisor(MsgDispatcherBase):
} }
_logger.debug('New trial sent: %s', new_trial) _logger.debug('New trial sent: %s', new_trial)
send_payload = nni.dump(new_trial, pickle_size_limit=-1) try:
if len(send_payload) > 256 * 1024: send_payload = nni.dump(new_trial, pickle_size_limit=int(os.getenv('PICKLE_SIZE_LIMIT', 64 * 1024)))
warnings.warn( except PayloadTooLarge:
'The total payload of the trial is larger than 50 KB. ' raise ValueError(
'This can cause performance issues and even the crash of NNI experiment. ' 'Serialization failed when trying to dump the model because payload too large (larger than 64 KB). '
'This is usually caused by pickling large objects (like datasets) by mistake. ' 'This is usually caused by pickling large objects (like datasets) by mistake. '
'See https://nni.readthedocs.io/en/stable/NAS/Serialization.html for details.' 'See the full error traceback for details and https://nni.readthedocs.io/en/stable/NAS/Serialization.html '
'for how to resolve such issue. '
) )
# trial parameters can be super large, disable pickle size limit here # trial parameters can be super large, disable pickle size limit here
# nevertheless, there could still be blocked by pipe / nni-manager # nevertheless, there could still be blocked by pipe / nni-manager
send(CommandType.NewTrialJob, send_payload) send(CommandType.NewTrialJob, send_payload)
...@@ -175,4 +178,4 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -175,4 +178,4 @@ class RetiariiAdvisor(MsgDispatcherBase):
return value['default'] return value['default']
else: else:
return value return value
return value return value
\ No newline at end of file
...@@ -8,7 +8,7 @@ import torch.nn as nn ...@@ -8,7 +8,7 @@ import torch.nn as nn
from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node
from nni.retiarii.mutator import Mutator from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
from nni.retiarii.utils import uid from nni.retiarii.utils import uid
from .api import LayerChoice, InputChoice, ValueChoice, Placeholder from .api import LayerChoice, InputChoice, ValueChoice, Placeholder
...@@ -223,7 +223,7 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op ...@@ -223,7 +223,7 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
graph = Graph(model, uid(), '_model', _internal=True)._register() graph = Graph(model, uid(), '_model', _internal=True)._register()
model.python_class = pytorch_model.__class__ model.python_class = pytorch_model.__class__
if len(inspect.signature(model.python_class.__init__).parameters) > 1: if len(inspect.signature(model.python_class.__init__).parameters) > 1:
if not getattr(pytorch_model, '_nni_model_wrapper', False): if not is_model_wrapped(pytorch_model):
raise ValueError('Please annotate the model with @model_wrapper decorator in python execution mode ' raise ValueError('Please annotate the model with @model_wrapper decorator in python execution mode '
'if your model has init parameters.') 'if your model has init parameters.')
model.python_init_params = pytorch_model.trace_kwargs model.python_init_params = pytorch_model.trace_kwargs
......
...@@ -39,6 +39,6 @@ ...@@ -39,6 +39,6 @@
}, },
"_evaluator": { "_evaluator": {
"__type__": "_debug_no_trainer" "type": "DebugEvaluator"
} }
} }
...@@ -33,6 +33,11 @@ def _test_file(json_path): ...@@ -33,6 +33,11 @@ def _test_file(json_path):
# debug output # debug output
#json.dump(orig_ir, open('_orig.json', 'w'), indent=4) #json.dump(orig_ir, open('_orig.json', 'w'), indent=4)
#json.dump(dump_ir, open('_dump.json', 'w'), indent=4) #json.dump(dump_ir, open('_dump.json', 'w'), indent=4)
# skip comparison of _evaluator
orig_ir.pop('_evaluator')
dump_ir.pop('_evaluator')
assert orig_ir == dump_ir assert orig_ir == dump_ir
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment