Commit 1f45cf04 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

making bootstrap more robust

Summary:
X-link: https://github.com/facebookresearch/mobile-vision/pull/81

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/283

- add `MoreMagicMock`, which handles inheritance and comparison.
- also support lazy registering mocked objects (has to be `MoreMagicMock`).
- don't need to skip `skip files that doesn't contain ".register()" call` since we can handle most files pretty well now.
- also mock the open
- delay the import for `from detectron2.utils.testing import assert_instances_allclose`; for some reason python is doing magic things if you import anything starting with `assert`, so the mocked import doesn't work.
- makes log function nicer.

Reviewed By: tglik

Differential Revision: D36798327

fbshipit-source-id: ccda7e7583b95a24f3dde1bbe0468593dacb8663
parent e73947e1
...@@ -9,23 +9,34 @@ import logging ...@@ -9,23 +9,34 @@ import logging
import os import os
import time import time
import traceback import traceback
from unittest import mock
import pkg_resources import pkg_resources
from mobile_cv.common.misc.registry import LazyRegisterable, Registry from mobile_cv.common.misc.py import MoreMagicMock
from mobile_cv.common.misc.registry import (
CLASS_OR_FUNCTION_TYPES,
LazyRegisterable,
Registry,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
orig_import = builtins.__import__ orig_import = builtins.__import__
orig_do_register = Registry._do_register orig_open = builtins.open
orig__register = Registry._register
_INSIDE_BOOTSTRAP = False _INSIDE_BOOTSTRAP = False
_IS_BOOTSTRAPPED = False _IS_BOOTSTRAPPED = False
_VERBOSE_LEVEL = 0
_BOOTSTRAP_PACKAGE = "d2go.registry._bootstrap" _BOOTSTRAP_PACKAGE = "d2go.registry._bootstrap"
def _log(lvl, msg):
_VERBOSE_LEVEL = 0
if _VERBOSE_LEVEL >= lvl:
print(msg)
class _catchtime: class _catchtime:
def __enter__(self): def __enter__(self):
self.time = time.perf_counter() self.time = time.perf_counter()
...@@ -93,23 +104,45 @@ def _import_mock(name, globals=None, locals=None, fromlist=(), level=0): ...@@ -93,23 +104,45 @@ def _import_mock(name, globals=None, locals=None, fromlist=(), level=0):
return orig_import(name, globals, locals, fromlist=fromlist, level=level) return orig_import(name, globals, locals, fromlist=fromlist, level=level)
else: else:
# return a Mock instead of making a real import # return a Mock instead of making a real import
if _VERBOSE_LEVEL >= 2: _log(2, f"mock import: {name}; fromlist={fromlist}; level={level}")
print(f"mock import: {name}; fromlist={fromlist}; level={level}") m = MoreMagicMock()
m = mock.MagicMock()
m.__version__ = mock.MagicMock()
return m return m
def _do_register_mock(self, name, obj): def _open_mock(*args, **kwargs):
assert isinstance(name, str), f"Can't register use non-string name: {name}" return MoreMagicMock()
def _register_mock(self, name, obj):
"""Convert `obj` to LazyRegisterable"""
# Instead of register the (possibly mocked) object which is created under the # Instead of register the (possibly mocked) object which is created under the
# "fake" package _BOOTSTRAP_PACKAGE, register a lazy-object (i.e. a string) pointing # "fake" package _BOOTSTRAP_PACKAGE, register a lazy-object (i.e. a string) pointing
# to its original (possibly un-imported) module. # to its original (possibly un-imported) module.
assert obj.__module__.startswith(_BOOTSTRAP_PACKAGE + ".") def _resolve_real_module(module_in_bootstrap_package):
orig_module = obj.__module__[len(_BOOTSTRAP_PACKAGE + ".") :] assert module_in_bootstrap_package.startswith(_BOOTSTRAP_PACKAGE + ".")
registerable = LazyRegisterable(module=orig_module, name=obj.__qualname__) orig_module = module_in_bootstrap_package[len(_BOOTSTRAP_PACKAGE + ".") :]
return orig_do_register(self, name, registerable) return orig_module
if isinstance(obj, MoreMagicMock):
assert obj.mocked_obj_info is not None, obj
if name is None:
name = obj.mocked_obj_info["__name__"]
obj = LazyRegisterable(
module=_resolve_real_module(obj.mocked_obj_info["__module__"]),
name=obj.mocked_obj_info["__qualname__"],
)
elif isinstance(obj, LazyRegisterable):
pass
else:
assert isinstance(obj, CLASS_OR_FUNCTION_TYPES), obj
if name is None:
name = obj.__name__
obj = LazyRegisterable(
module=_resolve_real_module(obj.__module__), name=obj.__qualname__
)
return orig__register(self, name, obj)
@contextlib.contextmanager @contextlib.contextmanager
...@@ -117,14 +150,16 @@ def _bootstrap_patch(): ...@@ -117,14 +150,16 @@ def _bootstrap_patch():
global _INSIDE_BOOTSTRAP global _INSIDE_BOOTSTRAP
builtins.__import__ = _import_mock builtins.__import__ = _import_mock
Registry._do_register = _do_register_mock builtins.open = _open_mock
Registry._register = _register_mock
_INSIDE_BOOTSTRAP = True _INSIDE_BOOTSTRAP = True
try: try:
yield yield
finally: finally:
builtins.__import__ = orig_import builtins.__import__ = orig_import
Registry._do_register = orig_do_register builtins.open = orig_open
Registry._register = orig__register
_INSIDE_BOOTSTRAP = False _INSIDE_BOOTSTRAP = False
...@@ -150,28 +185,23 @@ def _bootstrap_file(filename): ...@@ -150,28 +185,23 @@ def _bootstrap_file(filename):
with _catchtime() as t: with _catchtime() as t:
with open(filename) as f: with open(filename) as f:
content = f.read() content = f.read()
# skip files that doesn't contain ".register()" call, this would filter out many
# files and speed up the process
if ".register" not in content:
if _VERBOSE_LEVEL >= 2:
print("Skip file because there's no `.register()` call")
return
tree = ast.parse(content) tree = ast.parse(content)
# remove all the class inheritance. eg. `class MyClass(nn.Module)` -> `class MyClass()`
# HACK: convert multiple inheritance to single inheritance, this is needed
# because current implementation of MoreMagicMock can't handle this well.
# eg. `class MyClass(MyMixin, nn.Module)` -> `class MyClass(MyMixin)`
for stmt in tree.body: for stmt in tree.body:
if isinstance(stmt, ast.ClassDef): if isinstance(stmt, ast.ClassDef):
stmt.bases.clear() if len(stmt.bases) > 1:
stmt.bases = stmt.bases[:1]
stmt.keywords.clear() stmt.keywords.clear()
if _VERBOSE_LEVEL >= 2:
print(f"Parsing AST takes {t.time} sec") _log(2, f"Parsing AST takes {t.time} sec")
with _catchtime() as t: with _catchtime() as t:
with _bootstrap_patch(): with _bootstrap_patch():
exec(compile(tree, filename, "exec"), exec_globals) # noqa exec(compile(tree, filename, "exec"), exec_globals) # noqa
if _VERBOSE_LEVEL >= 2: _log(2, f"Execute file takes {t.time} sec")
print(f"Execute file takes {t.time} sec")
class _BootstrapBreakException(Exception): class _BootstrapBreakException(Exception):
...@@ -208,8 +238,7 @@ def bootstrap_registries(catch_exception=True): ...@@ -208,8 +238,7 @@ def bootstrap_registries(catch_exception=True):
return return
if _INSIDE_BOOTSTRAP: if _INSIDE_BOOTSTRAP:
if _VERBOSE_LEVEL >= 1: _log(1, "calling bootstrap_registries() inside bootstrap process, skip ...")
print("calling bootstrap_registries() inside bootstrap process, skip ...")
return return
start = time.perf_counter() start = time.perf_counter()
...@@ -224,10 +253,10 @@ def bootstrap_registries(catch_exception=True): ...@@ -224,10 +253,10 @@ def bootstrap_registries(catch_exception=True):
exception_files = [] exception_files = []
time_per_file = {} time_per_file = {}
for filename in all_files: for filename in all_files:
if _VERBOSE_LEVEL >= 1: _log(
print( 1,
f"bootstrap for file under d2go_root: {os.path.relpath(filename, d2go_root)}" f"bootstrap for file under d2go_root: {os.path.relpath(filename, d2go_root)}",
) )
with _catchtime() as t: with _catchtime() as t:
try: try:
...@@ -238,10 +267,13 @@ def bootstrap_registries(catch_exception=True): ...@@ -238,10 +267,13 @@ def bootstrap_registries(catch_exception=True):
continue continue
except Exception as e: except Exception as e:
if catch_exception: if catch_exception:
if _VERBOSE_LEVEL >= 1: _log(
print("Encountered the following error during bootstrap:") 1,
traceback.print_exception(type(e), e, e.__traceback__) "Encountered the following error during bootstrap:"
print("") + "".join(
traceback.format_exception(type(e), e, e.__traceback__)
),
)
exception_files.append(filename) exception_files.append(filename)
else: else:
raise e raise e
...@@ -262,13 +294,11 @@ def bootstrap_registries(catch_exception=True): ...@@ -262,13 +294,11 @@ def bootstrap_registries(catch_exception=True):
) )
) )
if _VERBOSE_LEVEL >= 2: # Log slowest Top-N files
TOP_N = 100 TOP_N = 100
print(f"Top-{TOP_N} slowest files during bootstrap:") _log(2, f"Top-{TOP_N} slowest files during bootstrap:")
all_time = [ all_time = [(os.path.relpath(k, d2go_root), v) for k, v in time_per_file.items()]
(os.path.relpath(k, d2go_root), v) for k, v in time_per_file.items() for x in sorted(all_time, key=lambda x: x[1])[-TOP_N:]:
] _log(2, x)
for x in sorted(all_time, key=lambda x: x[1])[-TOP_N:]:
print(x)
_IS_BOOTSTRAPPED = True _IS_BOOTSTRAPPED = True
...@@ -16,7 +16,6 @@ from d2go.utils.testing.data_loader_helper import ( ...@@ -16,7 +16,6 @@ from d2go.utils.testing.data_loader_helper import (
create_detection_data_loader_on_toy_dataset, create_detection_data_loader_on_toy_dataset,
) )
from detectron2.structures import Boxes, Instances from detectron2.structures import Boxes, Instances
from detectron2.utils.testing import assert_instances_allclose
from mobile_cv.predictor.api import create_predictor from mobile_cv.predictor.api import create_predictor
from parameterized import parameterized from parameterized import parameterized
...@@ -354,6 +353,8 @@ class RCNNBaseTestCases: ...@@ -354,6 +353,8 @@ class RCNNBaseTestCases:
with torch.no_grad(): with torch.no_grad():
pytorch_outputs = self.test_model(inputs) pytorch_outputs = self.test_model(inputs)
from detectron2.utils.testing import assert_instances_allclose
assert_instances_allclose( assert_instances_allclose(
predictor_outputs[0]["instances"], predictor_outputs[0]["instances"],
pytorch_outputs[0]["instances"], pytorch_outputs[0]["instances"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment