Unverified Commit 85fd39a7 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Serializer and experiment status fixes (#3421)

parent d69d4ae9
......@@ -6,7 +6,7 @@ from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell, Operation
from ..serializer import get_init_parameters_or_fail
from ..utils import get_full_class_name
from ..utils import get_importable_name
from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import _convert_name, build_full_name
......@@ -536,7 +536,7 @@ class GraphConverter:
def _handle_layerchoice(self, module):
choices = []
for cand in list(module):
cand_type = '__torch__.' + get_full_class_name(cand.__class__)
cand_type = '__torch__.' + get_importable_name(cand.__class__)
choices.append({'type': cand_type, 'parameters': get_init_parameters_or_fail(cand)})
return {
'candidates': choices,
......
import logging
import time
from dataclasses import dataclass
from pathlib import Path
from subprocess import Popen
......@@ -92,6 +93,8 @@ class RetiariiExperiment(Experiment):
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
self._strategy_thread: Optional[Thread] = None
def _start_strategy(self):
try:
script_module = torch.jit.script(self.base_model)
......@@ -110,8 +113,11 @@ class RetiariiExperiment(Experiment):
self.applied_mutators = mutators
_logger.info('Starting strategy...')
Thread(target=self.strategy.run, args=(base_model_ir, self.applied_mutators)).start()
# This is not intuitive and not friendly for debugging (setting breakpoints). Will refactor later.
self._strategy_thread = Thread(target=self.strategy.run, args=(base_model_ir, self.applied_mutators))
self._strategy_thread.start()
_logger.info('Strategy started!')
Thread(target=self._strategy_monitor).start()
def start(self, port: int = 8080, debug: bool = False) -> None:
"""
......@@ -131,6 +137,10 @@ class RetiariiExperiment(Experiment):
def _create_dispatcher(self):
return self._dispatcher
def _strategy_monitor(self):
self._strategy_thread.join()
self._dispatcher.mark_experiment_as_ending()
def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str:
"""
Run the experiment.
......
......@@ -9,7 +9,7 @@ from enum import Enum
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation
from .utils import get_full_class_name, import_, uid
from .utils import get_importable_name, import_, uid
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']
......@@ -147,7 +147,7 @@ class Model:
def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()}
ret['_evaluator'] = {
'__type__': get_full_class_name(self.evaluator.__class__),
'__type__': get_importable_name(self.evaluator.__class__),
**self.evaluator._dump()
}
return ret
......
......@@ -105,6 +105,9 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count
def mark_experiment_as_ending(self):
send(CommandType.NoMoreTrialJobs, '')
def handle_request_trial_jobs(self, num_trials):
_logger.info('Request trial jobs: %s', num_trials)
if self.request_trial_jobs_callback is not None:
......
import abc
import functools
import inspect
import types
from typing import Any
import json_tricks
from .utils import get_full_class_name, get_module_name, import_
from .utils import get_importable_name, get_module_name, import_
def get_init_parameters_or_fail(obj, silently=False):
......@@ -29,7 +30,7 @@ def _serialize_class_instance_encode(obj, primitives=False):
try: # FIXME: raise error
if hasattr(obj, '__class__'):
return {
'__type__': get_full_class_name(obj.__class__),
'__type__': get_importable_name(obj.__class__),
'arguments': get_init_parameters_or_fail(obj)
}
except ValueError:
......@@ -46,7 +47,11 @@ def _serialize_class_instance_decode(obj):
def _type_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
if isinstance(obj, type):
return {'__typename__': get_full_class_name(obj, relocate_module=True)}
return {'__typename__': get_importable_name(obj, relocate_module=True)}
if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
# This is not reliable for cases like closure, `open`, or objects that is callable but not intended to be serialized.
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return {'__typename__': get_importable_name(obj, relocate_module=True)}
return obj
......
import inspect
import warnings
from collections import defaultdict
from typing import Any
from pathlib import Path
......@@ -27,8 +28,8 @@ def uid(namespace: str = 'default') -> int:
return _last_uid[namespace]
def get_module_name(cls):
module_name = cls.__module__
def get_module_name(cls_or_func):
module_name = cls_or_func.__module__
if module_name == '__main__':
# infer the module name with inspect
for frm in inspect.stack():
......@@ -40,17 +41,20 @@ def get_module_name(cls):
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem
break
if module_name == '__main__':
warnings.warn('Callstack exhausted but main module still not found. This will probably cause issues that the '
'function/class cannot be imported.')
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
if f'{cls.__module__}.{cls.__name__}' == 'torch.nn.modules.rnn.LSTM':
module_name = cls.__module__
if f'{cls_or_func.__module__}.{cls_or_func.__name__}' == 'torch.nn.modules.rnn.LSTM':
module_name = cls_or_func.__module__
return module_name
def get_full_class_name(cls, relocate_module=False):
def get_importable_name(cls, relocate_module=False):
module_name = get_module_name(cls) if relocate_module else cls.__module__
return module_name + '.' + cls.__name__
import json
import math
from pathlib import Path
import re
import sys
......@@ -84,6 +85,8 @@ def test_type():
assert json_dumps(torch.optim.Adam) == '{"__typename__": "torch.optim.adam.Adam"}'
assert json_loads('{"__typename__": "torch.optim.adam.Adam"}') == torch.optim.Adam
assert re.match(r'{"__typename__": "(.*)test_serializer.Foo"}', json_dumps(Foo))
assert json_dumps(math.floor) == '{"__typename__": "math.floor"}'
assert json_loads('{"__typename__": "math.floor"}') == math.floor
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment