Commit 82295dbf authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

enable black for mobile-vision

Summary:
https://fb.workplace.com/groups/pythonfoundation/posts/2990917737888352

Remove `mobile-vision` from opt-out list; leaving `mobile-vision/SNPE` opted out because of 3rd-party code.

arc lint --take BLACK --apply-patches --paths-cmd 'hg files mobile-vision'

allow-large-files

Reviewed By: sstsai-adl

Differential Revision: D30721093

fbshipit-source-id: 9e5c16d988b315b93a28038443ecfb92efd18ef8
parent a56c7e15
......@@ -4,12 +4,15 @@
import copy
class FBNetV2ModelArch(object):
_MODEL_ARCH = {}
@staticmethod
def add(name, arch):
assert name not in FBNetV2ModelArch._MODEL_ARCH, \
"Arch name '{}' is already existed".format(name)
assert (
name not in FBNetV2ModelArch._MODEL_ARCH
), "Arch name '{}' is already existed".format(name)
FBNetV2ModelArch._MODEL_ARCH[name] = arch
@staticmethod
......
......@@ -3,8 +3,9 @@
import copy
from mobile_cv.arch.fbnet_v2.modeldef_utils import _ex, e1, e2, e1p, e3, e4, e6
from d2go.modeling.modeldef.fbnet_modeldef_registry import FBNetV2ModelArch
from mobile_cv.arch.fbnet_v2.modeldef_utils import _ex, e1, e2, e1p, e3, e4, e6
def _mutated_tuple(tp, pos, value):
......@@ -34,7 +35,6 @@ _BASIC_ARGS = {
# FBNetV1 builder.
# "always_pw": True,
# "bias": False,
# temporarily disable zero_last_bn_gamma
"zero_last_bn_gamma": False,
}
......@@ -59,10 +59,7 @@ IRF_CFG = {"less_se_channels": False}
FBNetV3_A_dsmask = [
[
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 1, {"expansion": 1}, IRF_CFG)
],
[("conv_k3", 16, 2, 1), ("ir_k3", 16, 1, 1, {"expansion": 1}, IRF_CFG)],
[
("ir_k5", 32, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 32, 1, 1, {"expansion": 2}, IRF_CFG),
......@@ -85,10 +82,7 @@ FBNetV3_A_dsmask = [
]
FBNetV3_A_dsmask_tiny = [
[
("conv_k3", 8, 2, 1),
("ir_k3", 8, 1, 1, {"expansion": 1}, IRF_CFG)
],
[("conv_k3", 8, 2, 1), ("ir_k3", 8, 1, 1, {"expansion": 1}, IRF_CFG)],
[
("ir_k5", 16, 2, 1, {"expansion": 3}, IRF_CFG),
("ir_k5", 16, 1, 1, {"expansion": 2}, IRF_CFG),
......@@ -112,10 +106,7 @@ FBNetV3_A_dsmask_tiny = [
FBNetV3_A = [
# FBNetV3 arch without hs
[
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)
],
[("conv_k3", 16, 2, 1), ("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)],
[
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 3, {"expansion": 3}, IRF_CFG),
......@@ -138,10 +129,7 @@ FBNetV3_A = [
]
FBNetV3_B = [
[
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2 , {"expansion": 1}, IRF_CFG)
],
[("conv_k3", 16, 2, 1), ("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)],
[
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 3, {"expansion": 2}, IRF_CFG),
......@@ -303,10 +291,7 @@ FBNetV3_H = [
FBNetV3_A_no_se = [
# FBNetV3 without hs and SE (SE is not quantization friendly)
[
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)
],
[("conv_k3", 16, 2, 1), ("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)],
[
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 3, {"expansion": 3}, IRF_CFG),
......@@ -329,10 +314,7 @@ FBNetV3_A_no_se = [
]
FBNetV3_B_no_se = [
[
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2 , {"expansion": 1}, IRF_CFG)
],
[("conv_k3", 16, 2, 1), ("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)],
[
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 3, {"expansion": 2}, IRF_CFG),
......@@ -357,10 +339,7 @@ FBNetV3_B_no_se = [
# FBNetV3_B model, a lighter version for real-time inference
FBNetV3_B_light_no_se = [
[
("conv_k3", 16, 2, 1),
("ir_k3", 16, 1, 2 , {"expansion": 1}, IRF_CFG)
],
[("conv_k3", 16, 2, 1), ("ir_k3", 16, 1, 2, {"expansion": 1}, IRF_CFG)],
[
("ir_k5", 24, 2, 1, {"expansion": 4}, IRF_CFG),
("ir_k5", 24, 1, 2, {"expansion": 2}, IRF_CFG),
......@@ -411,11 +390,21 @@ SMALL_UPSAMPLE_HEAD_STAGES = [
# NOTE: Compared with SMALL_UPSAMPLE_HEAD_STAGES, this does one more down-sample
# in the first "layer" and then up-sample twice
SMALL_DS_UPSAMPLE_HEAD_STAGES = [
[("ir_k3", 128, 2, 1, e4), ("ir_k3", 128, 1, 2, e6), ("ir_k3", 128, -2, 1, e6), ("ir_k3", 64, -2, 1, e3)], # noqa
[
("ir_k3", 128, 2, 1, e4),
("ir_k3", 128, 1, 2, e6),
("ir_k3", 128, -2, 1, e6),
("ir_k3", 64, -2, 1, e3),
], # noqa
]
TINY_DS_UPSAMPLE_HEAD_STAGES = [
[("ir_k3", 64, 2, 1, e4), ("ir_k3", 64, 1, 2, e4), ("ir_k3", 64, -2, 1, e4), ("ir_k3", 40, -2, 1, e3)], # noqa
[
("ir_k3", 64, 2, 1, e4),
("ir_k3", 64, 1, 2, e4),
("ir_k3", 64, -2, 1, e4),
("ir_k3", 40, -2, 1, e3),
], # noqa
]
FPN_UPSAMPLE_HEAD_STAGES = [
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List
import logging
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.layers import cat
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.utils.registry import Registry
from d2go.config import CfgNode as CN
from d2go.data.dataset_mappers import (
D2GO_DATA_MAPPER_REGISTRY,
D2GoDatasetMapper,
)
from d2go.utils.helper import alias
from detectron2.layers import cat
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.utils.registry import Registry
from torch import nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
SUBCLASS_FETCHER_REGISTRY = Registry("SUBCLASS_FETCHER")
def add_subclass_configs(cfg):
_C = cfg
_C.MODEL.SUBCLASS = CN()
......@@ -32,42 +32,44 @@ def add_subclass_configs(cfg):
_C.MODEL.SUBCLASS.NUM_SUBCLASSES = 0 # must be set
_C.MODEL.SUBCLASS.NUM_LAYERS = 1
_C.MODEL.SUBCLASS.SUBCLASS_ID_FETCHER = "SubclassFetcher" # ABC, must be set
_C.MODEL.SUBCLASS.SUBCLASS_MAPPING = [] # subclass mapping from model output to annotation
_C.MODEL.SUBCLASS.SUBCLASS_MAPPING = (
[]
) # subclass mapping from model output to annotation
class SubclassFetcher(ABC):
""" Fetcher class to read subclass id annotations from dataset and prepare for train/eval.
Subclass this and register with `@SUBCLASS_FETCHER_REGISTRY.register()` decorator
to use with custom projects.
"""Fetcher class to read subclass id annotations from dataset and prepare for train/eval.
Subclass this and register with `@SUBCLASS_FETCHER_REGISTRY.register()` decorator
to use with custom projects.
"""
@property
@abstractmethod
def subclass_names(self) -> List[str]:
""" Overwrite this member with any new mappings' subclass names, which
may be useful for specific evaluation purposes.
len(self.subclass_names) should be equal to the expected number
of subclass head outputs (cfg.MODEL.SUBCLASS.NUM_SUBCLASSES + 1).
"""Overwrite this member with any new mappings' subclass names, which
may be useful for specific evaluation purposes.
len(self.subclass_names) should be equal to the expected number
of subclass head outputs (cfg.MODEL.SUBCLASS.NUM_SUBCLASSES + 1).
"""
pass
def remap(self, subclass_id: int) -> int:
""" Map subclass ids read from dataset to new label id """
"""Map subclass ids read from dataset to new label id"""
return subclass_id
def fetch_subclass_ids(self, dataset_dict: Dict[str, Any]) -> List[int]:
""" Get all the subclass_ids in a dataset dict """
"""Get all the subclass_ids in a dataset dict"""
extras_list = [anno.get("extras") for anno in dataset_dict["annotations"]]
subclass_ids = [extras["subclass_id"] for extras in extras_list]
return subclass_ids
@D2GO_DATA_MAPPER_REGISTRY.register()
class SubclassDatasetMapper(D2GoDatasetMapper):
"""
Wrap any dataset mapper, encode gt_subclasses to the instances.
"""
def __init__(self, cfg, is_train, tfm_gens=None, subclass_fetcher=None):
super().__init__(cfg, is_train=is_train, tfm_gens=tfm_gens)
if subclass_fetcher is None:
......@@ -93,19 +95,23 @@ class SubclassDatasetMapper(D2GoDatasetMapper):
# Transform removes key 'annotations' from the dataset dict
mapped_dataset_dict = super()._original_call(dataset_dict)
if (self.is_train and self.subclass_on):
if self.is_train and self.subclass_on:
subclass_ids = self.subclass_fetcher.fetch_subclass_ids(dataset_dict)
subclasses = torch.tensor(subclass_ids, dtype=torch.int64)
mapped_dataset_dict["instances"].gt_subclasses = subclasses
return mapped_dataset_dict
def build_subclass_head(cfg, in_chann, out_chann):
# fully connected layers: n-1 in_chann x in_chann layers, and 1 in_chann x out_chann layer
layers = [nn.Linear(in_chann, in_chann) for _ in range(cfg.MODEL.SUBCLASS.NUM_LAYERS - 1)]
layers = [
nn.Linear(in_chann, in_chann) for _ in range(cfg.MODEL.SUBCLASS.NUM_LAYERS - 1)
]
layers.append(nn.Linear(in_chann, out_chann))
return nn.Sequential(*layers)
@ROI_HEADS_REGISTRY.register()
class StandardROIHeadsWithSubClass(StandardROIHeads):
"""
......@@ -119,7 +125,9 @@ class StandardROIHeadsWithSubClass(StandardROIHeads):
return
self.num_subclasses = cfg.MODEL.SUBCLASS.NUM_SUBCLASSES
self.subclass_head = build_subclass_head(cfg, self.box_head.output_shape.channels, self.num_subclasses + 1)
self.subclass_head = build_subclass_head(
cfg, self.box_head.output_shape.channels, self.num_subclasses + 1
)
for layer in self.subclass_head:
nn.init.normal_(layer.weight, std=0.01)
......@@ -142,12 +150,16 @@ class StandardROIHeadsWithSubClass(StandardROIHeads):
for pp_per_im in proposals:
if not pp_per_im.has("gt_subclasses"):
background_subcls_idx = 0
pp_per_im.gt_subclasses = torch.cuda.LongTensor(len(pp_per_im)).fill_(background_subcls_idx)
pp_per_im.gt_subclasses = torch.cuda.LongTensor(
len(pp_per_im)
).fill_(background_subcls_idx)
del targets
features_list = [features[f] for f in self.in_features]
box_features = self.box_pooler(features_list, [x.proposal_boxes for x in proposals])
box_features = self.box_pooler(
features_list, [x.proposal_boxes for x in proposals]
)
box_features = self.box_head(box_features)
predictions = self.box_predictor(box_features)
# --- end copy ---------------------------------------------------------
......@@ -155,8 +167,7 @@ class StandardROIHeadsWithSubClass(StandardROIHeads):
# NOTE: don't delete box_features, keep it temporarily
# del box_features
box_features = box_features.view(
box_features.shape[0],
np.prod(box_features.shape[1:])
box_features.shape[0], np.prod(box_features.shape[1:])
)
pred_subclass_logits = self.subclass_head(box_features)
......@@ -195,8 +206,7 @@ class StandardROIHeadsWithSubClass(StandardROIHeads):
if torch.onnx.is_in_onnx_export():
assert len(pred_instances) == 1
pred_instances[0].pred_subclass_prob = alias(
pred_instances[0].pred_subclass_prob,
"subclass_prob_nms"
pred_instances[0].pred_subclass_prob, "subclass_prob_nms"
)
return pred_instances, {}
......@@ -2,4 +2,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .build import build_optimizer_mapper
__all__ = ['build_optimizer_mapper']
__all__ = ["build_optimizer_mapper"]
......@@ -12,9 +12,7 @@ from detectron2.utils.registry import Registry
D2GO_OPTIM_MAPPER_REGISTRY = Registry("D2GO_OPTIM_MAPPER")
def reduce_param_groups(
param_groups: List[Dict[str, Any]]
):
def reduce_param_groups(param_groups: List[Dict[str, Any]]):
# The number of parameter groups needs to be as small as possible in order
# to efficiently use the PyTorch multi-tensor optimizer. Therefore instead
# of using a parameter_group per single parameter, we group all the params
......
......@@ -14,7 +14,7 @@ def create_runner(
class_full_name: str, *args, **kwargs
) -> Union[BaseRunner, Type[LightningModule]]:
"""Constructs a runner instance if class is a d2go runner. Returns class
type if class is a Lightning module.
type if class is a Lightning module.
"""
runner_module_name, runner_class_name = class_full_name.rsplit(".", 1)
runner_module = importlib.import_module(runner_module_name)
......
......@@ -29,13 +29,13 @@ PREPARED = "_prepared"
def rsetattr(obj: Any, attr: str, val: Any) -> None:
""" Same as setattr but supports deeply nested objects. """
"""Same as setattr but supports deeply nested objects."""
pre, _, post = attr.rpartition(".")
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
def rgetattr(obj: Any, attr: str, *args) -> Any:
""" Same as getattr but supports deeply nested objects. """
"""Same as getattr but supports deeply nested objects."""
def _getattr(obj, attr):
return getattr(obj, attr, *args)
......@@ -44,7 +44,7 @@ def rgetattr(obj: Any, attr: str, *args) -> Any:
def rhasattr(obj: Any, attr: str, *args) -> bool:
""" Same as hasattr but supports deeply nested objects. """
"""Same as hasattr but supports deeply nested objects."""
try:
_ = rgetattr(obj, attr, *args)
......@@ -66,7 +66,7 @@ def _deepcopy(pl_module: LightningModule) -> LightningModule:
def _quantized_forward(self, *args, **kwargs):
""" Forward method for a quantized module. """
"""Forward method for a quantized module."""
if not self.training and hasattr(self, "_quantized"):
return self._quantized(*args, **kwargs)
return self._prepared(*args, **kwargs)
......@@ -99,11 +99,7 @@ def checkpoint_has_prepared(checkpoint: Dict[str, Any]) -> bool:
def maybe_prepare_for_quantization(model: LightningModule, checkpoint: Dict[str, Any]):
if checkpoint_has_prepared(checkpoint) and not hasattr(model, PREPARED):
# model has been prepared for QAT before saving into checkpoint
setattr(
model,
PREPARED,
_deepcopy(model).prepare_for_quant()
)
setattr(model, PREPARED, _deepcopy(model).prepare_for_quant())
class QuantizationMixin(ABC):
......@@ -241,7 +237,7 @@ class ModelTransform:
interval: Optional[int] = None
def __post_init__(self) -> None:
""" Validate a few properties for early failure. """
"""Validate a few properties for early failure."""
if (self.step is None and self.interval is None) or (
self.step is not None and self.interval is not None
):
......@@ -469,7 +465,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
batch_idx: int,
dataloader_idx: int,
) -> None:
""" Applies model transforms at as specified during training. """
"""Applies model transforms at as specified during training."""
apply_only_once = []
current_step = trainer.global_step
for i, transform in enumerate(self.transforms):
......@@ -492,7 +488,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
]
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
""" Quantize the weights since training has finalized. """
"""Quantize the weights since training has finalized."""
if hasattr(pl_module, "_quantized") or self.skip_conversion:
return
pl_module._quantized = self.convert(
......@@ -563,7 +559,7 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
qconfig_dicts: Optional[QConfigDicts] = None,
preserved_attrs: Optional[List[str]] = None,
) -> None:
""" Initialize the callback. """
"""Initialize the callback."""
self.qconfig_dicts = qconfig_dicts or {"": {"": get_default_qconfig()}}
self.preserved_attrs = set([] if preserved_attrs is None else preserved_attrs)
self.prepared: Optional[torch.nn.Module] = None
......@@ -593,7 +589,7 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
)
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
""" Convert the calibrated model to its finalized quantized version. """
"""Convert the calibrated model to its finalized quantized version."""
self.quantized = self.convert(
self.prepared, self.qconfig_dicts.keys(), attrs=self.preserved_attrs
)
......@@ -607,7 +603,7 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
batch_idx: int,
dataloader_idx: int,
) -> None:
""" Also run the validation batch through the quantized model for calibration. """
"""Also run the validation batch through the quantized model for calibration."""
if self.should_calibrate:
with torch.no_grad():
self.prepared(batch)
......@@ -395,7 +395,7 @@ class Detectron2GoRunner(BaseRunner):
return results
def do_test(self, cfg, model, train_iter=None):
""" do_test does not load the weights of the model.
"""do_test does not load the weights of the model.
If you want to use it outside the regular training routine,
you will have to load the weights through a checkpointer.
"""
......
......@@ -15,10 +15,10 @@ from d2go.data.utils import (
update_cfg_if_using_adhoc_dataset,
)
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.modeling import build_model
from d2go.modeling.model_freezing_utils import (
set_requires_grad,
)
from d2go.modeling import build_model
from d2go.modeling.quantization import (
default_prepare_for_quant,
default_prepare_for_quant_convert,
......@@ -206,7 +206,9 @@ class DefaultTask(pl.LightningModule):
flattened = pl.loggers.LightningLoggerBase._flatten_dict(nested_res)
if self.trainer.global_rank:
assert len(flattened) == 0, "evaluation results should have been reduced on rank 0."
assert (
len(flattened) == 0
), "evaluation results should have been reduced on rank 0."
self.log_dict(flattened, rank_zero_only=True)
def test_epoch_end(self, _outputs) -> None:
......
......@@ -17,12 +17,12 @@ from d2go.config import (
)
from d2go.distributed import get_local_rank, get_num_processes_per_machine
from d2go.runner import GeneralizedRCNNRunner, create_runner
from d2go.utils.helper import run_once
from d2go.utils.launch_environment import get_launch_environment
from detectron2.utils.collect_env import collect_env_info
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger
from detectron2.utils.serialize import PicklableWrapper
from d2go.utils.helper import run_once
from detectron2.utils.file_io import PathManager
from mobile_cv.common.misc.py import FolderLock, MultiprocessingPdb, post_mortem_if_fail
......@@ -34,7 +34,7 @@ def basic_argument_parser(
requires_config_file=True,
requires_output_dir=True,
):
""" Basic cli tool parser for Detectron2Go binaries """
"""Basic cli tool parser for Detectron2Go binaries"""
parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
parser.add_argument(
"--runner",
......@@ -201,6 +201,7 @@ def _setup_after_launch(cfg: CN, output_dir: str, runner):
cfg.OUTPUT_DIR = output_dir
dump_cfg(cfg, os.path.join(output_dir, "config.yaml"))
def setup_after_launch(cfg: CN, output_dir: str, runner):
_setup_after_launch(cfg, output_dir, runner)
logger.info("Initializing runner ...")
......@@ -210,10 +211,12 @@ def setup_after_launch(cfg: CN, output_dir: str, runner):
auto_scale_world_size(cfg, new_world_size=comm.get_world_size())
def setup_after_lightning_launch(cfg: CN, output_dir: str):
_setup_after_launch(cfg, output_dir, runner=None)
log_info(cfg, runner=None)
@run_once()
def setup_loggers(output_dir, color=None):
if not color:
......
......@@ -31,8 +31,8 @@ def get_rel_loss_checker(rel_thres=1.0):
class TrainImageWriter(object):
def __init__(self, cfg, tbx_writer, max_count=5):
""" max_count: max number of data written to tensorboard, additional call
will be ignored
"""max_count: max number of data written to tensorboard, additional call
will be ignored
"""
self.visualizer = VisualizerWrapper(cfg)
self.writer = tbx_writer
......@@ -58,8 +58,8 @@ class TrainImageWriter(object):
class FileWriter(object):
def __init__(self, output_dir, max_count=5):
""" max_count: max number of data written to tensorboard, additional call
will be ignored
"""max_count: max number of data written to tensorboard, additional call
will be ignored
"""
self.output_dir = output_dir
self.max_count = max_count
......
# Copyright (c) Facebook, Inc. and its affiliates.
from collections import deque
import cv2
import detectron2.data.transforms as T
import torch
from d2go.model_zoo import model_zoo
from detectron2.data import MetadataCatalog
from detectron2.utils.video_visualizer import VideoVisualizer
from detectron2.utils.visualizer import ColorMode, Visualizer
import detectron2.data.transforms as T
from d2go.model_zoo import model_zoo
class DemoPredictor:
def __init__(self, model, min_size_test=224, max_size_test=320, input_format="RGB"):
self.model = model
self.model.eval()
self.aug = T.ResizeShortestEdge(
[min_size_test, min_size_test], max_size_test
)
self.aug = T.ResizeShortestEdge([min_size_test, min_size_test], max_size_test)
self.input_format = input_format
......@@ -43,6 +41,7 @@ class DemoPredictor:
predictions = self.model([inputs])[0]
return predictions
class VisualizationDemo(object):
def __init__(self, cfg, config_file, instance_mode=ColorMode.IMAGE, parallel=False):
"""
......@@ -59,7 +58,7 @@ class VisualizationDemo(object):
self.instance_mode = instance_mode
self.parallel = parallel
model = model_zoo.get(config_file, trained=True)#runner.build_model(cfg)
model = model_zoo.get(config_file, trained=True) # runner.build_model(cfg)
self.predictor = DemoPredictor(model)
def run_on_image(self, image):
......@@ -123,7 +122,9 @@ class VisualizationDemo(object):
)
elif "instances" in predictions:
predictions = predictions["instances"].to(self.cpu_device)
vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
vis_frame = video_visualizer.draw_instance_predictions(
frame, predictions
)
elif "sem_seg" in predictions:
vis_frame = video_visualizer.draw_sem_seg(
frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
......
......@@ -31,13 +31,13 @@ class EMAState(object):
decay: float = 0.999,
device: Optional[str] = None,
) -> "EMAState":
""" Constructs model state from the model and move to device if given."""
"""Constructs model state from the model and move to device if given."""
ret = cls(decay, device)
ret.load_from(model)
return ret
def load_from(self, model: nn.Module) -> None:
""" Load state from the model. """
"""Load state from the model."""
self.state.clear()
for name, val in self._get_model_state_iterator(model):
val = val.detach().clone()
......@@ -47,7 +47,7 @@ class EMAState(object):
return len(self.state) > 0
def apply_to(self, model: nn.Module) -> None:
""" Apply EMA state to the model. """
"""Apply EMA state to the model."""
with torch.no_grad():
for name, val in self._get_model_state_iterator(model):
assert (
......@@ -64,7 +64,7 @@ class EMAState(object):
self.state[name] = val.to(self.device) if self.device else val
def to(self, device: torch.device) -> None:
""" moves EMA state to device. """
"""moves EMA state to device."""
for name, val in self.state.items():
self.state[name] = val.to(device)
......
......@@ -2,16 +2,16 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import torch
import os
import logging
import os
import detectron2.utils.comm as comm
from detectron2.utils.file_io import PathManager
from detectron2.utils.analysis import FlopCountAnalysis
from fvcore.nn import flop_count_table, flop_count_str
import mobile_cv.lut.lib.pt.flops_utils as flops_utils
import torch
from d2go.utils.helper import run_once
from detectron2.utils.analysis import FlopCountAnalysis
from detectron2.utils.file_io import PathManager
from fvcore.nn import flop_count_table, flop_count_str
logger = logging.getLogger(__name__)
......@@ -58,7 +58,6 @@ def dump_flops_info(model, inputs, output_dir, use_eval_mode=True):
except Exception:
logger.exception("Failed to estimate flops using mobile_cv's FlopsEstimation")
# 2. using d2/fvcore's flop counter
try:
flops = FlopCountAnalysis(model, inputs)
......@@ -81,14 +80,16 @@ def dump_flops_info(model, inputs, output_dir, use_eval_mode=True):
flops_table = flop_count_table(flops, max_depth=3)
logger.info("Flops table:\n" + flops_table)
except Exception:
logger.exception("Failed to estimate flops using detectron2's FlopCountAnalysis")
logger.exception(
"Failed to estimate flops using detectron2's FlopCountAnalysis"
)
return flops
def add_flop_printing_hook(
model,
output_dir: str,
):
model,
output_dir: str,
):
"""
Add a pytorch module forward hook that will print/save flops of the whole model
at the first time the model is called.
......@@ -96,6 +97,7 @@ def add_flop_printing_hook(
Args:
output_dir: directory to save more detailed flop info
"""
def hook(module, input):
handle.remove()
dump_flops_info(module, input, output_dir)
......
......@@ -6,11 +6,11 @@ from d2go.data.build import (
add_random_subset_training_sampler_default_configs,
)
from d2go.data.config import add_d2go_data_default_configs
from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.backbone.fbnet_cfg import (
add_bifpn_default_configs,
add_fbnet_v2_default_configs,
)
from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.model_freezing_utils import add_model_freezing_configs
from d2go.modeling.quantization import add_quantization_default_configs
from d2go.modeling.subclass import add_subclass_configs
......
......@@ -7,24 +7,21 @@ import inspect
import logging
import math
import os
import re
import tempfile
import zipfile
import pickle
import re
import signal
import sys
import tempfile
import threading
import time
import traceback
import typing
import warnings
import pkg_resources
import zipfile
from contextlib import contextmanager
from functools import partial
from random import random
import six
from functools import wraps
from random import random
from typing import (
Any,
Callable,
......@@ -39,12 +36,20 @@ from typing import (
Union,
)
import torch
import detectron2.utils.comm as comm
import pkg_resources
import six
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
from detectron2.engine import (
DefaultTrainer,
default_argument_parser,
default_setup,
hooks,
launch,
)
from detectron2.evaluation import (
CityscapesInstanceEvaluator,
CityscapesSemSegEvaluator,
......@@ -66,9 +71,11 @@ NT = TypeVar("T", bound=NamedTuple)
from detectron2.utils.events import TensorboardXWriter
class MultipleFunctionCallError(Exception):
pass
def run_once(
raise_on_multiple: bool = False,
# pyre-fixme[34]: `Variable[T]` isn't present in the function's parameters.
......@@ -102,8 +109,8 @@ def run_once(
class retryable(object):
"""Fake retryable function
"""
"""Fake retryable function"""
def __init__(self, num_tries=1, sleep_time=0.1):
pass
......@@ -134,6 +141,7 @@ def alias(x, name, is_backward=False):
assert isinstance(x, torch.Tensor)
return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward)
class D2Trainer(DefaultTrainer):
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
......@@ -183,6 +191,7 @@ class D2Trainer(DefaultTrainer):
return evaluator_list[0]
return DatasetEvaluators(evaluator_list)
def reroute_config_path(path: str) -> str:
"""
Supporting rerouting the config files for convenience:
......
......@@ -6,5 +6,6 @@ import logging
import os
from functools import lru_cache
def get_tensorboard_log_dir(output_dir):
return output_dir
......@@ -67,7 +67,7 @@ def enable_ddp_env(func):
def tempdir(func):
""" A decorator for creating a tempory directory that is cleaned up after function execution. """
"""A decorator for creating a tempory directory that is cleaned up after function execution."""
@wraps(func)
def wrapper(self, *args, **kwargs):
......
......@@ -5,14 +5,13 @@ import glob
import multiprocessing as mp
import os
import time
import cv2
import tqdm
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger
from d2go.model_zoo import model_zoo
from d2go.utils.demo_predictor import VisualizationDemo
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger
# constants
WINDOW_NAME = "COCO detections"
......@@ -22,7 +21,9 @@ def setup_cfg(cfg, args):
# Set score_threshold for builtin models
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = (
args.confidence_threshold
)
cfg.freeze()
return cfg
......@@ -31,11 +32,13 @@ def get_parser():
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
parser.add_argument(
"--config-file",
default='keypoint_rcnn_fbnetv3a_dsmask_C4.yaml',
default="keypoint_rcnn_fbnetv3a_dsmask_C4.yaml",
metavar="FILE",
help="path to config file",
)
parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
parser.add_argument(
"--webcam", action="store_true", help="Take inputs from webcam."
)
parser.add_argument("--video-input", help="Path to video file.")
parser.add_argument(
"--input",
......@@ -99,7 +102,9 @@ def main():
assert os.path.isdir(args.output), args.output
out_filename = os.path.join(args.output, os.path.basename(path))
else:
assert len(args.input) == 1, "Please specify a directory with args.output"
assert (
len(args.input) == 1
), "Please specify a directory with args.output"
out_filename = args.output
visualized_output.save(out_filename)
else:
......@@ -145,6 +150,7 @@ def main():
output_file.release()
else:
cv2.destroyAllWindows()
if __name__ == "__main__":
main()
......@@ -2,4 +2,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from . import models, util, datasets
__all__ = ['models', 'util', 'datasets']
__all__ = ["models", "util", "datasets"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment