"src/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "ec84188ffc24a27ce1ce36148c099a49823a0b8e"
Unverified Commit 85fd39a7 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Serializer and experiment status fixes (#3421)

parent d69d4ae9
...@@ -6,7 +6,7 @@ from ..graph import Graph, Model, Node ...@@ -6,7 +6,7 @@ from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell, Operation from ..operation import Cell, Operation
from ..serializer import get_init_parameters_or_fail from ..serializer import get_init_parameters_or_fail
from ..utils import get_full_class_name from ..utils import get_importable_name
from .op_types import MODULE_EXCEPT_LIST, OpTypeName from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import _convert_name, build_full_name from .utils import _convert_name, build_full_name
...@@ -536,7 +536,7 @@ class GraphConverter: ...@@ -536,7 +536,7 @@ class GraphConverter:
def _handle_layerchoice(self, module): def _handle_layerchoice(self, module):
choices = [] choices = []
for cand in list(module): for cand in list(module):
cand_type = '__torch__.' + get_full_class_name(cand.__class__) cand_type = '__torch__.' + get_importable_name(cand.__class__)
choices.append({'type': cand_type, 'parameters': get_init_parameters_or_fail(cand)}) choices.append({'type': cand_type, 'parameters': get_init_parameters_or_fail(cand)})
return { return {
'candidates': choices, 'candidates': choices,
......
import logging import logging
import time
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from subprocess import Popen from subprocess import Popen
...@@ -92,6 +93,8 @@ class RetiariiExperiment(Experiment): ...@@ -92,6 +93,8 @@ class RetiariiExperiment(Experiment):
self._proc: Optional[Popen] = None self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None self._pipe: Optional[Pipe] = None
self._strategy_thread: Optional[Thread] = None
def _start_strategy(self): def _start_strategy(self):
try: try:
script_module = torch.jit.script(self.base_model) script_module = torch.jit.script(self.base_model)
...@@ -110,8 +113,11 @@ class RetiariiExperiment(Experiment): ...@@ -110,8 +113,11 @@ class RetiariiExperiment(Experiment):
self.applied_mutators = mutators self.applied_mutators = mutators
_logger.info('Starting strategy...') _logger.info('Starting strategy...')
Thread(target=self.strategy.run, args=(base_model_ir, self.applied_mutators)).start() # This is not intuitive and not friendly for debugging (setting breakpoints). Will refactor later.
self._strategy_thread = Thread(target=self.strategy.run, args=(base_model_ir, self.applied_mutators))
self._strategy_thread.start()
_logger.info('Strategy started!') _logger.info('Strategy started!')
Thread(target=self._strategy_monitor).start()
def start(self, port: int = 8080, debug: bool = False) -> None: def start(self, port: int = 8080, debug: bool = False) -> None:
""" """
...@@ -131,6 +137,10 @@ class RetiariiExperiment(Experiment): ...@@ -131,6 +137,10 @@ class RetiariiExperiment(Experiment):
def _create_dispatcher(self): def _create_dispatcher(self):
return self._dispatcher return self._dispatcher
def _strategy_monitor(self):
self._strategy_thread.join()
self._dispatcher.mark_experiment_as_ending()
def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str: def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str:
""" """
Run the experiment. Run the experiment.
......
...@@ -9,7 +9,7 @@ from enum import Enum ...@@ -9,7 +9,7 @@ from enum import Enum
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Union, overload) from typing import (Any, Dict, Iterable, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation from .operation import Cell, Operation, _IOPseudoOperation
from .utils import get_full_class_name, import_, uid from .utils import get_importable_name, import_, uid
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData'] __all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']
...@@ -147,7 +147,7 @@ class Model: ...@@ -147,7 +147,7 @@ class Model:
def _dump(self) -> Any: def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()} ret = {name: graph._dump() for name, graph in self.graphs.items()}
ret['_evaluator'] = { ret['_evaluator'] = {
'__type__': get_full_class_name(self.evaluator.__class__), '__type__': get_importable_name(self.evaluator.__class__),
**self.evaluator._dump() **self.evaluator._dump()
} }
return ret return ret
......
...@@ -105,6 +105,9 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -105,6 +105,9 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.send_trial_callback(parameters) # pylint: disable=not-callable self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count return self.parameters_count
def mark_experiment_as_ending(self):
send(CommandType.NoMoreTrialJobs, '')
def handle_request_trial_jobs(self, num_trials): def handle_request_trial_jobs(self, num_trials):
_logger.info('Request trial jobs: %s', num_trials) _logger.info('Request trial jobs: %s', num_trials)
if self.request_trial_jobs_callback is not None: if self.request_trial_jobs_callback is not None:
......
import abc import abc
import functools import functools
import inspect import inspect
import types
from typing import Any from typing import Any
import json_tricks import json_tricks
from .utils import get_full_class_name, get_module_name, import_ from .utils import get_importable_name, get_module_name, import_
def get_init_parameters_or_fail(obj, silently=False): def get_init_parameters_or_fail(obj, silently=False):
...@@ -29,7 +30,7 @@ def _serialize_class_instance_encode(obj, primitives=False): ...@@ -29,7 +30,7 @@ def _serialize_class_instance_encode(obj, primitives=False):
try: # FIXME: raise error try: # FIXME: raise error
if hasattr(obj, '__class__'): if hasattr(obj, '__class__'):
return { return {
'__type__': get_full_class_name(obj.__class__), '__type__': get_importable_name(obj.__class__),
'arguments': get_init_parameters_or_fail(obj) 'arguments': get_init_parameters_or_fail(obj)
} }
except ValueError: except ValueError:
...@@ -46,7 +47,11 @@ def _serialize_class_instance_decode(obj): ...@@ -46,7 +47,11 @@ def _serialize_class_instance_decode(obj):
def _type_encode(obj, primitives=False): def _type_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.' assert not primitives, 'Encoding with primitives is not supported.'
if isinstance(obj, type): if isinstance(obj, type):
return {'__typename__': get_full_class_name(obj, relocate_module=True)} return {'__typename__': get_importable_name(obj, relocate_module=True)}
if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
# This is not reliable for cases like closure, `open`, or objects that is callable but not intended to be serialized.
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return {'__typename__': get_importable_name(obj, relocate_module=True)}
return obj return obj
......
import inspect import inspect
import warnings
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
from pathlib import Path from pathlib import Path
...@@ -27,8 +28,8 @@ def uid(namespace: str = 'default') -> int: ...@@ -27,8 +28,8 @@ def uid(namespace: str = 'default') -> int:
return _last_uid[namespace] return _last_uid[namespace]
def get_module_name(cls): def get_module_name(cls_or_func):
module_name = cls.__module__ module_name = cls_or_func.__module__
if module_name == '__main__': if module_name == '__main__':
# infer the module name with inspect # infer the module name with inspect
for frm in inspect.stack(): for frm in inspect.stack():
...@@ -40,17 +41,20 @@ def get_module_name(cls): ...@@ -40,17 +41,20 @@ def get_module_name(cls):
f'please launch the experiment under the directory where "{main_file_path.name}" is located.') f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem module_name = main_file_path.stem
break break
if module_name == '__main__':
warnings.warn('Callstack exhausted but main module still not found. This will probably cause issues that the '
'function/class cannot be imported.')
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something. # NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to # to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__ # the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM # TODO: find out all the modules that have the same requirement as LSTM
if f'{cls.__module__}.{cls.__name__}' == 'torch.nn.modules.rnn.LSTM': if f'{cls_or_func.__module__}.{cls_or_func.__name__}' == 'torch.nn.modules.rnn.LSTM':
module_name = cls.__module__ module_name = cls_or_func.__module__
return module_name return module_name
def get_full_class_name(cls, relocate_module=False): def get_importable_name(cls, relocate_module=False):
module_name = get_module_name(cls) if relocate_module else cls.__module__ module_name = get_module_name(cls) if relocate_module else cls.__module__
return module_name + '.' + cls.__name__ return module_name + '.' + cls.__name__
import json import json
import math
from pathlib import Path from pathlib import Path
import re import re
import sys import sys
...@@ -84,6 +85,8 @@ def test_type(): ...@@ -84,6 +85,8 @@ def test_type():
assert json_dumps(torch.optim.Adam) == '{"__typename__": "torch.optim.adam.Adam"}' assert json_dumps(torch.optim.Adam) == '{"__typename__": "torch.optim.adam.Adam"}'
assert json_loads('{"__typename__": "torch.optim.adam.Adam"}') == torch.optim.Adam assert json_loads('{"__typename__": "torch.optim.adam.Adam"}') == torch.optim.Adam
assert re.match(r'{"__typename__": "(.*)test_serializer.Foo"}', json_dumps(Foo)) assert re.match(r'{"__typename__": "(.*)test_serializer.Foo"}', json_dumps(Foo))
assert json_dumps(math.floor) == '{"__typename__": "math.floor"}'
assert json_loads('{"__typename__": "math.floor"}') == math.floor
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment