Commit 2256bdb7 authored by Amethyst Reese's avatar Amethyst Reese Committed by Facebook GitHub Bot
Browse files

apply Black 2024 style in fbcode (7/16)

Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447732

fbshipit-source-id: e21fbbe27882c8af183d021f4ac27029cbe93e8e
parent 09bd2869
...@@ -98,12 +98,16 @@ def build_weighted_detection_train_loader( ...@@ -98,12 +98,16 @@ def build_weighted_detection_train_loader(
name: get_detection_dataset_dicts( name: get_detection_dataset_dicts(
[name], [name],
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE min_keypoints=(
if cfg.MODEL.KEYPOINT_ON cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
else 0, if cfg.MODEL.KEYPOINT_ON
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN else 0
if cfg.MODEL.LOAD_PROPOSALS ),
else None, proposal_files=(
cfg.DATASETS.PROPOSAL_FILES_TRAIN
if cfg.MODEL.LOAD_PROPOSALS
else None
),
) )
for name in cfg.DATASETS.TRAIN for name in cfg.DATASETS.TRAIN
} }
......
...@@ -67,9 +67,7 @@ class DiskCachedList(object): ...@@ -67,9 +67,7 @@ class DiskCachedList(object):
self._lst = [_serialize(x) for x in self._lst] self._lst = [_serialize(x) for x in self._lst]
total_size = sum(len(x) for x in self._lst) total_size = sum(len(x) for x in self._lst)
# TODO: only enabling DiskCachedDataset for large enough dataset # TODO: only enabling DiskCachedDataset for large enough dataset
logger.info( logger.info("Serialized dataset takes {:.2f} MiB".format(total_size / 1024**2))
"Serialized dataset takes {:.2f} MiB".format(total_size / 1024**2)
)
self._initialize_diskcache() self._initialize_diskcache()
def _initialize_diskcache(self): def _initialize_diskcache(self):
......
...@@ -363,7 +363,7 @@ def convert_to_dict_list( ...@@ -363,7 +363,7 @@ def convert_to_dict_list(
default_record = {"dataset_name": dataset_name} if dataset_name else {} default_record = {"dataset_name": dataset_name} if dataset_name else {}
converted_dict_list = [] converted_dict_list = []
for (img_dict, anno_dict_list) in zip(imgs, anns): for img_dict, anno_dict_list in zip(imgs, anns):
record = copy.deepcopy(default_record) record = copy.deepcopy(default_record)
# NOTE: besides using (relative path) in the "file_name" filed to represent # NOTE: besides using (relative path) in the "file_name" filed to represent
......
...@@ -89,7 +89,7 @@ def extended_lvis_load(json_file, image_root, dataset_name=None): ...@@ -89,7 +89,7 @@ def extended_lvis_load(json_file, image_root, dataset_name=None):
dataset_dicts = [] dataset_dicts = []
count_ignore_image_root_warning = 0 count_ignore_image_root_warning = 0
for (img_dict, anno_dict_list) in imgs_anns: for img_dict, anno_dict_list in imgs_anns:
record = {} record = {}
if "://" not in img_dict["file_name"]: if "://" not in img_dict["file_name"]:
file_name = img_dict["file_name"] file_name = img_dict["file_name"]
......
...@@ -198,7 +198,7 @@ class EMAUpdater(object): ...@@ -198,7 +198,7 @@ class EMAUpdater(object):
averaged_model_parameters, model_parameters, 1.0 - decay averaged_model_parameters, model_parameters, 1.0 - decay
) )
if self.debug_lerp: if self.debug_lerp:
for (orig_val, lerp_val) in zip( for orig_val, lerp_val in zip(
orig_averaged_model_parameters, averaged_model_parameters orig_averaged_model_parameters, averaged_model_parameters
): ):
assert torch.allclose(orig_val, lerp_val, rtol=1e-4, atol=1e-3) assert torch.allclose(orig_val, lerp_val, rtol=1e-4, atol=1e-3)
......
...@@ -282,7 +282,6 @@ def mock_quantization_type(quant_func): ...@@ -282,7 +282,6 @@ def mock_quantization_type(quant_func):
def default_prepare_for_quant(cfg, model): def default_prepare_for_quant(cfg, model):
""" """
Default implementation of preparing a model for quantization. This function will Default implementation of preparing a model for quantization. This function will
be called to before training if QAT is enabled, or before calibration during PTQ if be called to before training if QAT is enabled, or before calibration during PTQ if
......
...@@ -418,9 +418,11 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -418,9 +418,11 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
def from_config(cls, cfg: CfgNode): def from_config(cls, cfg: CfgNode):
qat = cfg.QUANTIZATION.QAT qat = cfg.QUANTIZATION.QAT
callback = cls( callback = cls(
qconfig_dicts={submodule: None for submodule in cfg.QUANTIZATION.MODULES} qconfig_dicts=(
if cfg.QUANTIZATION.MODULES {submodule: None for submodule in cfg.QUANTIZATION.MODULES}
else None, if cfg.QUANTIZATION.MODULES
else None
),
# We explicitly pass this to maintain properties for now. # We explicitly pass this to maintain properties for now.
preserved_attrs=["model.backbone.size_divisibility"], preserved_attrs=["model.backbone.size_divisibility"],
start_step=qat.START_ITER, start_step=qat.START_ITER,
...@@ -576,9 +578,11 @@ class PostTrainingQuantization(Callback, QuantizationMixin): ...@@ -576,9 +578,11 @@ class PostTrainingQuantization(Callback, QuantizationMixin):
@classmethod @classmethod
def from_config(cls, cfg: CfgNode): def from_config(cls, cfg: CfgNode):
return cls( return cls(
qconfig_dicts={submodule: None for submodule in cfg.QUANTIZATION.MODULES} qconfig_dicts=(
if cfg.QUANTIZATION.MODULES {submodule: None for submodule in cfg.QUANTIZATION.MODULES}
else None, if cfg.QUANTIZATION.MODULES
else None
),
# We explicitly pass this to maintain properties for now. # We explicitly pass this to maintain properties for now.
preserved_attrs=["model.backbone.size_divisibility"], preserved_attrs=["model.backbone.size_divisibility"],
) )
......
...@@ -464,9 +464,11 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -464,9 +464,11 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
output_folder, output_folder,
train_iter, train_iter,
model_tag, model_tag,
model.module (
if isinstance(model, nn.parallel.DistributedDataParallel) model.module
else model, if isinstance(model, nn.parallel.DistributedDataParallel)
else model
),
) )
inference_callbacks = self._get_inference_callbacks() inference_callbacks = self._get_inference_callbacks()
......
...@@ -35,9 +35,11 @@ class ActivationCheckpointModelingHook(mh.ModelingHook): ...@@ -35,9 +35,11 @@ class ActivationCheckpointModelingHook(mh.ModelingHook):
logger.info("Activation Checkpointing is used") logger.info("Activation Checkpointing is used")
wrapper_fn = partial( wrapper_fn = partial(
checkpoint_wrapper, checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT checkpoint_impl=(
if not self.cfg.ACTIVATION_CHECKPOINT.REENTRANT CheckpointImpl.NO_REENTRANT
else CheckpointImpl.REENTRANT, if not self.cfg.ACTIVATION_CHECKPOINT.REENTRANT
else CheckpointImpl.REENTRANT
),
) )
policy_name = self.cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_POLICY policy_name = self.cfg.ACTIVATION_CHECKPOINT.AUTO_WRAP_POLICY
assert ( assert (
......
...@@ -56,7 +56,7 @@ def run_once( ...@@ -56,7 +56,7 @@ def run_once(
`MultipleFunctionCallError`. `MultipleFunctionCallError`.
""" """
def decorator(func: Callable[..., T]) -> (Callable[..., T]): def decorator(func: Callable[..., T]) -> Callable[..., T]:
signal: List[T] = [] signal: List[T] = []
@wraps(func) @wraps(func)
......
...@@ -10,9 +10,9 @@ def get_lt_trainer(output_dir: str, cfg): ...@@ -10,9 +10,9 @@ def get_lt_trainer(output_dir: str, cfg):
return pl.Trainer( return pl.Trainer(
max_epochs=10**8, max_epochs=10**8,
max_steps=cfg.SOLVER.MAX_ITER, max_steps=cfg.SOLVER.MAX_ITER,
val_check_interval=cfg.TEST.EVAL_PERIOD val_check_interval=(
if cfg.TEST.EVAL_PERIOD > 0 cfg.TEST.EVAL_PERIOD if cfg.TEST.EVAL_PERIOD > 0 else cfg.SOLVER.MAX_ITER
else cfg.SOLVER.MAX_ITER, ),
callbacks=[checkpoint_callback], callbacks=[checkpoint_callback],
logger=False, logger=False,
) )
......
...@@ -90,9 +90,11 @@ def main(): ...@@ -90,9 +90,11 @@ def main():
logger.info( logger.info(
"{}: {} in {:.2f}s".format( "{}: {} in {:.2f}s".format(
path, path,
"detected {} instances".format(len(predictions["instances"])) (
if "instances" in predictions "detected {} instances".format(len(predictions["instances"]))
else "finished", if "instances" in predictions
else "finished"
),
time.time() - start_time, time.time() - start_time,
) )
) )
......
...@@ -67,9 +67,9 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]: ...@@ -67,9 +67,9 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
params = { params = {
"max_epochs": -1, "max_epochs": -1,
"max_steps": cfg.SOLVER.MAX_ITER, "max_steps": cfg.SOLVER.MAX_ITER,
"val_check_interval": cfg.TEST.EVAL_PERIOD "val_check_interval": (
if cfg.TEST.EVAL_PERIOD > 0 cfg.TEST.EVAL_PERIOD if cfg.TEST.EVAL_PERIOD > 0 else cfg.SOLVER.MAX_ITER
else cfg.SOLVER.MAX_ITER, ),
"num_nodes": comm.get_num_nodes(), "num_nodes": comm.get_num_nodes(),
"devices": comm.get_local_size(), "devices": comm.get_local_size(),
"strategy": strategy, "strategy": strategy,
...@@ -78,11 +78,11 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]: ...@@ -78,11 +78,11 @@ def get_trainer_params(cfg: CfgNode) -> Dict[str, Any]:
"logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR), "logger": TensorBoardLogger(save_dir=cfg.OUTPUT_DIR),
"num_sanity_val_steps": 0, "num_sanity_val_steps": 0,
"replace_sampler_ddp": False, "replace_sampler_ddp": False,
"precision": parse_precision_from_string( "precision": (
cfg.SOLVER.AMP.PRECISION, lightning=True parse_precision_from_string(cfg.SOLVER.AMP.PRECISION, lightning=True)
) if cfg.SOLVER.AMP.ENABLED
if cfg.SOLVER.AMP.ENABLED else 32
else 32, ),
} }
if cfg.SOLVER.CLIP_GRADIENTS.ENABLED: if cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
if ( if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment