"src/diffusers/commands/diffusers_cli.py" did not exist on "07ffe73f796db1c19555fee04711f1ab71a92de2"
Commit 94b027bb authored by Devin Zhou's avatar Devin Zhou Committed by Facebook GitHub Bot
Browse files

Enable Class Balancing for Model Train Sampler

Summary:
This diff enables both category and datasets weight balancing at the same time by declaring "WeightedCategoryTrainingSampler" under "SAMPLER_TRAIN" in config file.

X-link: https://github.com/facebookresearch/detectron2/pull/4995

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/570

Reviewed By: jiaxuzhu92, shiyud

Differential Revision: D46377371

fbshipit-source-id: 4e8bdf6a7e5d40b04072cb99637d13d85b2e0fce
parent 78328839
...@@ -66,7 +66,9 @@ def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]: ...@@ -66,7 +66,9 @@ def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]:
return name_to_weight return name_to_weight
def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None): def build_weighted_detection_train_loader(
cfg: CfgNode, mapper=None, enable_category_balance=False
):
dataset_repeat_factors = get_train_datasets_repeat_factors(cfg) dataset_repeat_factors = get_train_datasets_repeat_factors(cfg)
# OrderedDict to guarantee order of values() consistent with repeat factors # OrderedDict to guarantee order of values() consistent with repeat factors
dataset_name_to_dicts = OrderedDict( dataset_name_to_dicts = OrderedDict(
...@@ -98,12 +100,39 @@ def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None): ...@@ -98,12 +100,39 @@ def build_weighted_detection_train_loader(cfg: CfgNode, mapper=None):
mapper = DatasetMapper(cfg, True) mapper = DatasetMapper(cfg, True)
dataset = MapDataset(dataset, mapper) dataset = MapDataset(dataset, mapper)
logger.info( repeat_factors = torch.tensor(repeat_factors)
"Using WeightedTrainingSampler with repeat_factors={}".format( if enable_category_balance:
cfg.DATASETS.TRAIN_REPEAT_FACTOR """
1. Calculate repeat factors using category frequency for each dataset and then merge them.
2. Element wise dot producting the dataset frequency repeat factors with
the category frequency repeat factors gives the final repeat factors.
"""
category_repeat_factors = [
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD
)
for dataset_dict in dataset_name_to_dicts.values()
]
# flatten the category repeat factors from all datasets
category_repeat_factors = list(
itertools.chain.from_iterable(category_repeat_factors)
)
category_repeat_factors = torch.tensor(category_repeat_factors)
repeat_factors = torch.mul(category_repeat_factors, repeat_factors)
repeat_factors = repeat_factors / torch.min(repeat_factors)
logger.info(
"Using WeightedCategoryTrainingSampler with repeat_factors={}".format(
cfg.DATASETS.TRAIN_REPEAT_FACTOR
)
)
else:
logger.info(
"Using WeightedTrainingSampler with repeat_factors={}".format(
cfg.DATASETS.TRAIN_REPEAT_FACTOR
)
) )
)
sampler = RepeatFactorTrainingSampler(torch.tensor(repeat_factors)) sampler = RepeatFactorTrainingSampler(repeat_factors)
return build_batch_data_loader( return build_batch_data_loader(
dataset, dataset,
...@@ -149,7 +178,13 @@ def build_clip_grouping_data_loader(dataset, sampler, total_batch_size, num_work ...@@ -149,7 +178,13 @@ def build_clip_grouping_data_loader(dataset, sampler, total_batch_size, num_work
@fb_overwritable() @fb_overwritable()
def build_mapped_train_loader(cfg, mapper): def build_mapped_train_loader(cfg, mapper):
if cfg.DATALOADER.SAMPLER_TRAIN == "WeightedTrainingSampler": if cfg.DATALOADER.SAMPLER_TRAIN == "WeightedTrainingSampler":
# balancing only datasets frequencies
data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper) data_loader = build_weighted_detection_train_loader(cfg, mapper=mapper)
elif cfg.DATALOADER.SAMPLER_TRAIN == "WeightedCategoryTrainingSampler":
# balancing both datasets and its categories
data_loader = build_weighted_detection_train_loader(
cfg, mapper=mapper, enable_category_balance=True
)
else: else:
data_loader = build_detection_train_loader(cfg, mapper=mapper) data_loader = build_detection_train_loader(cfg, mapper=mapper)
return data_loader return data_loader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment