Unverified Commit d5ed88e4 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Retiarii serializer user experience improvements (#4437)

parent 12b5dbe2
......@@ -14,7 +14,9 @@ The recommendation is, unless you are absolutely certain that there is no proble
**What will happen if I forget to "trace" my objects?**
It is likely that the program can still run. NNI will try to serialize the untraced object into a binary. If might fail in complicated cases (e.g., circular dependency). Even if it succeeds, the result might be a substantially large object. For example, if you forgot to add ``nni.trace`` on ``MNIST``, the MNIST dataset object wil be serialized into binary, which will be dozens of megabytes because the object has the whole 60k images stored inside. You might see warnings and even errors when running experiments. To avoid such issues, the easiest way is to always remember to add ``nni.trace`` to non-primitive objects.
It is likely that the program can still run. NNI will try to serialize the untraced object into a binary. It might fail in complex cases. For example, when the object is too large. Even if it succeeds, the result might be a substantially large object. For example, if you forgot to add ``nni.trace`` on ``MNIST``, the MNIST dataset object wil be serialized into binary, which will be dozens of megabytes because the object has the whole 60k images stored inside. You might see warnings and even errors when running experiments. To avoid such issues, the easiest way is to always remember to add ``nni.trace`` to non-primitive objects.
.. note:: In Retiarii, serializer will throw exception when one of an single object in the recursive serialization is larger than 64 KB when binary serialized. This indicates that such object needs to be wrapped by ``nni.trace``. In rare cases, if you insist on pickling large data, the limit can be overridden by setting an environment variable ``PICKLE_SIZE_LIMIT``, whose unit is byte. Please note that even if the experiment might be able to run, this can still cause performance issues and even the crash of NNI experiment.
To trace a function or class, users can use decorator like,
......
......@@ -13,12 +13,16 @@ from typing import Any, Dict, List, Optional, TypeVar, Union
import cloudpickle # use cloudpickle as backend for unserializable types and instances
import json_tricks # use json_tricks as serializer backend
__all__ = ['trace', 'dump', 'load', 'Translatable', 'Traceable', 'is_traceable']
__all__ = ['trace', 'dump', 'load', 'PayloadTooLarge', 'Translatable', 'Traceable', 'is_traceable']
T = TypeVar('T')
class PayloadTooLarge(Exception):
pass
class Traceable(abc.ABC):
"""
A traceable object have copy and dict. Copy and mutate are used to copy the object for further mutations.
......@@ -563,9 +567,9 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si
if hasattr(obj, '__class__') and (hasattr(obj, '__dict__') or hasattr(obj, '__slots__')):
b = cloudpickle.dumps(obj)
if len(b) > pickle_size_limit > 0:
raise ValueError(f'Pickle too large when trying to dump {obj}. This might be caused by classes that are '
'not decorated by @nni.trace. Another option is to force bytes pickling and '
'try to raise pickle_size_limit.')
raise PayloadTooLarge(f'Pickle too large when trying to dump {obj}. This might be caused by classes that are '
'not decorated by @nni.trace. Another option is to force bytes pickling and '
'try to raise pickle_size_limit.')
# use base64 to dump a bytes array
return {
'__nni_obj__': base64.b64encode(b).decode()
......
......@@ -26,6 +26,7 @@ class FunctionalEvaluator(Evaluator):
def _dump(self):
return {
'type': self.__class__,
'function': self.function,
'arguments': self.arguments
}
......
......@@ -93,6 +93,7 @@ class Lightning(Evaluator):
def _dump(self):
return {
'type': self.__class__,
'module': self.module,
'trainer': self.trainer,
'train_dataloader': self.train_dataloader,
......
......@@ -34,13 +34,16 @@ class BaseGraphData:
def dump(self) -> dict:
return {
'model_script': self.model_script,
'evaluator': self.evaluator,
# engine needs to call dump here,
# otherwise, evaluator will become binary
# also, evaluator can be none in tests
'evaluator': self.evaluator._dump() if self.evaluator is not None else None,
'mutation_summary': self.mutation_summary
}
@staticmethod
def load(data) -> 'BaseGraphData':
return BaseGraphData(data['model_script'], data['evaluator'], data['mutation_summary'])
return BaseGraphData(data['model_script'], Evaluator._load(data['evaluator']), data['mutation_summary'])
class BaseExecutionEngine(AbstractExecutionEngine):
......
from typing import Dict, Any
from typing import Dict, Any, Type
import torch.nn as nn
from ..graph import Evaluator, Model
from ..integration_api import receive_trial_parameters
from ..utils import ContextStack, import_, get_importable_name
from ..utils import ContextStack
from .base import BaseExecutionEngine
from .utils import get_mutation_dict, mutation_dict_to_summary
class PythonGraphData:
def __init__(self, class_name: str, init_parameters: Dict[str, Any],
def __init__(self, class_: Type[nn.Module], init_parameters: Dict[str, Any],
mutation: Dict[str, Any], evaluator: Evaluator) -> None:
self.class_name = class_name
self.class_ = class_
self.init_parameters = init_parameters
self.mutation = mutation
self.evaluator = evaluator
......@@ -18,16 +20,19 @@ class PythonGraphData:
def dump(self) -> dict:
return {
'class_name': self.class_name,
'class': self.class_,
'init_parameters': self.init_parameters,
'mutation': self.mutation,
'evaluator': self.evaluator,
# engine needs to call dump here,
# otherwise, evaluator will become binary
# also, evaluator can be none in tests
'evaluator': self.evaluator._dump() if self.evaluator is not None else None,
'mutation_summary': self.mutation_summary
}
@staticmethod
def load(data) -> 'PythonGraphData':
return PythonGraphData(data['class_name'], data['init_parameters'], data['mutation'], data['evaluator'])
return PythonGraphData(data['class'], data['init_parameters'], data['mutation'], Evaluator._load(data['evaluator']))
class PurePythonExecutionEngine(BaseExecutionEngine):
......@@ -44,17 +49,15 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation = get_mutation_dict(model)
graph_data = PythonGraphData(get_importable_name(model.python_class, relocate_module=True),
model.python_init_params, mutation, model.evaluator)
graph_data = PythonGraphData(model.python_class, model.python_init_params, mutation, model.evaluator)
return graph_data
@classmethod
def trial_execute_graph(cls) -> None:
graph_data = PythonGraphData.load(receive_trial_parameters())
class _model(import_(graph_data.class_name)):
def __init__(self):
super().__init__(**graph_data.init_parameters)
def _model():
return graph_data.class_(**graph_data.init_parameters)
with ContextStack('fixed', graph_data.mutation):
graph_data.evaluator._execute(_model)
......@@ -6,6 +6,7 @@ import logging
import os
import socket
import time
import warnings
from dataclasses import dataclass
from pathlib import Path
from subprocess import Popen
......@@ -35,6 +36,7 @@ from ..integration import RetiariiAdvisor
from ..mutator import Mutator
from ..nn.pytorch.mutator import extract_mutation_from_pt_module, process_inline_mutation
from ..oneshot.interface import BaseOneShotTrainer
from ..serializer import is_model_wrapped
from ..strategy import BaseStrategy
from ..strategy.utils import dry_run_for_formatted_search_space
......@@ -185,6 +187,13 @@ class RetiariiExperiment(Experiment):
self.url_prefix = None
# check for sanity
if not is_model_wrapped(base_model):
warnings.warn(colorama.Style.BRIGHT + colorama.Fore.RED +
'`@model_wrapper` is missing for the base model. The experiment might still be able to run, '
'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL,
RuntimeWarning)
def _start_strategy(self):
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.trainer, self.applied_mutators,
......
......@@ -11,7 +11,7 @@ from enum import Enum
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Type, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation
from .utils import get_importable_name, import_, uid
from .utils import uid
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData']
......@@ -41,20 +41,25 @@ class Evaluator(abc.ABC):
items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()])
return f'{self.__class__.__name__}({items})'
@abc.abstractstaticmethod
def _load(ir: Any) -> 'Evaluator':
pass
@staticmethod
def _load_with_type(type_name: str, ir: Any) -> 'Optional[Evaluator]':
if type_name == '_debug_no_trainer':
return DebugEvaluator()
config_cls = import_(type_name)
assert issubclass(config_cls, Evaluator)
return config_cls._load(ir)
def _load(ir: Any) -> 'Evaluator':
evaluator_type = ir.get('type')
if isinstance(evaluator_type, str):
# for debug purposes only
for subclass in Evaluator.__subclasses__():
if subclass.__name__ == evaluator_type:
evaluator_type = subclass
break
assert issubclass(evaluator_type, Evaluator)
return evaluator_type._load(ir)
@abc.abstractmethod
def _dump(self) -> Any:
"""
Subclass implements ``_dump`` for their own serialization.
They should return a dict, with a key ``type`` which equals ``self.__class__``,
and optionally other keys.
"""
pass
@abc.abstractmethod
......@@ -154,16 +159,13 @@ class Model:
if graph_name != '_evaluator':
Graph._load(model, graph_name, graph_data)._register()
if '_evaluator' in ir:
model.evaluator = Evaluator._load_with_type(ir['_evaluator']['__type__'], ir['_evaluator'])
model.evaluator = Evaluator._load(ir['_evaluator'])
return model
def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()}
if self.evaluator is not None:
ret['_evaluator'] = {
'__type__': get_importable_name(self.evaluator.__class__),
**self.evaluator._dump()
}
ret['_evaluator'] = self.evaluator._dump()
return ret
def get_nodes(self) -> Iterable['Node']:
......@@ -787,7 +789,7 @@ class DebugEvaluator(Evaluator):
return DebugEvaluator()
def _dump(self) -> Any:
return {'__type__': '_debug_no_trainer'}
return {'type': DebugEvaluator}
def _execute(self, model_cls: type) -> Any:
pass
......
......@@ -2,10 +2,11 @@
# Licensed under the MIT license.
import logging
import warnings
import os
from typing import Any, Callable
import nni
from nni.common.serializer import PayloadTooLarge
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType
......@@ -123,14 +124,16 @@ class RetiariiAdvisor(MsgDispatcherBase):
}
_logger.debug('New trial sent: %s', new_trial)
send_payload = nni.dump(new_trial, pickle_size_limit=-1)
if len(send_payload) > 256 * 1024:
warnings.warn(
'The total payload of the trial is larger than 50 KB. '
'This can cause performance issues and even the crash of NNI experiment. '
try:
send_payload = nni.dump(new_trial, pickle_size_limit=int(os.getenv('PICKLE_SIZE_LIMIT', 64 * 1024)))
except PayloadTooLarge:
raise ValueError(
'Serialization failed when trying to dump the model because payload too large (larger than 64 KB). '
'This is usually caused by pickling large objects (like datasets) by mistake. '
'See https://nni.readthedocs.io/en/stable/NAS/Serialization.html for details.'
'See the full error traceback for details and https://nni.readthedocs.io/en/stable/NAS/Serialization.html '
'for how to resolve such issue. '
)
# trial parameters can be super large, disable pickle size limit here
# nevertheless, there could still be blocked by pipe / nni-manager
send(CommandType.NewTrialJob, send_payload)
......@@ -175,4 +178,4 @@ class RetiariiAdvisor(MsgDispatcherBase):
return value['default']
else:
return value
return value
\ No newline at end of file
return value
......@@ -8,7 +8,7 @@ import torch.nn as nn
from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node
from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit
from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
from nni.retiarii.utils import uid
from .api import LayerChoice, InputChoice, ValueChoice, Placeholder
......@@ -223,7 +223,7 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
graph = Graph(model, uid(), '_model', _internal=True)._register()
model.python_class = pytorch_model.__class__
if len(inspect.signature(model.python_class.__init__).parameters) > 1:
if not getattr(pytorch_model, '_nni_model_wrapper', False):
if not is_model_wrapped(pytorch_model):
raise ValueError('Please annotate the model with @model_wrapper decorator in python execution mode '
'if your model has init parameters.')
model.python_init_params = pytorch_model.trace_kwargs
......
......@@ -39,6 +39,6 @@
},
"_evaluator": {
"__type__": "_debug_no_trainer"
"type": "DebugEvaluator"
}
}
......@@ -33,6 +33,11 @@ def _test_file(json_path):
# debug output
#json.dump(orig_ir, open('_orig.json', 'w'), indent=4)
#json.dump(dump_ir, open('_dump.json', 'w'), indent=4)
# skip comparison of _evaluator
orig_ir.pop('_evaluator')
dump_ir.pop('_evaluator')
assert orig_ir == dump_ir
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment