Commit 2256bdb7 authored by Amethyst Reese's avatar Amethyst Reese Committed by Facebook GitHub Bot
Browse files

apply Black 2024 style in fbcode (7/16)

Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447732

fbshipit-source-id: e21fbbe27882c8af183d021f4ac27029cbe93e8e
parent 09bd2869
......@@ -98,12 +98,16 @@ def build_weighted_detection_train_loader(
name: get_detection_dataset_dicts(
[name],
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN
if cfg.MODEL.LOAD_PROPOSALS
else None,
min_keypoints=(
cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0
),
proposal_files=(
cfg.DATASETS.PROPOSAL_FILES_TRAIN
if cfg.MODEL.LOAD_PROPOSALS
else None
),
)
for name in cfg.DATASETS.TRAIN
}
......
......@@ -67,9 +67,7 @@ class DiskCachedList(object):
self._lst = [_serialize(x) for x in self._lst]
total_size = sum(len(x) for x in self._lst)
# TODO: only enabling DiskCachedDataset for large enough dataset
logger.info(
"Serialized dataset takes {:.2f} MiB".format(total_size / 1024**2)
)
logger.info("Serialized dataset takes {:.2f} MiB".format(total_size / 1024**2))
self._initialize_diskcache()
def _initialize_diskcache(self):
......
......@@ -363,7 +363,7 @@ def convert_to_dict_list(
default_record = {"dataset_name": dataset_name} if dataset_name else {}
converted_dict_list = []
for (img_dict, anno_dict_list) in zip(imgs, anns):
for img_dict, anno_dict_list in zip(imgs, anns):
record = copy.deepcopy(default_record)
# NOTE: besides using (relative path) in the "file_name" filed to represent
......
......@@ -89,7 +89,7 @@ def extended_lvis_load(json_file, image_root, dataset_name=None):
dataset_dicts = []
count_ignore_image_root_warning = 0
for (img_dict, anno_dict_list) in imgs_anns:
for img_dict, anno_dict_list in imgs_anns:
record = {}
if "://" not in img_dict["file_name"]:
file_name = img_dict["file_name"]
......
......@@ -198,7 +198,7 @@ class EMAUpdater(object):
averaged_model_parameters, model_parameters, 1.0 - decay
)
if self.debug_lerp:
for (orig_val, lerp_val) in zip(
for orig_val, lerp_val in zip(
orig_averaged_model_parameters, averaged_model_parameters
):
assert torch.allclose(orig_val, lerp_val, rtol=1e-4, atol=1e-3)
......
......@@ -282,7 +282,6 @@ def mock_quantization_type(quant_func):
def default_prepare_for_quant(cfg, model):
"""
Default implementation of preparing a model for quantization. This function will
be called to before training if QAT is enabled, or before calibration during PTQ if
......
......@@ -418,9 +418,11 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
def from_config(cls, cfg: CfgNode):
qat = cfg.QUANTIZATION.QAT
callback = cls(
qconfig_dicts={submodule: None for submodule in cfg.QUANTIZATION.MODULES}
if cfg.QUANTIZATION.MODULES
else None,
qconfig_dicts=(
{submodule: None for submodule in cfg.QUANTIZATION.MODULES}
if cfg.QUANTIZATION.MODULES
else None
),
# We explicitly pass this to maintain properties for now.
preserved_attrs=["model.backbone.size_divisibility"],
start_step=qat.START_ITER,
......@@ -576,9 +578,11 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
@classmethod
def from_config(cls, cfg: CfgNode):
return cls(
qconfig_dicts={submodule: None for submodule in cfg.QUANTIZATION.MODULES}
if cfg.QUANTIZATION.MODULES
else None,
qconfig_dicts=(
{submodule: None for submodule in cfg.QUANTIZATION.MODULES}
if cfg.QUANTIZATION.MODULES
else None
),
# We explicitly pass this to maintain properties for now.
preserved_attrs=["model.backbone.size_divisibility"],
)
......
......@@ -464,9 +464,11 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
output_folder,
train_iter,
model_tag,
model.module
if isinstance(model, nn.parallel.DistributedDataParallel)
else model,
(
model.module
if isinstance(model, nn.parallel.DistributedDataParallel)
else model
),
)
inference_callbacks = self._get_inference_callbacks()
......
......@@ -35,9 +35,11 @@ class ActivationCheckpointModelingHook(mh.ModelingHook):
logger.info("Activation Checkpointing is used")
wrapper_fn = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT
if not self.cfg.ACTIVATION_CHECKPOINT.REENTRANT
else CheckpointImpl.REENTRANT,
checkpoint_impl=(
CheckpointImpl.NO_REENTRANT
if not self.cfg.ACTIVATION_CHECKPOINT.REENTRANT
else CheckpointImpl.REENTRANT
),
)
policy_name = self.cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_POLICY
assert (
......
......@@ -56,7 +56,7 @@ def run_once(
`MultipleFunctionCallError`.
"""
def decorator(func: Callable[..., T]) -> (Callable[..., T]):
def decorator(func: Callable[..., T]) -> Callable[..., T]:
signal: List[T] = []
@wraps(func)
......
......@@ -10,9 +10,9 @@ def get_lt_trainer(output_dir: str, cfg):
return pl.Trainer(
max_epochs=10**8,
max_steps=cfg.SOLVER.MAX_ITER,
val_check_interval=cfg.TEST.EVAL_PERIOD
if cfg.TEST.EVAL_PERIOD > 0
else cfg.SOLVER.MAX_ITER,
val_check_interval=(
cfg.TEST.EVAL_PERIOD if cfg.TEST.EVAL_PERIOD > 0 else cfg.SOLVER.MAX_ITER
),
callbacks=[checkpoint_callback],
logger=False,
)
......
......@@ -90,9 +90,11 @@ def main():
logger.info(
"{}: {} in {:.2f}s".format(
path,
"detected {} instances".format(len(predictions["instances"]))
if "instances" in predictions
else "finished",
(
"detected {} instances".format(len(predictions["instances"]))
if "instances" in predictions
else "finished"
),
time.time() - start_time,
)
)
......
......@@ -67,9 +67,9 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
params = {
"max_epochs": -1,
"max_steps": cfg.SOLVER.MAX_ITER,
"val_check_interval": cfg.TEST.EVAL_PERIOD
if cfg.TEST.EVAL_PERIOD > 0
else cfg.SOLVER.MAX_ITER,
"val_check_interval": (
cfg.TEST.EVAL_PERIOD if cfg.TEST.EVAL_PERIOD > 0 else cfg.SOLVER.MAX_ITER
),
"num_nodes": comm.get_num_nodes(),
"devices": comm.get_local_size(),
"strategy": strategy,
......@@ -78,11 +78,11 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
"logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
"num_sanity_val_steps": 0,
"replace_sampler_ddp": False,
"precision": parse_precision_from_string(
cfg.SOLVER.AMP.PRECISION, lightning=True
)
if cfg.SOLVER.AMP.ENABLED
else 32,
"precision": (
parse_precision_from_string(cfg.SOLVER.AMP.PRECISION, lightning=True)
if cfg.SOLVER.AMP.ENABLED
else 32
),
}
if cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment