Commit 1f45cf04 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

making bootstrap more robust

Summary:
X-link: https://github.com/facebookresearch/mobile-vision/pull/81

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/283

- add `MoreMagicMock`, which handles inheritance and comparison.
- also support lazy registering mocked objects (has to be `MoreMagicMock`).
- don't need to skip `skip files that doesn't contain ".register()" call` since we can handle most files pretty well now.
- also mock the open
- delay the import for `from detectron2.utils.testing import assert_instances_allclose`; for some reason python is doing magic things if you import anything starting with `assert`, so the mocked import doesn't work.
- makes log function nicer.

Reviewed By: tglik

Differential Revision: D36798327

fbshipit-source-id: ccda7e7583b95a24f3dde1bbe0468593dacb8663
parent e73947e1
......@@ -9,23 +9,34 @@ import logging
import os
import time
import traceback
from unittest import mock
import pkg_resources
from mobile_cv.common.misc.registry import LazyRegisterable, Registry
from mobile_cv.common.misc.py import MoreMagicMock
from mobile_cv.common.misc.registry import (
CLASS_OR_FUNCTION_TYPES,
LazyRegisterable,
Registry,
)
logger = logging.getLogger(__name__)
orig_import = builtins.__import__
orig_do_register = Registry._do_register
orig_open = builtins.open
orig__register = Registry._register
_INSIDE_BOOTSTRAP = False
_IS_BOOTSTRAPPED = False
_VERBOSE_LEVEL = 0
_BOOTSTRAP_PACKAGE = "d2go.registry._bootstrap"
def _log(lvl, msg):
_VERBOSE_LEVEL = 0
if _VERBOSE_LEVEL >= lvl:
print(msg)
class _catchtime:
def __enter__(self):
self.time = time.perf_counter()
......@@ -93,23 +104,45 @@ def _import_mock(name, globals=None, locals=None, fromlist=(), level=0):
return orig_import(name, globals, locals, fromlist=fromlist, level=level)
else:
# return a Mock instead of making a real import
if _VERBOSE_LEVEL >= 2:
print(f"mock import: {name}; fromlist={fromlist}; level={level}")
m = mock.MagicMock()
m.__version__ = mock.MagicMock()
_log(2, f"mock import: {name}; fromlist={fromlist}; level={level}")
m = MoreMagicMock()
return m
def _do_register_mock(self, name, obj):
assert isinstance(name, str), f"Can't register use non-string name: {name}"
def _open_mock(*args, **kwargs):
return MoreMagicMock()
def _register_mock(self, name, obj):
"""Convert `obj` to LazyRegisterable"""
# Instead of register the (possibly mocked) object which is created under the
# "fake" package _BOOTSTRAP_PACKAGE, register a lazy-object (i.e. a string) pointing
# to its original (possibly un-imported) module.
assert obj.__module__.startswith(_BOOTSTRAP_PACKAGE + ".")
orig_module = obj.__module__[len(_BOOTSTRAP_PACKAGE + ".") :]
registerable = LazyRegisterable(module=orig_module, name=obj.__qualname__)
return orig_do_register(self, name, registerable)
def _resolve_real_module(module_in_bootstrap_package):
assert module_in_bootstrap_package.startswith(_BOOTSTRAP_PACKAGE + ".")
orig_module = module_in_bootstrap_package[len(_BOOTSTRAP_PACKAGE + ".") :]
return orig_module
if isinstance(obj, MoreMagicMock):
assert obj.mocked_obj_info is not None, obj
if name is None:
name = obj.mocked_obj_info["__name__"]
obj = LazyRegisterable(
module=_resolve_real_module(obj.mocked_obj_info["__module__"]),
name=obj.mocked_obj_info["__qualname__"],
)
elif isinstance(obj, LazyRegisterable):
pass
else:
assert isinstance(obj, CLASS_OR_FUNCTION_TYPES), obj
if name is None:
name = obj.__name__
obj = LazyRegisterable(
module=_resolve_real_module(obj.__module__), name=obj.__qualname__
)
return orig__register(self, name, obj)
@contextlib.contextmanager
......@@ -117,14 +150,16 @@ def _bootstrap_patch():
global _INSIDE_BOOTSTRAP
builtins.__import__ = _import_mock
Registry._do_register = _do_register_mock
builtins.open = _open_mock
Registry._register = _register_mock
_INSIDE_BOOTSTRAP = True
try:
yield
finally:
builtins.__import__ = orig_import
Registry._do_register = orig_do_register
builtins.open = orig_open
Registry._register = orig__register
_INSIDE_BOOTSTRAP = False
......@@ -150,28 +185,23 @@ def _bootstrap_file(filename):
with _catchtime() as t:
with open(filename) as f:
content = f.read()
# skip files that doesn't contain ".register()" call, this would filter out many
# files and speed up the process
if ".register" not in content:
if _VERBOSE_LEVEL >= 2:
print("Skip file because there's no `.register()` call")
return
tree = ast.parse(content)
# remove all the class inheritance. eg. `class MyClass(nn.Module)` -> `class MyClass()`
# HACK: convert multiple inheritance to single inheritance, this is needed
# because current implementation of MoreMagicMock can't handle this well.
# eg. `class MyClass(MyMixin, nn.Module)` -> `class MyClass(MyMixin)`
for stmt in tree.body:
if isinstance(stmt, ast.ClassDef):
stmt.bases.clear()
if len(stmt.bases) > 1:
stmt.bases = stmt.bases[:1]
stmt.keywords.clear()
if _VERBOSE_LEVEL >= 2:
print(f"Parsing AST takes {t.time} sec")
_log(2, f"Parsing AST takes {t.time} sec")
with _catchtime() as t:
with _bootstrap_patch():
exec(compile(tree, filename, "exec"), exec_globals) # noqa
if _VERBOSE_LEVEL >= 2:
print(f"Execute file takes {t.time} sec")
_log(2, f"Execute file takes {t.time} sec")
class _BootstrapBreakException(Exception):
......@@ -208,8 +238,7 @@ def bootstrap_registries(catch_exception=True):
return
if _INSIDE_BOOTSTRAP:
if _VERBOSE_LEVEL >= 1:
print("calling bootstrap_registries() inside bootstrap process, skip ...")
_log(1, "calling bootstrap_registries() inside bootstrap process, skip ...")
return
start = time.perf_counter()
......@@ -224,9 +253,9 @@ def bootstrap_registries(catch_exception=True):
exception_files = []
time_per_file = {}
for filename in all_files:
if _VERBOSE_LEVEL >= 1:
print(
f"bootstrap for file under d2go_root: {os.path.relpath(filename, d2go_root)}"
_log(
1,
f"bootstrap for file under d2go_root: {os.path.relpath(filename, d2go_root)}",
)
with _catchtime() as t:
......@@ -238,10 +267,13 @@ def bootstrap_registries(catch_exception=True):
continue
except Exception as e:
if catch_exception:
if _VERBOSE_LEVEL >= 1:
print("Encountered the following error during bootstrap:")
traceback.print_exception(type(e), e, e.__traceback__)
print("")
_log(
1,
"Encountered the following error during bootstrap:"
+ "".join(
traceback.format_exception(type(e), e, e.__traceback__)
),
)
exception_files.append(filename)
else:
raise e
......@@ -262,13 +294,11 @@ def bootstrap_registries(catch_exception=True):
)
)
if _VERBOSE_LEVEL >= 2:
# Log slowest Top-N files
TOP_N = 100
print(f"Top-{TOP_N} slowest files during bootstrap:")
all_time = [
(os.path.relpath(k, d2go_root), v) for k, v in time_per_file.items()
]
_log(2, f"Top-{TOP_N} slowest files during bootstrap:")
all_time = [(os.path.relpath(k, d2go_root), v) for k, v in time_per_file.items()]
for x in sorted(all_time, key=lambda x: x[1])[-TOP_N:]:
print(x)
_log(2, x)
_IS_BOOTSTRAPPED = True
......@@ -16,7 +16,6 @@ from d2go.utils.testing.data_loader_helper import (
create_detection_data_loader_on_toy_dataset,
)
from detectron2.structures import Boxes, Instances
from detectron2.utils.testing import assert_instances_allclose
from mobile_cv.predictor.api import create_predictor
from parameterized import parameterized
......@@ -354,6 +353,8 @@ class RCNNBaseTestCases:
with torch.no_grad():
pytorch_outputs = self.test_model(inputs)
from detectron2.utils.testing import assert_instances_allclose
assert_instances_allclose(
predictor_outputs[0]["instances"],
pytorch_outputs[0]["instances"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment