Commit 82295dbf authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

enable black for mobile-vision

Summary:
https://fb.workplace.com/groups/pythonfoundation/posts/2990917737888352

Remove `mobile-vision` from opt-out list; leaving `mobile-vision/SNPE` opted out because of 3rd-party code.

arc lint --take BLACK --apply-patches --paths-cmd 'hg files mobile-vision'

allow-large-files

Reviewed By: sstsai-adl

Differential Revision: D30721093

fbshipit-source-id: 9e5c16d988b315b93a28038443ecfb92efd18ef8
parent a56c7e15
...@@ -4,12 +4,15 @@ ...@@ -4,12 +4,15 @@
import copy import copy
class FBNetV2ModelArch(object): class FBNetV2ModelArch(object):
_MODEL_ARCH = {} _MODEL_ARCH = {}
@staticmethod @staticmethod
def add(name, arch): def add(name, arch):
assert name not in FBNetV2ModelArch._MODEL_ARCH, \ assert (
"Arch name '{}' is already existed".format(name) name not in FBNetV2ModelArch._MODEL_ARCH
), "Arch name '{}' is already existed".format(name)
FBNetV2ModelArch._MODEL_ARCH[name] = arch FBNetV2ModelArch._MODEL_ARCH[name] = arch
@staticmethod @staticmethod
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
import copy import copy
from mobile_cv.arch.fbnet_v2.modeldef_utils import _ex, e1, e2, e1p, e3, e4, e6
from d2go.modeling.modeldef.fbnet_modeldef_registry import FBNetV2ModelArch from d2go.modeling.modeldef.fbnet_modeldef_registry import FBNetV2ModelArch
from mobile_cv.arch.fbnet_v2.modeldef_utils import _ex, e1, e2, e1p, e3, e4, e6
def _mutated_tuple(tp, pos, value): def _mutated_tuple(tp, pos, value):
...@@ -34,7 +35,6 @@ _BASIC_ARGS = { ...@@ -34,7 +35,6 @@ _BASIC_ARGS = {
# FBNetV1 builder. # FBNetV1 builder.
# "always_pw": True, # "always_pw": True,
# "bias": False, # "bias": False,
# temporarily disable zero_last_bn_gamma # temporarily disable zero_last_bn_gamma
"zero_last_bn_gamma": False, "zero_last_bn_gamma": False,
} }
...@@ -59,10 +59,7 @@ IRF_CFG = {"less_se_channels": False} ...@@ -59,10 +59,7 @@ IRF_CFG = {"less_se_channels": False}
FBNetV3_A_dsmask = [ FBNetV3_A_dsmask = [
[ [("conv_k3", 16, 2, 1), ("ir_k3", 16, 1, 1, {"expansion": 1}, IRF_CFG)],
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 1, {"expansion": 1}, IRF_CFG)
],
[ [
("ir_k5", 32, 2, 1, {"expansion": 4}, IRF_CFG), ("ir_k5", 32, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 32, 1, 1, {"expansion": 2}, IRF_CFG), ("ir_k5", 32, 1, 1, {"expansion": 2}, IRF_CFG),
...@@ -85,10 +82,7 @@ FBNetV3_A_dsmask = [ ...@@ -85,10 +82,7 @@ FBNetV3_A_dsmask = [
] ]
FBNetV3_A_dsmask_tiny = [ FBNetV3_A_dsmask_tiny = [
[ [("conv_k3", 8, 2, 1), ("ir_k3", 8, 1, 1, {"expansion": 1}, IRF_CFG)],
("conv_k3", 8, 2, 1),
("ir_k3", 8, 1, 1, {"expansion": 1}, IRF_CFG)
],
[ [
("ir_k5", 16, 2, 1, {"expansion": 3}, IRF_CFG), ("ir_k5", 16, 2, 1, {"expansion": 3}, IRF_CFG),
("ir_k5", 16, 1, 1, {"expansion": 2}, IRF_CFG), ("ir_k5", 16, 1, 1, {"expansion": 2}, IRF_CFG),
...@@ -112,10 +106,7 @@ FBNetV3_A_dsmask_tiny = [ ...@@ -112,10 +106,7 @@ FBNetV3_A_dsmask_tiny = [
FBNetV3_A = [ FBNetV3_A = [
# FBNetV3 arch without hs # FBNetV3 arch without hs
[ [("conv_k3", 16, 2, 1), ("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)],
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)
],
[ [
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG), ("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 3, {"expansion": 3}, IRF_CFG), ("ir_k5", 24, 1, 3, {"expansion": 3}, IRF_CFG),
...@@ -138,10 +129,7 @@ FBNetV3_A = [ ...@@ -138,10 +129,7 @@ FBNetV3_A = [
] ]
FBNetV3_B = [ FBNetV3_B = [
[ [("conv_k3", 16, 2, 1), ("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)],
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2 , {"expansion": 1}, IRF_CFG)
],
[ [
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG), ("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 3, {"expansion": 2}, IRF_CFG), ("ir_k5", 24, 1, 3, {"expansion": 2}, IRF_CFG),
...@@ -303,10 +291,7 @@ FBNetV3_H = [ ...@@ -303,10 +291,7 @@ FBNetV3_H = [
FBNetV3_A_no_se = [ FBNetV3_A_no_se = [
# FBNetV3 without hs and SE (SE is not quantization friendly) # FBNetV3 without hs and SE (SE is not quantization friendly)
[ [("conv_k3", 16, 2, 1), ("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)],
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)
],
[ [
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG), ("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 3, {"expansion": 3}, IRF_CFG), ("ir_k5", 24, 1, 3, {"expansion": 3}, IRF_CFG),
...@@ -329,10 +314,7 @@ FBNetV3_A_no_se = [ ...@@ -329,10 +314,7 @@ FBNetV3_A_no_se = [
] ]
FBNetV3_B_no_se = [ FBNetV3_B_no_se = [
[ [("conv_k3", 16, 2, 1), ("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)],
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2 , {"expansion": 1}, IRF_CFG)
],
[ [
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG), ("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 3, {"expansion": 2}, IRF_CFG), ("ir_k5", 24, 1, 3, {"expansion": 2}, IRF_CFG),
...@@ -357,10 +339,7 @@ FBNetV3_B_no_se = [ ...@@ -357,10 +339,7 @@ FBNetV3_B_no_se = [
# FBNetV3_B model, a lighter version for real-time inference # FBNetV3_B model, a lighter version for real-time inference
FBNetV3_B_light_no_se = [ FBNetV3_B_light_no_se = [
[ [("conv_k3", 16, 2, 1), ("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)],
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2 , {"expansion": 1}, IRF_CFG)
],
[ [
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG), ("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 2, {"expansion": 2}, IRF_CFG), ("ir_k5", 24, 1, 2, {"expansion": 2}, IRF_CFG),
...@@ -411,11 +390,21 @@ SMALL_UPSAMPLE_HEAD_STAGES = [ ...@@ -411,11 +390,21 @@ SMALL_UPSAMPLE_HEAD_STAGES = [
# NOTE: Compared with SMALL_UPSAMPLE_HEAD_STAGES, this does one more down-sample # NOTE: Compared with SMALL_UPSAMPLE_HEAD_STAGES, this does one more down-sample
# in the first "layer" and then up-sample twice # in the first "layer" and then up-sample twice
SMALL_DS_UPSAMPLE_HEAD_STAGES = [ SMALL_DS_UPSAMPLE_HEAD_STAGES = [
[("ir_k3", 128, 2, 1, e4), ("ir_k3", 128, 1, 2, e6), ("ir_k3", 128, -2, 1, e6), ("ir_k3", 64, -2, 1, e3)], # noqa [
("ir_k3", 128, 2, 1, e4),
("ir_k3", 128, 1, 2, e6),
("ir_k3", 128, -2, 1, e6),
("ir_k3", 64, -2, 1, e3),
], # noqa
] ]
TINY_DS_UPSAMPLE_HEAD_STAGES = [ TINY_DS_UPSAMPLE_HEAD_STAGES = [
[("ir_k3", 64, 2, 1, e4), ("ir_k3", 64, 1, 2, e4), ("ir_k3", 64, -2, 1, e4), ("ir_k3", 40, -2, 1, e3)], # noqa [
("ir_k3", 64, 2, 1, e4),
("ir_k3", 64, 1, 2, e4),
("ir_k3", 64, -2, 1, e4),
("ir_k3", 40, -2, 1, e3),
], # noqa
] ]
FPN_UPSAMPLE_HEAD_STAGES = [ FPN_UPSAMPLE_HEAD_STAGES = [
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List
import logging
import numpy as np import numpy as np
import torch import torch
from torch import nn
from torch.nn import functional as F
from detectron2.layers import cat
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.utils.registry import Registry
from d2go.config import CfgNode as CN from d2go.config import CfgNode as CN
from d2go.data.dataset_mappers import ( from d2go.data.dataset_mappers import (
D2GO_DATA_MAPPER_REGISTRY, D2GO_DATA_MAPPER_REGISTRY,
D2GoDatasetMapper, D2GoDatasetMapper,
) )
from d2go.utils.helper import alias from d2go.utils.helper import alias
from detectron2.layers import cat
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.utils.registry import Registry
from torch import nn
from torch.nn import functional as F
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SUBCLASS_FETCHER_REGISTRY = Registry("SUBCLASS_FETCHER") SUBCLASS_FETCHER_REGISTRY = Registry("SUBCLASS_FETCHER")
def add_subclass_configs(cfg): def add_subclass_configs(cfg):
_C = cfg _C = cfg
_C.MODEL.SUBCLASS = CN() _C.MODEL.SUBCLASS = CN()
...@@ -32,42 +32,44 @@ def add_subclass_configs(cfg): ...@@ -32,42 +32,44 @@ def add_subclass_configs(cfg):
_C.MODEL.SUBCLASS.NUM_SUBCLASSES = 0 # must be set _C.MODEL.SUBCLASS.NUM_SUBCLASSES = 0 # must be set
_C.MODEL.SUBCLASS.NUM_LAYERS = 1 _C.MODEL.SUBCLASS.NUM_LAYERS = 1
_C.MODEL.SUBCLASS.SUBCLASS_ID_FETCHER = "SubclassFetcher" # ABC, must be set _C.MODEL.SUBCLASS.SUBCLASS_ID_FETCHER = "SubclassFetcher" # ABC, must be set
_C.MODEL.SUBCLASS.SUBCLASS_MAPPING = [] # subclass mapping from model output to annotation _C.MODEL.SUBCLASS.SUBCLASS_MAPPING = (
[]
) # subclass mapping from model output to annotation
class SubclassFetcher(ABC): class SubclassFetcher(ABC):
""" Fetcher class to read subclass id annotations from dataset and prepare for train/eval. """Fetcher class to read subclass id annotations from dataset and prepare for train/eval.
Subclass this and register with `@SUBCLASS_FETCHER_REGISTRY.register()` decorator Subclass this and register with `@SUBCLASS_FETCHER_REGISTRY.register()` decorator
to use with custom projects. to use with custom projects.
""" """
@property @property
@abstractmethod @abstractmethod
def subclass_names(self) -> List[str]: def subclass_names(self) -> List[str]:
""" Overwrite this member with any new mappings' subclass names, which """Overwrite this member with any new mappings' subclass names, which
may be useful for specific evaluation purposes. may be useful for specific evaluation purposes.
len(self.subclass_names) should be equal to the expected number len(self.subclass_names) should be equal to the expected number
of subclass head outputs (cfg.MODEL.SUBCLASS.NUM_SUBCLASSES + 1). of subclass head outputs (cfg.MODEL.SUBCLASS.NUM_SUBCLASSES + 1).
""" """
pass pass
def remap(self, subclass_id: int) -> int: def remap(self, subclass_id: int) -> int:
""" Map subclass ids read from dataset to new label id """ """Map subclass ids read from dataset to new label id"""
return subclass_id return subclass_id
def fetch_subclass_ids(self, dataset_dict: Dict[str, Any]) -> List[int]: def fetch_subclass_ids(self, dataset_dict: Dict[str, Any]) -> List[int]:
""" Get all the subclass_ids in a dataset dict """ """Get all the subclass_ids in a dataset dict"""
extras_list = [anno.get("extras") for anno in dataset_dict["annotations"]] extras_list = [anno.get("extras") for anno in dataset_dict["annotations"]]
subclass_ids = [extras["subclass_id"] for extras in extras_list] subclass_ids = [extras["subclass_id"] for extras in extras_list]
return subclass_ids return subclass_ids
@D2GO_DATA_MAPPER_REGISTRY.register() @D2GO_DATA_MAPPER_REGISTRY.register()
class SubclassDatasetMapper(D2GoDatasetMapper): class SubclassDatasetMapper(D2GoDatasetMapper):
""" """
Wrap any dataset mapper, encode gt_subclasses to the instances. Wrap any dataset mapper, encode gt_subclasses to the instances.
""" """
def __init__(self, cfg, is_train, tfm_gens=None, subclass_fetcher=None): def __init__(self, cfg, is_train, tfm_gens=None, subclass_fetcher=None):
super().__init__(cfg, is_train=is_train, tfm_gens=tfm_gens) super().__init__(cfg, is_train=is_train, tfm_gens=tfm_gens)
if subclass_fetcher is None: if subclass_fetcher is None:
...@@ -93,19 +95,23 @@ class SubclassDatasetMapper(D2GoDatasetMapper): ...@@ -93,19 +95,23 @@ class SubclassDatasetMapper(D2GoDatasetMapper):
# Transform removes key 'annotations' from the dataset dict # Transform removes key 'annotations' from the dataset dict
mapped_dataset_dict = super()._original_call(dataset_dict) mapped_dataset_dict = super()._original_call(dataset_dict)
if (self.is_train and self.subclass_on): if self.is_train and self.subclass_on:
subclass_ids = self.subclass_fetcher.fetch_subclass_ids(dataset_dict) subclass_ids = self.subclass_fetcher.fetch_subclass_ids(dataset_dict)
subclasses = torch.tensor(subclass_ids, dtype=torch.int64) subclasses = torch.tensor(subclass_ids, dtype=torch.int64)
mapped_dataset_dict["instances"].gt_subclasses = subclasses mapped_dataset_dict["instances"].gt_subclasses = subclasses
return mapped_dataset_dict return mapped_dataset_dict
def build_subclass_head(cfg, in_chann, out_chann): def build_subclass_head(cfg, in_chann, out_chann):
# fully connected layers: n-1 in_chann x in_chann layers, and 1 in_chann x out_chann layer # fully connected layers: n-1 in_chann x in_chann layers, and 1 in_chann x out_chann layer
layers = [nn.Linear(in_chann, in_chann) for _ in range(cfg.MODEL.SUBCLASS.NUM_LAYERS - 1)] layers = [
nn.Linear(in_chann, in_chann) for _ in range(cfg.MODEL.SUBCLASS.NUM_LAYERS - 1)
]
layers.append(nn.Linear(in_chann, out_chann)) layers.append(nn.Linear(in_chann, out_chann))
return nn.Sequential(*layers) return nn.Sequential(*layers)
@ROI_HEADS_REGISTRY.register() @ROI_HEADS_REGISTRY.register()
class StandardROIHeadsWithSubClass(StandardROIHeads): class StandardROIHeadsWithSubClass(StandardROIHeads):
""" """
...@@ -119,7 +125,9 @@ class StandardROIHeadsWithSubClass(StandardROIHeads): ...@@ -119,7 +125,9 @@ class StandardROIHeadsWithSubClass(StandardROIHeads):
return return
self.num_subclasses = cfg.MODEL.SUBCLASS.NUM_SUBCLASSES self.num_subclasses = cfg.MODEL.SUBCLASS.NUM_SUBCLASSES
self.subclass_head = build_subclass_head(cfg, self.box_head.output_shape.channels, self.num_subclasses + 1) self.subclass_head = build_subclass_head(
cfg, self.box_head.output_shape.channels, self.num_subclasses + 1
)
for layer in self.subclass_head: for layer in self.subclass_head:
nn.init.normal_(layer.weight, std=0.01) nn.init.normal_(layer.weight, std=0.01)
...@@ -142,12 +150,16 @@ class StandardROIHeadsWithSubClass(StandardROIHeads): ...@@ -142,12 +150,16 @@ class StandardROIHeadsWithSubClass(StandardROIHeads):
for pp_per_im in proposals: for pp_per_im in proposals:
if not pp_per_im.has("gt_subclasses"): if not pp_per_im.has("gt_subclasses"):
background_subcls_idx = 0 background_subcls_idx = 0
pp_per_im.gt_subclasses = torch.cuda.LongTensor(len(pp_per_im)).fill_(background_subcls_idx) pp_per_im.gt_subclasses = torch.cuda.LongTensor(
len(pp_per_im)
).fill_(background_subcls_idx)
del targets del targets
features_list = [features[f] for f in self.in_features] features_list = [features[f] for f in self.in_features]
box_features = self.box_pooler(features_list, [x.proposal_boxes for x in proposals]) box_features = self.box_pooler(
features_list, [x.proposal_boxes for x in proposals]
)
box_features = self.box_head(box_features) box_features = self.box_head(box_features)
predictions = self.box_predictor(box_features) predictions = self.box_predictor(box_features)
# --- end copy --------------------------------------------------------- # --- end copy ---------------------------------------------------------
...@@ -155,8 +167,7 @@ class StandardROIHeadsWithSubClass(StandardROIHeads): ...@@ -155,8 +167,7 @@ class StandardROIHeadsWithSubClass(StandardROIHeads):
# NOTE: don't delete box_features, keep it temporarily # NOTE: don't delete box_features, keep it temporarily
# del box_features # del box_features
box_features = box_features.view( box_features = box_features.view(
box_features.shape[0], box_features.shape[0], np.prod(box_features.shape[1:])
np.prod(box_features.shape[1:])
) )
pred_subclass_logits = self.subclass_head(box_features) pred_subclass_logits = self.subclass_head(box_features)
...@@ -195,8 +206,7 @@ class StandardROIHeadsWithSubClass(StandardROIHeads): ...@@ -195,8 +206,7 @@ class StandardROIHeadsWithSubClass(StandardROIHeads):
if torch.onnx.is_in_onnx_export(): if torch.onnx.is_in_onnx_export():
assert len(pred_instances) == 1 assert len(pred_instances) == 1
pred_instances[0].pred_subclass_prob = alias( pred_instances[0].pred_subclass_prob = alias(
pred_instances[0].pred_subclass_prob, pred_instances[0].pred_subclass_prob, "subclass_prob_nms"
"subclass_prob_nms"
) )
return pred_instances, {} return pred_instances, {}
...@@ -2,4 +2,4 @@ ...@@ -2,4 +2,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .build import build_optimizer_mapper from .build import build_optimizer_mapper
__all__ = ['build_optimizer_mapper'] __all__ = ["build_optimizer_mapper"]
...@@ -12,9 +12,7 @@ from detectron2.utils.registry import Registry ...@@ -12,9 +12,7 @@ from detectron2.utils.registry import Registry
D2GO_OPTIM_MAPPER_REGISTRY = Registry("D2GO_OPTIM_MAPPER") D2GO_OPTIM_MAPPER_REGISTRY = Registry("D2GO_OPTIM_MAPPER")
def reduce_param_groups( def reduce_param_groups(param_groups: List[Dict[str, Any]]):
param_groups: List[Dict[str, Any]]
):
# The number of parameter groups needs to be as small as possible in order # The number of parameter groups needs to be as small as possible in order
# to efficiently use the PyTorch multi-tensor optimizer. Therefore instead # to efficiently use the PyTorch multi-tensor optimizer. Therefore instead
# of using a parameter_group per single parameter, we group all the params # of using a parameter_group per single parameter, we group all the params
......
...@@ -14,7 +14,7 @@ def create_runner( ...@@ -14,7 +14,7 @@ def create_runner(
class_full_name: str, *args, **kwargs class_full_name: str, *args, **kwargs
) -> Union[BaseRunner, Type[LightningModule]]: ) -> Union[BaseRunner, Type[LightningModule]]:
"""Constructs a runner instance if class is a d2go runner. Returns class """Constructs a runner instance if class is a d2go runner. Returns class
type if class is a Lightning module. type if class is a Lightning module.
""" """
runner_module_name, runner_class_name = class_full_name.rsplit(".", 1) runner_module_name, runner_class_name = class_full_name.rsplit(".", 1)
runner_module = importlib.import_module(runner_module_name) runner_module = importlib.import_module(runner_module_name)
......
...@@ -29,13 +29,13 @@ PREPARED = "_prepared" ...@@ -29,13 +29,13 @@ PREPARED = "_prepared"
def rsetattr(obj: Any, attr: str, val: Any) -> None: def rsetattr(obj: Any, attr: str, val: Any) -> None:
""" Same as setattr but supports deeply nested objects. """ """Same as setattr but supports deeply nested objects."""
pre, _, post = attr.rpartition(".") pre, _, post = attr.rpartition(".")
return setattr(rgetattr(obj, pre) if pre else obj, post, val) return setattr(rgetattr(obj, pre) if pre else obj, post, val)
def rgetattr(obj: Any, attr: str, *args) -> Any: def rgetattr(obj: Any, attr: str, *args) -> Any:
""" Same as getattr but supports deeply nested objects. """ """Same as getattr but supports deeply nested objects."""
def _getattr(obj, attr): def _getattr(obj, attr):
return getattr(obj, attr, *args) return getattr(obj, attr, *args)
...@@ -44,7 +44,7 @@ def rgetattr(obj: Any, attr: str, *args) -> Any: ...@@ -44,7 +44,7 @@ def rgetattr(obj: Any, attr: str, *args) -> Any:
def rhasattr(obj: Any, attr: str, *args) -> bool: def rhasattr(obj: Any, attr: str, *args) -> bool:
""" Same as hasattr but supports deeply nested objects. """ """Same as hasattr but supports deeply nested objects."""
try: try:
_ = rgetattr(obj, attr, *args) _ = rgetattr(obj, attr, *args)
...@@ -66,7 +66,7 @@ def _deepcopy(pl_module: LightningModule) -> LightningModule: ...@@ -66,7 +66,7 @@ def _deepcopy(pl_module: LightningModule) -> LightningModule:
def _quantized_forward(self, *args, **kwargs): def _quantized_forward(self, *args, **kwargs):
""" Forward method for a quantized module. """ """Forward method for a quantized module."""
if not self.training and hasattr(self, "_quantized"): if not self.training and hasattr(self, "_quantized"):
return self._quantized(*args, **kwargs) return self._quantized(*args, **kwargs)
return self._prepared(*args, **kwargs) return self._prepared(*args, **kwargs)
...@@ -99,11 +99,7 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool: ...@@ -99,11 +99,7 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool:
def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]): def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]):
if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED): if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED):
# model has been prepared for QAT before saving into checkpoint # model has been prepared for QAT before saving into checkpoint
setattr( setattr(model, PREPARED, _deepcopy(model).prepare_for_quant())
model,
PREPARED,
_deepcopy(model).prepare_for_quant()
)
class QuantizationMixin(ABC): class QuantizationMixin(ABC):
...@@ -241,7 +237,7 @@ class ModelTransform: ...@@ -241,7 +237,7 @@ class ModelTransform:
interval: Optional[int] = None interval: Optional[int] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
""" Validate a few properties for early failure. """ """Validate a few properties for early failure."""
if (self.step is None and self.interval is None) or ( if (self.step is None and self.interval is None) or (
self.step is not None and self.interval is not None self.step is not None and self.interval is not None
): ):
...@@ -469,7 +465,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -469,7 +465,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
batch_idx: int, batch_idx: int,
dataloader_idx: int, dataloader_idx: int,
) -> None: ) -> None:
""" Applies model transforms at as specified during training. """ """Applies model transforms at as specified during training."""
apply_only_once = [] apply_only_once = []
current_step = trainer.global_step current_step = trainer.global_step
for i, transform in enumerate(self.transforms): for i, transform in enumerate(self.transforms):
...@@ -492,7 +488,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -492,7 +488,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
] ]
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
""" Quantize the weights since training has finalized. """ """Quantize the weights since training has finalized."""
if hasattr(pl_module, "_quantized") or self.skip_conversion: if hasattr(pl_module, "_quantized") or self.skip_conversion:
return return
pl_module._quantized = self.convert( pl_module._quantized = self.convert(
...@@ -563,7 +559,7 @@ class PostTrainingQuantization(Callback, QuantizationMixin): ...@@ -563,7 +559,7 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
qconfig_dicts: Optional[QConfigDicts] = None, qconfig_dicts: Optional[QConfigDicts] = None,
preserved_attrs: Optional[List[str]] = None, preserved_attrs: Optional[List[str]] = None,
) -> None: ) -> None:
""" Initialize the callback. """ """Initialize the callback."""
self.qconfig_dicts = qconfig_dicts or {"": {"": get_default_qconfig()}} self.qconfig_dicts = qconfig_dicts or {"": {"": get_default_qconfig()}}
self.preserved_attrs = set([] if preserved_attrs is None else preserved_attrs) self.preserved_attrs = set([] if preserved_attrs is None else preserved_attrs)
self.prepared: Optional[torch.nn.Module] = None self.prepared: Optional[torch.nn.Module] = None
...@@ -593,7 +589,7 @@ class PostTrainingQuantization(Callback, QuantizationMixin): ...@@ -593,7 +589,7 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
) )
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
""" Convert the calibrated model to its finalized quantized version. """ """Convert the calibrated model to its finalized quantized version."""
self.quantized = self.convert( self.quantized = self.convert(
self.prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs self.prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs
) )
...@@ -607,7 +603,7 @@ class PostTrainingQuantization(Callback, QuantizationMixin): ...@@ -607,7 +603,7 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
batch_idx: int, batch_idx: int,
dataloader_idx: int, dataloader_idx: int,
) -> None: ) -> None:
""" Also run the validation batch through the quantized model for calibration. """ """Also run the validation batch through the quantized model for calibration."""
if self.should_calibrate: if self.should_calibrate:
with torch.no_grad(): with torch.no_grad():
self.prepared(batch) self.prepared(batch)
...@@ -395,7 +395,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -395,7 +395,7 @@ class Detectron2GoRunner(BaseRunner):
return results return results
def do_test(self, cfg, model, train_iter=None): def do_test(self, cfg, model, train_iter=None):
""" do_test does not load the weights of the model. """do_test does not load the weights of the model.
If you want to use it outside the regular training routine, If you want to use it outside the regular training routine,
you will have to load the weights through a checkpointer. you will have to load the weights through a checkpointer.
""" """
......
...@@ -15,10 +15,10 @@ from d2go.data.utils import ( ...@@ -15,10 +15,10 @@ from d2go.data.utils import (
update_cfg_if_using_adhoc_dataset, update_cfg_if_using_adhoc_dataset,
) )
from d2go.export.d2_meta_arch import patch_d2_meta_arch from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.modeling import build_model
from d2go.modeling.model_freezing_utils import ( from d2go.modeling.model_freezing_utils import (
set_requires_grad, set_requires_grad,
) )
from d2go.modeling import build_model
from d2go.modeling.quantization import ( from d2go.modeling.quantization import (
default_prepare_for_quant, default_prepare_for_quant,
default_prepare_for_quant_convert, default_prepare_for_quant_convert,
...@@ -206,7 +206,9 @@ class DefaultTask(pl.LightningModule): ...@@ -206,7 +206,9 @@ class DefaultTask(pl.LightningModule):
flattened = pl.loggers.LightningLoggerBase._flatten_dict(nested_res) flattened = pl.loggers.LightningLoggerBase._flatten_dict(nested_res)
if self.trainer.global_rank: if self.trainer.global_rank:
assert len(flattened) == 0, "evaluation results should have been reduced on rank 0." assert (
len(flattened) == 0
), "evaluation results should have been reduced on rank 0."
self.log_dict(flattened, rank_zero_only=True) self.log_dict(flattened, rank_zero_only=True)
def test_epoch_end(self, _outputs) -> None: def test_epoch_end(self, _outputs) -> None:
......
...@@ -17,12 +17,12 @@ from d2go.config import ( ...@@ -17,12 +17,12 @@ from d2go.config import (
) )
from d2go.distributed import get_local_rank, get_num_processes_per_machine from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import GeneralizedRCNNRunner, create_runner from d2go.runner import GeneralizedRCNNRunner, create_runner
from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment from d2go.utils.launch_environment import get_launch_environment
from detectron2.utils.collect_env import collect_env_info from detectron2.utils.collect_env import collect_env_info
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger from detectron2.utils.logger import setup_logger
from detectron2.utils.serialize import PicklableWrapper from detectron2.utils.serialize import PicklableWrapper
from d2go.utils.helper import run_once
from detectron2.utils.file_io import PathManager
from mobile_cv.common.misc.py import FolderLock, MultiprocessingPdb, post_mortem_if_fail from mobile_cv.common.misc.py import FolderLock, MultiprocessingPdb, post_mortem_if_fail
...@@ -34,7 +34,7 @@ def basic_argument_parser( ...@@ -34,7 +34,7 @@ def basic_argument_parser(
requires_config_file=True, requires_config_file=True,
requires_output_dir=True, requires_output_dir=True,
): ):
""" Basic cli tool parser for Detectron2Go binaries """ """Basic cli tool parser for Detectron2Go binaries"""
parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
parser.add_argument( parser.add_argument(
"--runner", "--runner",
...@@ -201,6 +201,7 @@ def _setup_after_launch(cfg: CN, output_dir: str, runner): ...@@ -201,6 +201,7 @@ def _setup_after_launch(cfg: CN, output_dir: str, runner):
cfg.OUTPUT_DIR = output_dir cfg.OUTPUT_DIR = output_dir
dump_cfg(cfg, os.path.join(output_dir, "config.yaml")) dump_cfg(cfg, os.path.join(output_dir, "config.yaml"))
def setup_after_launch(cfg: CN, output_dir: str, runner): def setup_after_launch(cfg: CN, output_dir: str, runner):
_setup_after_launch(cfg, output_dir, runner) _setup_after_launch(cfg, output_dir, runner)
logger.info("Initializing runner ...") logger.info("Initializing runner ...")
...@@ -210,10 +211,12 @@ def setup_after_launch(cfg: CN, output_dir: str, runner): ...@@ -210,10 +211,12 @@ def setup_after_launch(cfg: CN, output_dir: str, runner):
auto_scale_world_size(cfg, new_world_size=comm.get_world_size()) auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
def setup_after_lightning_launch(cfg: CN, output_dir: str): def setup_after_lightning_launch(cfg: CN, output_dir: str):
_setup_after_launch(cfg, output_dir, runner=None) _setup_after_launch(cfg, output_dir, runner=None)
log_info(cfg, runner=None) log_info(cfg, runner=None)
@run_once() @run_once()
def setup_loggers(output_dir, color=None): def setup_loggers(output_dir, color=None):
if not color: if not color:
......
...@@ -31,8 +31,8 @@ def get_rel_loss_checker(rel_thres=1.0): ...@@ -31,8 +31,8 @@ def get_rel_loss_checker(rel_thres=1.0):
class TrainImageWriter(object): class TrainImageWriter(object):
def __init__(self, cfg, tbx_writer, max_count=5): def __init__(self, cfg, tbx_writer, max_count=5):
""" max_count: max number of data written to tensorboard, additional call """max_count: max number of data written to tensorboard, additional call
will be ignored will be ignored
""" """
self.visualizer = VisualizerWrapper(cfg) self.visualizer = VisualizerWrapper(cfg)
self.writer = tbx_writer self.writer = tbx_writer
...@@ -58,8 +58,8 @@ class TrainImageWriter(object): ...@@ -58,8 +58,8 @@ class TrainImageWriter(object):
class FileWriter(object): class FileWriter(object):
def __init__(self, output_dir, max_count=5): def __init__(self, output_dir, max_count=5):
""" max_count: max number of data written to tensorboard, additional call """max_count: max number of data written to tensorboard, additional call
will be ignored will be ignored
""" """
self.output_dir = output_dir self.output_dir = output_dir
self.max_count = max_count self.max_count = max_count
......
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
from collections import deque from collections import deque
import cv2 import cv2
import detectron2.data.transforms as T
import torch import torch
from d2go.model_zoo import model_zoo
from detectron2.data import MetadataCatalog from detectron2.data import MetadataCatalog
from detectron2.utils.video_visualizer import VideoVisualizer from detectron2.utils.video_visualizer import VideoVisualizer
from detectron2.utils.visualizer import ColorMode, Visualizer from detectron2.utils.visualizer import ColorMode, Visualizer
import detectron2.data.transforms as T
from d2go.model_zoo import model_zoo
class DemoPredictor: class DemoPredictor:
def __init__(self, model, min_size_test=224, max_size_test=320, input_format="RGB"): def __init__(self, model, min_size_test=224, max_size_test=320, input_format="RGB"):
self.model = model self.model = model
self.model.eval() self.model.eval()
self.aug = T.ResizeShortestEdge( self.aug = T.ResizeShortestEdge([min_size_test, min_size_test], max_size_test)
[min_size_test, min_size_test], max_size_test
)
self.input_format = input_format self.input_format = input_format
...@@ -43,6 +41,7 @@ class DemoPredictor: ...@@ -43,6 +41,7 @@ class DemoPredictor:
predictions = self.model([inputs])[0] predictions = self.model([inputs])[0]
return predictions return predictions
class VisualizationDemo(object): class VisualizationDemo(object):
def __init__(self, cfg, config_file, instance_mode=ColorMode.IMAGE, parallel=False): def __init__(self, cfg, config_file, instance_mode=ColorMode.IMAGE, parallel=False):
""" """
...@@ -59,7 +58,7 @@ class VisualizationDemo(object): ...@@ -59,7 +58,7 @@ class VisualizationDemo(object):
self.instance_mode = instance_mode self.instance_mode = instance_mode
self.parallel = parallel self.parallel = parallel
model = model_zoo.get(config_file, trained=True)#runner.build_model(cfg) model = model_zoo.get(config_file, trained=True) # runner.build_model(cfg)
self.predictor = DemoPredictor(model) self.predictor = DemoPredictor(model)
def run_on_image(self, image): def run_on_image(self, image):
...@@ -123,7 +122,9 @@ class VisualizationDemo(object): ...@@ -123,7 +122,9 @@ class VisualizationDemo(object):
) )
elif "instances" in predictions: elif "instances" in predictions:
predictions = predictions["instances"].to(self.cpu_device) predictions = predictions["instances"].to(self.cpu_device)
vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) vis_frame = video_visualizer.draw_instance_predictions(
frame, predictions
)
elif "sem_seg" in predictions: elif "sem_seg" in predictions:
vis_frame = video_visualizer.draw_sem_seg( vis_frame = video_visualizer.draw_sem_seg(
frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
......
...@@ -31,13 +31,13 @@ class EMAState(object): ...@@ -31,13 +31,13 @@ class EMAState(object):
decay: float = 0.999, decay: float = 0.999,
device: Optional[str] = None, device: Optional[str] = None,
) -> "EMAState": ) -> "EMAState":
""" Constructs model state from the model and move to device if given.""" """Constructs model state from the model and move to device if given."""
ret = cls(decay, device) ret = cls(decay, device)
ret.load_from(model) ret.load_from(model)
return ret return ret
def load_from(self, model: nn.Module) -> None: def load_from(self, model: nn.Module) -> None:
""" Load state from the model. """ """Load state from the model."""
self.state.clear() self.state.clear()
for name, val in self._get_model_state_iterator(model): for name, val in self._get_model_state_iterator(model):
val = val.detach().clone() val = val.detach().clone()
...@@ -47,7 +47,7 @@ class EMAState(object): ...@@ -47,7 +47,7 @@ class EMAState(object):
return len(self.state) > 0 return len(self.state) > 0
def apply_to(self, model: nn.Module) -> None: def apply_to(self, model: nn.Module) -> None:
""" Apply EMA state to the model. """ """Apply EMA state to the model."""
with torch.no_grad(): with torch.no_grad():
for name, val in self._get_model_state_iterator(model): for name, val in self._get_model_state_iterator(model):
assert ( assert (
...@@ -64,7 +64,7 @@ class EMAState(object): ...@@ -64,7 +64,7 @@ class EMAState(object):
self.state[name] = val.to(self.device) if self.device else val self.state[name] = val.to(self.device) if self.device else val
def to(self, device: torch.device) -> None: def to(self, device: torch.device) -> None:
""" moves EMA state to device. """ """moves EMA state to device."""
for name, val in self.state.items(): for name, val in self.state.items():
self.state[name] = val.to(device) self.state[name] = val.to(device)
......
...@@ -2,16 +2,16 @@ ...@@ -2,16 +2,16 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy import copy
import torch
import os
import logging import logging
import os
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
from detectron2.utils.file_io import PathManager
from detectron2.utils.analysis import FlopCountAnalysis
from fvcore.nn import flop_count_table, flop_count_str
import mobile_cv.lut.lib.pt.flops_utils as flops_utils import mobile_cv.lut.lib.pt.flops_utils as flops_utils
import torch
from d2go.utils.helper import run_once from d2go.utils.helper import run_once
from detectron2.utils.analysis import FlopCountAnalysis
from detectron2.utils.file_io import PathManager
from fvcore.nn import flop_count_table, flop_count_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -58,7 +58,6 @@ def dump_flops_info(model, inputs, output_dir, use_eval_mode=True): ...@@ -58,7 +58,6 @@ def dump_flops_info(model, inputs, output_dir, use_eval_mode=True):
except Exception: except Exception:
logger.exception("Failed to estimate flops using mobile_cv's FlopsEstimation") logger.exception("Failed to estimate flops using mobile_cv's FlopsEstimation")
# 2. using d2/fvcore's flop counter # 2. using d2/fvcore's flop counter
try: try:
flops = FlopCountAnalysis(model, inputs) flops = FlopCountAnalysis(model, inputs)
...@@ -81,14 +80,16 @@ def dump_flops_info(model, inputs, output_dir, use_eval_mode=True): ...@@ -81,14 +80,16 @@ def dump_flops_info(model, inputs, output_dir, use_eval_mode=True):
flops_table = flop_count_table(flops, max_depth=3) flops_table = flop_count_table(flops, max_depth=3)
logger.info("Flops table:\n" + flops_table) logger.info("Flops table:\n" + flops_table)
except Exception: except Exception:
logger.exception("Failed to estimate flops using detectron2's FlopCountAnalysis") logger.exception(
"Failed to estimate flops using detectron2's FlopCountAnalysis"
)
return flops return flops
def add_flop_printing_hook( def add_flop_printing_hook(
model, model,
output_dir: str, output_dir: str,
): ):
""" """
Add a pytorch module forward hook that will print/save flops of the whole model Add a pytorch module forward hook that will print/save flops of the whole model
at the first time the model is called. at the first time the model is called.
...@@ -96,6 +97,7 @@ def add_flop_printing_hook( ...@@ -96,6 +97,7 @@ def add_flop_printing_hook(
Args: Args:
output_dir: directory to save more detailed flop info output_dir: directory to save more detailed flop info
""" """
def hook(module, input): def hook(module, input):
handle.remove() handle.remove()
dump_flops_info(module, input, output_dir) dump_flops_info(module, input, output_dir)
......
...@@ -6,11 +6,11 @@ from d2go.data.build import ( ...@@ -6,11 +6,11 @@ from d2go.data.build import (
add_random_subset_training_sampler_default_configs, add_random_subset_training_sampler_default_configs,
) )
from d2go.data.config import add_d2go_data_default_configs from d2go.data.config import add_d2go_data_default_configs
from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.backbone.fbnet_cfg import ( from d2go.modeling.backbone.fbnet_cfg import (
add_bifpn_default_configs, add_bifpn_default_configs,
add_fbnet_v2_default_configs, add_fbnet_v2_default_configs,
) )
from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.model_freezing_utils import add_model_freezing_configs from d2go.modeling.model_freezing_utils import add_model_freezing_configs
from d2go.modeling.quantization import add_quantization_default_configs from d2go.modeling.quantization import add_quantization_default_configs
from d2go.modeling.subclass import add_subclass_configs from d2go.modeling.subclass import add_subclass_configs
......
...@@ -7,24 +7,21 @@ import inspect ...@@ -7,24 +7,21 @@ import inspect
import logging import logging
import math import math
import os import os
import re
import tempfile
import zipfile
import pickle import pickle
import re
import signal import signal
import sys import sys
import tempfile
import threading import threading
import time import time
import traceback import traceback
import typing import typing
import warnings import warnings
import pkg_resources import zipfile
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from random import random
import six
from functools import wraps from functools import wraps
from random import random
from typing import ( from typing import (
Any, Any,
Callable, Callable,
...@@ -39,12 +36,20 @@ from typing import ( ...@@ -39,12 +36,20 @@ from typing import (
Union, Union,
) )
import torch
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import pkg_resources
import six
import torch
from detectron2.checkpoint import DetectionCheckpointer from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch from detectron2.engine import (
DefaultTrainer,
default_argument_parser,
default_setup,
hooks,
launch,
)
from detectron2.evaluation import ( from detectron2.evaluation import (
CityscapesInstanceEvaluator, CityscapesInstanceEvaluator,
CityscapesSemSegEvaluator, CityscapesSemSegEvaluator,
...@@ -66,9 +71,11 @@ NT = TypeVar("T", bound=NamedTuple) ...@@ -66,9 +71,11 @@ NT = TypeVar("T", bound=NamedTuple)
from detectron2.utils.events import TensorboardXWriter from detectron2.utils.events import TensorboardXWriter
class MultipleFunctionCallError(Exception): class MultipleFunctionCallError(Exception):
pass pass
def run_once( def run_once(
raise_on_multiple: bool = False, raise_on_multiple: bool = False,
# pyre-fixme[34]: `Variable[T]` isn't present in the function's parameters. # pyre-fixme[34]: `Variable[T]` isn't present in the function's parameters.
...@@ -102,8 +109,8 @@ def run_once( ...@@ -102,8 +109,8 @@ def run_once(
class retryable(object): class retryable(object):
"""Fake retryable function """Fake retryable function"""
"""
def __init__(self, num_tries=1, sleep_time=0.1): def __init__(self, num_tries=1, sleep_time=0.1):
pass pass
...@@ -134,6 +141,7 @@ def alias(x, name, is_backward=False): ...@@ -134,6 +141,7 @@ def alias(x, name, is_backward=False):
assert isinstance(x, torch.Tensor) assert isinstance(x, torch.Tensor)
return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward) return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward)
class D2Trainer(DefaultTrainer): class D2Trainer(DefaultTrainer):
@classmethod @classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None): def build_evaluator(cls, cfg, dataset_name, output_folder=None):
...@@ -183,6 +191,7 @@ class D2Trainer(DefaultTrainer): ...@@ -183,6 +191,7 @@ class D2Trainer(DefaultTrainer):
return evaluator_list[0] return evaluator_list[0]
return DatasetEvaluators(evaluator_list) return DatasetEvaluators(evaluator_list)
def reroute_config_path(path: str) -> str: def reroute_config_path(path: str) -> str:
""" """
Supporting rerouting the config files for convenience: Supporting rerouting the config files for convenience:
......
...@@ -6,5 +6,6 @@ import logging ...@@ -6,5 +6,6 @@ import logging
import os import os
from functools import lru_cache from functools import lru_cache
def get_tensorboard_log_dir(output_dir): def get_tensorboard_log_dir(output_dir):
return output_dir return output_dir
...@@ -67,7 +67,7 @@ def enable_ddp_env(func): ...@@ -67,7 +67,7 @@ def enable_ddp_env(func):
def tempdir(func): def tempdir(func):
""" A decorator for creating a tempory directory that is cleaned up after function execution. """ """A decorator for creating a tempory directory that is cleaned up after function execution."""
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
......
...@@ -5,14 +5,13 @@ import glob ...@@ -5,14 +5,13 @@ import glob
import multiprocessing as mp import multiprocessing as mp
import os import os
import time import time
import cv2 import cv2
import tqdm import tqdm
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger
from d2go.model_zoo import model_zoo from d2go.model_zoo import model_zoo
from d2go.utils.demo_predictor import VisualizationDemo from d2go.utils.demo_predictor import VisualizationDemo
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger
# constants # constants
WINDOW_NAME = "COCO detections" WINDOW_NAME = "COCO detections"
...@@ -22,7 +21,9 @@ def setup_cfg(cfg, args): ...@@ -22,7 +21,9 @@ def setup_cfg(cfg, args):
# Set score_threshold for builtin models # Set score_threshold for builtin models
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = (
args.confidence_threshold
)
cfg.freeze() cfg.freeze()
return cfg return cfg
...@@ -31,11 +32,13 @@ def get_parser(): ...@@ -31,11 +32,13 @@ def get_parser():
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
parser.add_argument( parser.add_argument(
"--config-file", "--config-file",
default='keypoint_rcnn_fbnetv3a_dsmask_C4.yaml', default="keypoint_rcnn_fbnetv3a_dsmask_C4.yaml",
metavar="FILE", metavar="FILE",
help="path to config file", help="path to config file",
) )
parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.") parser.add_argument(
"--webcam", action="store_true", help="Take inputs from webcam."
)
parser.add_argument("--video-input", help="Path to video file.") parser.add_argument("--video-input", help="Path to video file.")
parser.add_argument( parser.add_argument(
"--input", "--input",
...@@ -99,7 +102,9 @@ def main(): ...@@ -99,7 +102,9 @@ def main():
assert os.path.isdir(args.output), args.output assert os.path.isdir(args.output), args.output
out_filename = os.path.join(args.output, os.path.basename(path)) out_filename = os.path.join(args.output, os.path.basename(path))
else: else:
assert len(args.input) == 1, "Please specify a directory with args.output" assert (
len(args.input) == 1
), "Please specify a directory with args.output"
out_filename = args.output out_filename = args.output
visualized_output.save(out_filename) visualized_output.save(out_filename)
else: else:
...@@ -145,6 +150,7 @@ def main(): ...@@ -145,6 +150,7 @@ def main():
output_file.release() output_file.release()
else: else:
cv2.destroyAllWindows() cv2.destroyAllWindows()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -2,4 +2,4 @@ ...@@ -2,4 +2,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from . import models, util, datasets from . import models, util, datasets
__all__ = ['models', 'util', 'datasets'] __all__ = ["models", "util", "datasets"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment