Commit 8098d160 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

support cache for registry bootstrap

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/286

Support cache for registry bootstrap.

Reviewed By: tglik

Differential Revision: D36798475

fbshipit-source-id: 5f32093d6d8d1d1db896d6b2567ddd0187d66248
parent 6d9a016a
...@@ -5,13 +5,20 @@ import ast ...@@ -5,13 +5,20 @@ import ast
import builtins import builtins
import contextlib import contextlib
import glob import glob
import hashlib
import logging import logging
import os import os
import tempfile
import time import time
import traceback import traceback
from collections import defaultdict
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
import pkg_resources import pkg_resources
from mobile_cv.common.misc.py import MoreMagicMock import yaml
from mobile_cv.common.misc.py import dynamic_import, MoreMagicMock
from mobile_cv.common.misc.registry import ( from mobile_cv.common.misc.registry import (
CLASS_OR_FUNCTION_TYPES, CLASS_OR_FUNCTION_TYPES,
LazyRegisterable, LazyRegisterable,
...@@ -28,15 +35,29 @@ _INSIDE_BOOTSTRAP = False ...@@ -28,15 +35,29 @@ _INSIDE_BOOTSTRAP = False
_IS_BOOTSTRAPPED = False _IS_BOOTSTRAPPED = False
_BOOTSTRAP_PACKAGE = "d2go.registry._bootstrap" _BOOTSTRAP_PACKAGE = "d2go.registry._bootstrap"
_BOOTSTRAP_CACHE_FILENAME = "registry_bootstrap.v1.yaml"
def _log(lvl, msg): def _log(lvl: int, msg: str):
_VERBOSE_LEVEL = 0 _VERBOSE_LEVEL = 0
if _VERBOSE_LEVEL >= lvl: if _VERBOSE_LEVEL >= lvl:
print(msg) print(msg)
# Simple version copied from fvcore/iopath
def _get_cache_dir() -> str:
cache_dir = os.path.expanduser("~/.torch/d2go_cache")
try:
os.makedirs(cache_dir, exist_ok=True)
assert os.access(cache_dir, os.R_OK | os.W_OK | os.X_OK)
except (OSError, AssertionError):
tmp_dir = os.path.join(tempfile.gettempdir(), "d2go_cache")
logger.warning(f"{cache_dir} is not accessible! Using {tmp_dir} instead!")
cache_dir = tmp_dir
return cache_dir
class _catchtime: class _catchtime:
def __enter__(self): def __enter__(self):
self.time = time.perf_counter() self.time = time.perf_counter()
...@@ -113,7 +134,7 @@ def _open_mock(*args, **kwargs): ...@@ -113,7 +134,7 @@ def _open_mock(*args, **kwargs):
return MoreMagicMock() return MoreMagicMock()
def _register_mock(self, name, obj): def _register_mock(self, name: Optional[str], obj: Any):
"""Convert `obj` to LazyRegisterable""" """Convert `obj` to LazyRegisterable"""
# Instead of register the (possibly mocked) object which is created under the # Instead of register the (possibly mocked) object which is created under the
...@@ -163,13 +184,52 @@ def _bootstrap_patch(): ...@@ -163,13 +184,52 @@ def _bootstrap_patch():
_INSIDE_BOOTSTRAP = False _INSIDE_BOOTSTRAP = False
def _bootstrap_file(filename): def _get_registered_names() -> Dict[str, List[str]]:
# convert absolute path to full module name """Return the currently registered names for each registry"""
# NOTE: currently only support D2Go's builtin registry module, which can be extended
# in future.
import d2go.registry.builtin
modules = [
d2go.registry.builtin,
]
registered = {}
for module in modules:
registered_in_module = {
f"{module.__name__}.{name}": obj.get_names()
for name, obj in module.__dict__.items()
if isinstance(obj, Registry)
}
registered.update(registered_in_module)
return registered
class BootstrapStatus(Enum):
CACHED = 0
FULLY_IMPORTED = 1
PARTIALLY_IMPORTED = 2
FAILED = 3
@dataclass
class CachedResult:
sha1: str
registered: Dict[str, str]
status: str # string representation of BootstrapStatus
def _bootstrap_file(
rel_path: str,
catch_exception: bool,
cached_result: Optional[CachedResult] = None,
) -> Tuple[CachedResult, BootstrapStatus]:
# convert relative path to full module name
# eg. ".../d2go/a/b/c.py" -> "d2go.a.b.c" # eg. ".../d2go/a/b/c.py" -> "d2go.a.b.c"
# eg. ".../d2go/a/b/__init__.py" -> "d2go.a.b" # eg. ".../d2go/a/b/__init__.py" -> "d2go.a.b"
package_root = os.path.dirname(pkg_resources.resource_filename("d2go", "")) package_root = os.path.dirname(pkg_resources.resource_filename("d2go", ""))
assert filename.startswith(package_root), (filename, package_root) filename = os.path.join(package_root, rel_path)
rel_path = os.path.relpath(filename, package_root)
assert rel_path.endswith(".py") assert rel_path.endswith(".py")
module = rel_path[: -len(".py")] module = rel_path[: -len(".py")]
if module.endswith("/__init__"): if module.endswith("/__init__"):
...@@ -185,6 +245,22 @@ def _bootstrap_file(filename): ...@@ -185,6 +245,22 @@ def _bootstrap_file(filename):
with _catchtime() as t: with _catchtime() as t:
with open(filename) as f: with open(filename) as f:
content = f.read() content = f.read()
file_hash = hashlib.sha1(content.encode("utf-8")).hexdigest()
if cached_result is not None and file_hash == cached_result.sha1:
_log(
2,
f"Hash {file_hash} matches, lazy registering cached registerables ...",
)
registerables = cached_result.registered
for registry_module_dot_name, names_to_register in registerables.items():
registry = dynamic_import(registry_module_dot_name)
for name in names_to_register:
# we only store the registered name in the cache, here we know the
# module of bootstrapped file, which should be sufficient.
registry.register(name, LazyRegisterable(module=module))
return cached_result, BootstrapStatus.CACHED
tree = ast.parse(content) tree = ast.parse(content)
# HACK: convert multiple inheritance to single inheritance, this is needed # HACK: convert multiple inheritance to single inheritance, this is needed
...@@ -198,11 +274,41 @@ def _bootstrap_file(filename): ...@@ -198,11 +274,41 @@ def _bootstrap_file(filename):
_log(2, f"Parsing AST takes {t.time} sec") _log(2, f"Parsing AST takes {t.time} sec")
prev_registered = _get_registered_names()
with _catchtime() as t: with _catchtime() as t:
with _bootstrap_patch(): try:
exec(compile(tree, filename, "exec"), exec_globals) # noqa with _bootstrap_patch():
exec(compile(tree, filename, "exec"), exec_globals) # noqa
status = BootstrapStatus.FULLY_IMPORTED
except _BootstrapBreakException:
status = BootstrapStatus.PARTIALLY_IMPORTED
except Exception as e:
if catch_exception:
_log(
1,
"Encountered the following error during bootstrap:"
+ "".join(traceback.format_exception(type(e), e, e.__traceback__)),
)
else:
raise e
status = BootstrapStatus.FAILED
_log(2, f"Execute file takes {t.time} sec") _log(2, f"Execute file takes {t.time} sec")
# compare and get the newly registered
cur_registered = _get_registered_names()
assert set(cur_registered.keys()) == set(prev_registered.keys())
newly_registered = {
k: sorted(set(cur_registered[k]) - set(prev_registered[k]))
for k in sorted(cur_registered.keys())
}
newly_registered = {k: v for k, v in newly_registered.items() if len(v) > 0}
result = CachedResult(
sha1=file_hash,
registered=newly_registered,
status=status.name,
)
return result, status
class _BootstrapBreakException(Exception): class _BootstrapBreakException(Exception):
pass pass
...@@ -223,7 +329,26 @@ def break_bootstrap(): ...@@ -223,7 +329,26 @@ def break_bootstrap():
return return
def bootstrap_registries(catch_exception=True): def _load_cached_results(filename: str) -> Dict[str, CachedResult]:
with open(filename) as f:
loaded = yaml.safe_load(f)
assert isinstance(loaded, dict), f"Wrong format: {filename}"
results = {
filename: CachedResult(**result_dic) for filename, result_dic in loaded.items()
}
return results
def _dump_cached_results(cached_results: Dict[str, CachedResult], filename: str):
results_dict = {
filename: asdict(result_dic) for filename, result_dic in cached_results.items()
}
dumped = yaml.safe_dump(results_dict)
with open(filename, "w") as f:
f.write(dumped)
def bootstrap_registries(enable_cache: bool = True, catch_exception: bool = True):
""" """
Bootstrap all registries so that all objects are effectively registered. Bootstrap all registries so that all objects are effectively registered.
...@@ -243,52 +368,57 @@ def bootstrap_registries(catch_exception=True): ...@@ -243,52 +368,57 @@ def bootstrap_registries(catch_exception=True):
start = time.perf_counter() start = time.perf_counter()
# load cached bootstrap results if exist
cached_bootstrap_results: Dict[str, CachedResult] = {}
if enable_cache:
filename = os.path.join(_get_cache_dir(), _BOOTSTRAP_CACHE_FILENAME)
if os.path.isfile(filename):
logger.info(f"Loading bootstrap cache at {filename} ...")
cached_bootstrap_results = _load_cached_results(filename)
else:
logger.info(
f"Can't find the bootstrap cache at {filename}, start from scratch"
)
# locate all the files under d2go package # locate all the files under d2go package
# NOTE: we may extend to support user-defined locations if necessary # NOTE: we may extend to support user-defined locations if necessary
d2go_root = pkg_resources.resource_filename("d2go", "") d2go_root = pkg_resources.resource_filename("d2go", "")
logger.info(f"Start bootstrapping for d2go_root: {d2go_root} ...") logger.info(f"Start bootstrapping for d2go_root: {d2go_root} ...")
all_files = glob.glob(f"{d2go_root}/**/*.py", recursive=True) all_files = glob.glob(f"{d2go_root}/**/*.py", recursive=True)
all_files = [os.path.relpath(x, os.path.dirname(d2go_root)) for x in all_files]
skip_files = [] new_bootstrap_results: Dict[str, CachedResult] = {}
exception_files = [] files_per_status = defaultdict(list)
time_per_file = {} time_per_file = {}
for filename in all_files: for filename in all_files:
_log( _log(1, f"bootstrap for file: {filename}")
1,
f"bootstrap for file under d2go_root: {os.path.relpath(filename, d2go_root)}",
)
cached_result = cached_bootstrap_results.get(filename, None)
with _catchtime() as t: with _catchtime() as t:
try: result, status = _bootstrap_file(filename, catch_exception, cached_result)
_bootstrap_file(filename) new_bootstrap_results[filename] = result
except _BootstrapBreakException: files_per_status[status].append(filename)
# the bootstrap process is manually skipped
skip_files.append(filename)
continue
except Exception as e:
if catch_exception:
_log(
1,
"Encountered the following error during bootstrap:"
+ "".join(
traceback.format_exception(type(e), e, e.__traceback__)
),
)
exception_files.append(filename)
else:
raise e
time_per_file[filename] = t.time time_per_file[filename] = t.time
end = time.perf_counter() end = time.perf_counter()
duration = end - start duration = end - start
status_breakdown = ", ".join(
[f"{len(files_per_status[status])} {status.name}" for status in BootstrapStatus]
)
logger.info( logger.info(
f"Finished bootstrapping for {len(all_files)} files ({len(skip_files)} break-ed)" f"Finished bootstrapping for {len(all_files)} files ({status_breakdown})"
f" in {duration:.2f} seconds." f" in {duration:.2f} seconds."
) )
exception_files = [
filename
for filename, result in new_bootstrap_results.items()
if result.status == BootstrapStatus.FAILED.name
]
if len(exception_files) > 0: if len(exception_files) > 0:
logger.warning( logger.warning(
"Encountered error bootstrapping following {} files," "Found exception for the following {} files (either during this bootstrap"
" registration inside those files might not work!\n{}".format( " run or from previous cached result), registration inside those files"
" might not work!\n{}".format(
len(exception_files), len(exception_files),
"\n".join(exception_files), "\n".join(exception_files),
) )
...@@ -301,4 +431,9 @@ def bootstrap_registries(catch_exception=True): ...@@ -301,4 +431,9 @@ def bootstrap_registries(catch_exception=True):
for x in sorted(all_time, key=lambda x: x[1])[-TOP_N:]: for x in sorted(all_time, key=lambda x: x[1])[-TOP_N:]:
_log(2, x) _log(2, x)
if enable_cache:
filename = os.path.join(_get_cache_dir(), _BOOTSTRAP_CACHE_FILENAME)
logger.info(f"Writing updated bootstrap results to {filename} ...")
_dump_cached_results(new_bootstrap_results, filename)
_IS_BOOTSTRAPPED = True _IS_BOOTSTRAPPED = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment