Commit 82295dbf authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

enable black for mobile-vision

Summary:
https://fb.workplace.com/groups/pythonfoundation/posts/2990917737888352

Remove `mobile-vision` from opt-out list; leaving `mobile-vision/SNPE` opted out because of 3rd-party code.

arc lint --take BLACK --apply-patches --paths-cmd 'hg files mobile-vision'

allow-large-files

Reviewed By: sstsai-adl

Differential Revision: D30721093

fbshipit-source-id: 9e5c16d988b315b93a28038443ecfb92efd18ef8
parent a56c7e15
......@@ -49,7 +49,7 @@ def add_random_subset_training_sampler_default_configs(cfg: CfgNode):
Add default cfg.DATALOADER.RANDOM_SUBSET_RATIO for RandomSubsetTrainingSampler
The CfgNode under cfg.DATALOADER.RANDOM_SUBSET_RATIO should be a float > 0 and <= 1
"""
cfg.DATALOADER.RANDOM_SUBSET_RATIO = 1.
cfg.DATALOADER.RANDOM_SUBSET_RATIO = 1.0
def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]:
......
......@@ -7,9 +7,9 @@ import logging
import numpy as np
import torch
from d2go.data.dataset_mappers.d2go_dataset_mapper import D2GoDatasetMapper
from detectron2.data import detection_utils as utils, transforms as T
from detectron2.structures import BoxMode, Instances, RotatedBoxes
from d2go.data.dataset_mappers.d2go_dataset_mapper import D2GoDatasetMapper
from .build import D2GO_DATA_MAPPER_REGISTRY
......
......@@ -190,8 +190,14 @@ def convert_to_dict_list(image_root, id_map, imgs, anns, dataset_name=None):
bbox_object = obj.get("bbox", None)
if bbox_object is not None and "bbox_mode" in obj:
bbox_object = BoxMode.convert(bbox_object, obj["bbox_mode"], BoxMode.XYWH_ABS)
if record.get("width") and record.get("height") and not valid_bbox(bbox_object, record["width"], record["height"]):
bbox_object = BoxMode.convert(
bbox_object, obj["bbox_mode"], BoxMode.XYWH_ABS
)
if (
record.get("width")
and record.get("height")
and not valid_bbox(bbox_object, record["width"], record["height"])
):
num_instances_without_valid_bounding_box += 1
continue
......
......@@ -13,19 +13,20 @@ raw data with fields such as:
...
"""
import os
import json
import logging
import os
import tempfile
from pathlib import Path
from detectron2.utils.file_io import PathManager
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.utils.file_io import PathManager
logger = logging.getLogger(__name__)
IMG_EXTENSIONS = ['.jpg', '.JPG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
IMG_EXTENSIONS = [".jpg", ".JPG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP"]
def is_image_file(filename):
......@@ -65,7 +66,7 @@ def load_pix2pix_image_folder(image_root, input_folder="input", gt_folder="gt"):
f = {
"file_name": fname[:-4],
"input_path": input_path,
"gt_path": gt_path
"gt_path": gt_path,
}
data.append(f)
if image_root.rfind("test") != -1 and len(data) == 5000:
......@@ -76,8 +77,13 @@ def load_pix2pix_image_folder(image_root, input_folder="input", gt_folder="gt"):
def load_pix2pix_json(
json_path, input_folder, gt_folder, mask_folder,
real_json_path=None, real_folder=None, max_num=1e10,
json_path,
input_folder,
gt_folder,
mask_folder,
real_json_path=None,
real_folder=None,
max_num=1e10,
):
"""
Args:
......@@ -90,11 +96,11 @@ def load_pix2pix_json(
"""
real_filenames = {}
if real_json_path is not None:
with PathManager.open(real_json_path, 'r') as f:
with PathManager.open(real_json_path, "r") as f:
real_filenames = json.load(f)
data = []
with PathManager.open(json_path, 'r') as f:
with PathManager.open(json_path, "r") as f:
filenames = json.load(f)
in_len = len(filenames)
......@@ -110,9 +116,9 @@ def load_pix2pix_json(
fname = in_keys[cnt % in_len]
input_label = filenames[fname]
if isinstance(input_label, tuple) or isinstance(input_label, list):
assert len(input_label) == 2, (
"Save (real_name, label) as the value of the json dict for resampling"
)
assert (
len(input_label) == 2
), "Save (real_name, label) as the value of the json dict for resampling"
fname, input_label = input_label
f = {
......@@ -121,7 +127,7 @@ def load_pix2pix_json(
"gt_folder": gt_folder,
"mask_folder": mask_folder,
"input_label": input_label,
"real_folder": real_folder
"real_folder": real_folder,
}
if real_len > 0:
real_fname = real_keys[cnt % real_len]
......@@ -151,10 +157,16 @@ def register_folder_dataset(
max_num=1e10,
):
DatasetCatalog.register(
name, lambda: load_pix2pix_json(
json_path, input_folder, gt_folder, mask_folder,
real_json_path, real_folder, max_num
)
name,
lambda: load_pix2pix_json(
json_path,
input_folder,
gt_folder,
mask_folder,
real_json_path,
real_folder,
max_num,
),
)
metadata = {
"input_src_path": input_src_path,
......@@ -190,9 +202,7 @@ def register_lmdb_dataset(
src_data_folder,
max_num,
):
DatasetCatalog.register(
name, lambda: load_lmdb_keys(max_num)
)
DatasetCatalog.register(name, lambda: load_lmdb_keys(max_num))
metadata = {
"data_folder": data_folder,
"src_data_folder": src_data_folder,
......@@ -205,22 +215,23 @@ def inject_gan_datasets(cfg):
if cfg.D2GO_DATA.DATASETS.GAN_INJECTION.ENABLE:
name = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.NAME
cfg.merge_from_list(
["DATASETS.TRAIN",
[
"DATASETS.TRAIN",
list(cfg.DATASETS.TRAIN) + [name + "_train"],
"DATASETS.TEST",
list(cfg.DATASETS.TEST) + [name + "_test"]
list(cfg.DATASETS.TEST) + [name + "_test"],
]
)
json_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.JSON_PATH
assert PathManager.isfile(json_path), (
"{} is not valid!".format(json_path))
assert PathManager.isfile(json_path), "{} is not valid!".format(json_path)
image_dir = Path(tempfile.mkdtemp())
input_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.INPUT_SRC_DIR
assert PathManager.isfile(input_src_path), (
"{} is not valid!".format(input_src_path))
assert PathManager.isfile(input_src_path), "{} is not valid!".format(
input_src_path
)
input_folder = os.path.join(image_dir, name, "input")
gt_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.GT_SRC_DIR
......@@ -228,25 +239,26 @@ def inject_gan_datasets(cfg):
gt_folder = os.path.join(image_dir, name, "gt")
else:
gt_src_path = None
gt_folder=None
gt_folder = None
mask_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.MASK_SRC_DIR
if PathManager.isfile(mask_src_path):
mask_folder = os.path.join(image_dir, name, "mask")
else:
mask_src_path = None
mask_folder=None
mask_folder = None
real_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.REAL_SRC_DIR
if PathManager.isfile(real_src_path):
real_folder = os.path.join(image_dir, name, "real")
real_json_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.REAL_JSON_PATH
assert PathManager.isfile(real_json_path), (
"{} is not valid!".format(real_json_path))
assert PathManager.isfile(real_json_path), "{} is not valid!".format(
real_json_path
)
else:
real_src_path = None
real_folder=None
real_json_path=None
real_folder = None
real_json_path = None
register_folder_dataset(
name + "_train",
......
......@@ -12,7 +12,7 @@ from .build import TRANSFORM_OP_REGISTRY, _json_load
class LocalizedBoxMotionBlurTransform(Transform):
""" Transform to blur provided bounding boxes from an image. """
"""Transform to blur provided bounding boxes from an image."""
def __init__(
self,
......@@ -36,15 +36,15 @@ class LocalizedBoxMotionBlurTransform(Transform):
return new_img
def apply_segmentation(self, segmentation: np.ndarray) -> np.ndarray:
""" Apply no transform on the full-image segmentation. """
"""Apply no transform on the full-image segmentation."""
return segmentation
def apply_coords(self, coords: np.ndarray):
""" Apply no transform on the coordinates. """
"""Apply no transform on the coordinates."""
return coords
def inverse(self) -> Transform:
""" The inverse is a No-op, only for geometric transforms. """
"""The inverse is a No-op, only for geometric transforms."""
return NoOpTransform()
......
......@@ -81,7 +81,7 @@ class PadBorderDivisible(aug.Augmentation):
self.pad_mode = pad_mode
def get_transform(self, image: np.ndarray) -> Transform:
""" image: HxWxC """
"""image: HxWxC"""
assert len(image.shape) == 3 and image.shape[2] in [1, 3]
H, W = image.shape[:2]
new_h = int(math.ceil(H / self.size_divisibility) * self.size_divisibility)
......
......@@ -59,7 +59,7 @@ class AugInput:
class Tensor2Array(Transform):
""" Convert image tensor (CHW) to np array (HWC) """
"""Convert image tensor (CHW) to np array (HWC)"""
def __init__(self):
super().__init__()
......@@ -82,7 +82,7 @@ class Tensor2Array(Transform):
class Array2Tensor(Transform):
""" Convert image np array (HWC) to torch tensor (CHW) """
"""Convert image np array (HWC) to torch tensor (CHW)"""
def __init__(self):
super().__init__()
......
......@@ -233,7 +233,7 @@ class COCOWithClassesToUse(AdhocCOCODataset):
# check if name is already a derived class and try to reverse it
res = re.match("(?P<src>.+)@(?P<num>[0-9]+)classes", src_ds_name)
if res is not None:
src_ds_name = res['src']
src_ds_name = res["src"]
super().__init__(
src_ds_name=src_ds_name,
......
......@@ -155,7 +155,6 @@ def _distributed_worker(
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
comm.synchronize()
ret = main_func(*args)
if global_rank == 0:
logger.info(
......
......@@ -5,8 +5,8 @@ import itertools
import logging
from collections import OrderedDict
import numpy as np
import detectron2.utils.comm as comm
import numpy as np
from detectron2.evaluation import DatasetEvaluator
logger = logging.getLogger(__name__)
......
......@@ -40,22 +40,23 @@ def _register_d2_datasets():
@_record_times(REGISTER_TIME)
def _register():
from d2go.modeling.backbone import ( # NOQA
fbnet_v2,
)
from d2go.data import dataset_mappers # NOQA
from d2go.data.datasets import (
register_json_datasets,
register_builtin_datasets,
)
from d2go.modeling.backbone import ( # NOQA
fbnet_v2,
)
#register_json_datasets()
#register_builtin_datasets()
# register_json_datasets()
# register_builtin_datasets()
def initialize_all():
# exclude torch from timing
from torchvision.ops import nms # noqa
_setup_env()
_register_d2_datasets()
_register()
......
# Copyright (c) Facebook, Inc. and its affiliates.
import os
from typing import Optional
import pkg_resources
import torch
from detectron2.checkpoint import DetectionCheckpointer
from d2go.runner import create_runner
from detectron2.checkpoint import DetectionCheckpointer
class _ModelZooUrls(object):
"""
Mapping from names to officially released D2Go pre-trained models.
"""
S3_PREFIX = "https://mobile-cv.s3-us-west-2.amazonaws.com/d2go/models/"
CONFIG_PATH_TO_URL_SUFFIX = {
"faster_rcnn_fbnetv3a_C4.yaml": "268421013/model_final.pth",
......@@ -56,7 +58,9 @@ def get_config_file(config_path):
return cfg_file
def get_config(config_path, trained: bool = False, runner="d2go.runner.GeneralizedRCNNRunner"):
def get_config(
config_path, trained: bool = False, runner="d2go.runner.GeneralizedRCNNRunner"
):
"""
Returns a config object for a model in model zoo.
Args:
......@@ -77,7 +81,13 @@ def get_config(config_path, trained: bool = False, runner="d2go.runner.Generaliz
cfg.MODEL.WEIGHTS = get_checkpoint_url(config_path)
return cfg
def get(config_path, trained: bool = False, device: Optional[str] = None, runner="d2go.runner.GeneralizedRCNNRunner"):
def get(
config_path,
trained: bool = False,
device: Optional[str] = None,
runner="d2go.runner.GeneralizedRCNNRunner",
):
"""
Get a model specified by relative path under Detectron2's official ``configs/`` directory.
Args:
......
......@@ -6,8 +6,7 @@ from d2go.config import CfgNode as CN
def add_fbnet_default_configs(_C):
""" FBNet options and default values
"""
"""FBNet options and default values"""
_C.MODEL.FBNET = CN()
_C.MODEL.FBNET.ARCH = "default"
# custom arch
......
......@@ -9,6 +9,7 @@ from typing import List
import torch
import torch.nn as nn
from d2go.modeling.modeldef.fbnet_modeldef_registry import FBNetV2ModelArch
from detectron2.layers import ShapeSpec
from detectron2.modeling import (
BACKBONE_REGISTRY,
......@@ -20,7 +21,6 @@ from detectron2.modeling.backbone.fpn import FPN, LastLevelMaxPool, LastLevelP6P
from detectron2.modeling.roi_heads import box_head, keypoint_head, mask_head
from detectron2.utils.logger import log_first_n
from mobile_cv.arch.fbnet_v2 import fbnet_builder as mbuilder
from d2go.modeling.modeldef.fbnet_modeldef_registry import FBNetV2ModelArch
from mobile_cv.arch.utils.helper import format_dict_expanding_list_values
from .modules import (
......@@ -49,7 +49,9 @@ def _get_builder_norm_args(cfg):
def _merge_fbnetv2_arch_def(cfg):
arch_def = {}
assert all(isinstance(x, dict) for x in cfg.MODEL.FBNET_V2.ARCH_DEF), cfg.MODEL.FBNET_V2.ARCH_DEF
assert all(
isinstance(x, dict) for x in cfg.MODEL.FBNET_V2.ARCH_DEF
), cfg.MODEL.FBNET_V2.ARCH_DEF
for dic in cfg.MODEL.FBNET_V2.ARCH_DEF:
arch_def.update(dic)
return arch_def
......@@ -58,16 +60,17 @@ def _merge_fbnetv2_arch_def(cfg):
def _parse_arch_def(cfg):
arch = cfg.MODEL.FBNET_V2.ARCH
arch_def = cfg.MODEL.FBNET_V2.ARCH_DEF
assert (arch != "" and not arch_def) ^ (not arch and arch_def != []), (
"Only allow one unset node between MODEL.FBNET_V2.ARCH ({}) and MODEL.FBNET_V2.ARCH_DEF ({})"
.format(arch, arch_def)
assert (arch != "" and not arch_def) ^ (
not arch and arch_def != []
), "Only allow one unset node between MODEL.FBNET_V2.ARCH ({}) and MODEL.FBNET_V2.ARCH_DEF ({})".format(
arch, arch_def
)
arch_def = FBNetV2ModelArch.get(arch) if arch else _merge_fbnetv2_arch_def(cfg)
# NOTE: arch_def is a dictionary describing the CNN architecture for creating
# the detection model. It can describe a wide range of models including the
# original FBNet. Each key-value pair expresses either a sub part of the model
# like trunk or head, or stores other meta information.
message = "Using un-unified arch_def for ARCH \"{}\" (without scaling):\n{}".format(
message = 'Using un-unified arch_def for ARCH "{}" (without scaling):\n{}'.format(
arch, format_dict_expanding_list_values(arch_def)
)
log_first_n(logging.INFO, message, n=1, key="message")
......@@ -129,13 +132,15 @@ def _get_stride_per_stage(blocks):
def fbnet_identifier_checker(func):
""" Can be used to decorate _load_from_state_dict """
"""Can be used to decorate _load_from_state_dict"""
def wrapper(self, state_dict, prefix, *args, **kwargs):
possible_keys = [k for k in state_dict.keys() if k.startswith(prefix)]
if not all(FBNET_BUILDER_IDENTIFIER in k for k in possible_keys):
logger.warning(
"Couldn't match FBNetV2 pattern given prefix {}, possible keys: \n{}"
.format(prefix, "\n".join(possible_keys))
"Couldn't match FBNetV2 pattern given prefix {}, possible keys: \n{}".format(
prefix, "\n".join(possible_keys)
)
)
if any("xif" in k for k in possible_keys):
raise RuntimeError(
......@@ -146,6 +151,7 @@ def fbnet_identifier_checker(func):
" still found, see D19477651 as example."
)
return func(self, state_dict, prefix, *args, **kwargs)
return wrapper
......@@ -183,8 +189,9 @@ def build_fbnet(cfg, name, in_channels):
arch_def = mbuilder.unify_arch_def(raw_arch_def, [name])
arch_def = {name: arch_def[name]}
logger.info(
"Build FBNet using unified arch_def:\n{}"
.format(format_dict_expanding_list_values(arch_def))
"Build FBNet using unified arch_def:\n{}".format(
format_dict_expanding_list_values(arch_def)
)
)
arch_def_blocks = arch_def[name]
......@@ -192,15 +199,19 @@ def build_fbnet(cfg, name, in_channels):
trunk_stride_per_stage = _get_stride_per_stage(arch_def_blocks)
shape_spec_per_stage = []
for i, stride_i in enumerate(trunk_stride_per_stage):
stages.append(builder.build_blocks(
stages.append(
builder.build_blocks(
arch_def_blocks,
stage_indices=[i],
prefix_name=FBNET_BUILDER_IDENTIFIER + "_",
))
shape_spec_per_stage.append(ShapeSpec(
)
)
shape_spec_per_stage.append(
ShapeSpec(
channels=builder.last_depth,
stride=stride_i,
))
)
)
return FBNetModule(*stages), shape_spec_per_stage
......@@ -226,9 +237,7 @@ class FBNetV2Backbone(Backbone):
def __init__(self, cfg):
super(FBNetV2Backbone, self).__init__()
stages, shape_specs = build_fbnet(
cfg,
name="trunk",
in_channels=cfg.MODEL.FBNET_V2.STEM_IN_CHANNELS
cfg, name="trunk", in_channels=cfg.MODEL.FBNET_V2.STEM_IN_CHANNELS
)
self._trunk_stage_names = []
......@@ -338,9 +347,7 @@ class FBNetV2RpnHead(nn.Module):
num_cell_anchors = num_cell_anchors[0]
self.rpn_feature, shape_specs = build_fbnet(
cfg,
name="rpn",
in_channels=in_channels
cfg, name="rpn", in_channels=in_channels
)
self.rpn_regressor = RPNHeadConvRegressor(
in_channels=shape_specs[-1].channels,
......@@ -359,9 +366,7 @@ class FBNetV2RoIBoxHead(nn.Module):
super(FBNetV2RoIBoxHead, self).__init__()
self.roi_box_conv, shape_specs = build_fbnet(
cfg,
name="bbox",
in_channels=input_shape.channels
cfg, name="bbox", in_channels=input_shape.channels
)
self._out_channels = shape_specs[-1].channels
......@@ -388,9 +393,7 @@ class FBNetV2RoIKeypointHead(keypoint_head.BaseKeypointRCNNHead):
)
self.feature_extractor, shape_specs = build_fbnet(
cfg,
name="kpts",
in_channels=input_shape.channels
cfg, name="kpts", in_channels=input_shape.channels
)
self.predictor = KeypointRCNNPredictor(
......@@ -462,7 +465,9 @@ class FBNetV2RoIKeypointHeadKPRCNNConvUpsamplePredictorNoUpscale(
keypoint_head.BaseKeypointRCNNHead,
):
def __init__(self, cfg, input_shape: ShapeSpec):
super(FBNetV2RoIKeypointHeadKPRCNNConvUpsamplePredictorNoUpscale, self).__init__(
super(
FBNetV2RoIKeypointHeadKPRCNNConvUpsamplePredictorNoUpscale, self
).__init__(
cfg=cfg,
input_shape=input_shape,
)
......
......@@ -3,11 +3,12 @@
import logging
import numpy as np
import torch
from typing import List
import detectron2.utils.comm as comm
import numpy as np
import torch
from d2go.config import temp_defrost, CfgNode as CN
from detectron2.engine import hooks
from detectron2.layers import ShapeSpec
from detectron2.modeling import GeneralizedRCNN
......@@ -18,7 +19,6 @@ from detectron2.modeling.anchor_generator import (
)
from detectron2.modeling.proposal_generator.rpn import RPN
from detectron2.structures.boxes import Boxes
from d2go.config import temp_defrost, CfgNode as CN
logger = logging.getLogger(__name__)
......@@ -68,7 +68,7 @@ def compute_kmeans_anchors_hook(runner, cfg):
@ANCHOR_GENERATOR_REGISTRY.register()
class KMeansAnchorGenerator(DefaultAnchorGenerator):
""" Generate anchors using pre-computed KMEANS_ANCHORS.COMPUTED_ANCHORS """
"""Generate anchors using pre-computed KMEANS_ANCHORS.COMPUTED_ANCHORS"""
def __init__(self, cfg, input_shape: List[ShapeSpec]):
torch.nn.Module.__init__(self)
......@@ -106,8 +106,9 @@ class KMeansAnchorGenerator(DefaultAnchorGenerator):
def collect_boxes_size_stats(data_loader, max_num_imgs, _legacy_plus_one=False):
logger.info(
"Collecting size of boxes, loading up to {} images from data loader ..."
.format(max_num_imgs)
"Collecting size of boxes, loading up to {} images from data loader ...".format(
max_num_imgs
)
)
# data_loader might be infinite length, thus can't loop all images, using
# max_num_imgs == 0 stands for 0 images instead of whole dataset
......@@ -140,8 +141,9 @@ def collect_boxes_size_stats(data_loader, max_num_imgs, _legacy_plus_one=False):
percentage = 100.0 * i / estimated_iters
logger.info(
"Processed batch {} ({:.2f}%) from data_loader, got {} boxes,"
" remaining number of images: {}/{}"
.format(i, percentage, len(box_sizes), remaining_num_imgs, max_num_imgs)
" remaining number of images: {}/{}".format(
i, percentage, len(box_sizes), remaining_num_imgs, max_num_imgs
)
)
if remaining_num_imgs <= 0:
assert remaining_num_imgs == 0
......@@ -149,21 +151,17 @@ def collect_boxes_size_stats(data_loader, max_num_imgs, _legacy_plus_one=False):
box_sizes = np.array(box_sizes)
logger.info(
"Collected {} boxes from {} images"
.format(len(box_sizes), max_num_imgs)
"Collected {} boxes from {} images".format(len(box_sizes), max_num_imgs)
)
return box_sizes
def compute_kmeans_anchors(
cfg,
data_loader,
sort_by_area=True,
_stride=0,
_legacy_plus_one=False
cfg, data_loader, sort_by_area=True, _stride=0, _legacy_plus_one=False
):
assert cfg.MODEL.KMEANS_ANCHORS.NUM_TRAINING_IMG > 0, \
"Please provide positive MODEL.KMEANS_ANCHORS.NUM_TRAINING_IMG"
assert (
cfg.MODEL.KMEANS_ANCHORS.NUM_TRAINING_IMG > 0
), "Please provide positive MODEL.KMEANS_ANCHORS.NUM_TRAINING_IMG"
num_training_img = cfg.MODEL.KMEANS_ANCHORS.NUM_TRAINING_IMG
div_i, mod_i = divmod(num_training_img, comm.get_world_size())
......@@ -179,9 +177,11 @@ def compute_kmeans_anchors(
box_sizes = np.concatenate(all_box_sizes)
logger.info("Collected {} boxes from all gpus".format(len(box_sizes)))
assert cfg.MODEL.KMEANS_ANCHORS.NUM_CLUSTERS > 0, \
"Please provide positive MODEL.KMEANS_ANCHORS.NUM_CLUSTERS"
assert (
cfg.MODEL.KMEANS_ANCHORS.NUM_CLUSTERS > 0
), "Please provide positive MODEL.KMEANS_ANCHORS.NUM_CLUSTERS"
from sklearn.cluster import KMeans # delayed import
default_anchors = (
KMeans(
n_clusters=cfg.MODEL.KMEANS_ANCHORS.NUM_CLUSTERS,
......@@ -214,12 +214,15 @@ def compute_kmeans_anchors(
anchors = anchors[indices]
sqrt_areas = sqrt_areas[indices].tolist()
display_str = "\n".join([
display_str = "\n".join(
[
s + "\t sqrt area: {:.2f}".format(a)
for s, a in zip(str(anchors).split("\n"), sqrt_areas)
])
]
)
logger.info(
"Compuated kmeans anchors (sorted by area: {}):\n{}"
.format(sort_by_area, display_str)
"Compuated kmeans anchors (sorted by area: {}):\n{}".format(
sort_by_area, display_str
)
)
return anchors
......@@ -2,8 +2,8 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from detectron2.modeling import build_model as d2_build_model
from d2go.utils.misc import _log_api_usage
from detectron2.modeling import build_model as d2_build_model
def build_model(cfg):
......
......@@ -168,21 +168,24 @@ def _fx_quant_prepare(self, cfg):
self.backbone = prep_fn(
self.backbone,
qconfig,
prepare_custom_config_dict={"preserved_attributes": ["size_divisibility"],
prepare_custom_config_dict={
"preserved_attributes": ["size_divisibility"],
# keep the output of backbone quantized, to avoid
# redundant dequant
# TODO: output of backbone is a dict and currently this will keep all output
# quantized, when we fix the implementation of "output_quantized_idxs"
# we'll need to change this
"output_quantized_idxs": [0]},
"output_quantized_idxs": [0],
},
)
self.proposal_generator.rpn_head.rpn_feature = prep_fn(
self.proposal_generator.rpn_head.rpn_feature, qconfig,
self.proposal_generator.rpn_head.rpn_feature,
qconfig,
prepare_custom_config_dict={
# rpn_feature expecting quantized input, this is used to avoid redundant
# quant
"input_quantized_idxs": [0]
}
},
)
self.proposal_generator.rpn_head.rpn_regressor.cls_logits = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.cls_logits, qconfig
......@@ -191,27 +194,26 @@ def _fx_quant_prepare(self, cfg):
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred, qconfig
)
self.roi_heads.box_head.roi_box_conv = prep_fn(
self.roi_heads.box_head.roi_box_conv, qconfig,
self.roi_heads.box_head.roi_box_conv,
qconfig,
prepare_custom_config_dict={
"output_quantized_idxs": [0],
},
)
self.roi_heads.box_head.avgpool = prep_fn(
self.roi_heads.box_head.avgpool, qconfig,
prepare_custom_config_dict={
"input_quantized_idxs": [0]
})
self.roi_heads.box_head.avgpool,
qconfig,
prepare_custom_config_dict={"input_quantized_idxs": [0]},
)
self.roi_heads.box_predictor.cls_score = prep_fn(
self.roi_heads.box_predictor.cls_score, qconfig,
prepare_custom_config_dict={
"input_quantized_idxs": [0]
}
self.roi_heads.box_predictor.cls_score,
qconfig,
prepare_custom_config_dict={"input_quantized_idxs": [0]},
)
self.roi_heads.box_predictor.bbox_pred = prep_fn(
self.roi_heads.box_predictor.bbox_pred, qconfig,
prepare_custom_config_dict={
"input_quantized_idxs": [0]
}
self.roi_heads.box_predictor.bbox_pred,
qconfig,
prepare_custom_config_dict={"input_quantized_idxs": [0]},
)
......
......@@ -40,6 +40,7 @@ class AddCoordChannels(nn.Module):
@param with_r include radial distance from centroid as additional channel (default: False)
"""
def __init__(self, with_r: bool = False) -> None:
super().__init__()
self.with_r = with_r
......@@ -71,10 +72,14 @@ class AddCoordChannels(nn.Module):
xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1)
yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1)
out = torch.cat([input_tensor, xx_channel.to(device), yy_channel.to(device)], dim=1)
out = torch.cat(
[input_tensor, xx_channel.to(device), yy_channel.to(device)], dim=1
)
if self.with_r:
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
rr = torch.sqrt(
torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)
)
out = torch.cat([out, rr], dim=1)
return out
......@@ -25,13 +25,13 @@ class EMAState(object):
return ret
def save_from(self, model: torch.nn.Module, device: str = ""):
""" Save model state from `model` to this object """
"""Save model state from `model` to this object"""
for name, val in self.get_model_state_iterator(model):
val = val.detach().clone()
self.state[name] = val.to(device) if device else val
def apply_to(self, model: torch.nn.Module):
""" Apply state to `model` from this object """
"""Apply state to `model` from this object"""
with torch.no_grad():
for name, val in self.get_model_state_iterator(model):
assert (
......@@ -91,7 +91,7 @@ class EMAState(object):
class EMAUpdater(object):
""" Model Exponential Moving Average
"""Model Exponential Moving Average
Keep a moving average of everything in the model state_dict (parameters and
buffers). This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
......@@ -163,8 +163,7 @@ def may_get_ema_checkpointer(cfg, model):
def get_model_ema_state(model):
""" Return the ema state stored in `model`
"""
"""Return the ema state stored in `model`"""
model = _remove_ddp(model)
assert hasattr(model, "ema_state")
ema = model.ema_state
......@@ -172,7 +171,7 @@ def get_model_ema_state(model):
def apply_model_ema(model, state=None, save_current=False):
""" Apply ema stored in `model` to model and returns a function to restore
"""Apply ema stored in `model` to model and returns a function to restore
the weights are applied
"""
model = _remove_ddp(model)
......@@ -192,7 +191,7 @@ def apply_model_ema(model, state=None, save_current=False):
@contextmanager
def apply_model_ema_and_restore(model, state=None):
""" Apply ema stored in `model` to model and returns a function to restore
"""Apply ema stored in `model` to model and returns a function to restore
the weights are applied
"""
model = _remove_ddp(model)
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import re
import logging
import re
import torch.nn as nn
from detectron2.layers import FrozenBatchNorm2d
......@@ -32,8 +32,9 @@ def set_requires_grad(model, reg_exps, value):
if not matched:
unmatched_parameter_names.append(name)
unmatched_parameters.append(parameter)
logger.info("Matched layers (require_grad={}): {}".format(
value, matched_parameter_names))
logger.info(
"Matched layers (require_grad={}): {}".format(value, matched_parameter_names)
)
logger.info("Unmatched layers: {}".format(unmatched_parameter_names))
return matched_parameter_names, unmatched_parameter_names
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment