"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "67b3d3267e407ca6d9d5c41dc37f0f7d3ae29116"
Commit 49ffc846 authored by Wei Ye's avatar Wei Ye Committed by Facebook GitHub Bot
Browse files

print sampling probability for WeightedTrainingSampler

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/610

As titled

Reviewed By: wat3rBro

Differential Revision: D48461077

fbshipit-source-id: f0bfd0dc9b8615b958a68d35c3df25a6c52859c0
parent f7e1b47e
......@@ -23,6 +23,7 @@ from detectron2.data.dataset_mapper import DatasetMapper
from detectron2.data.samplers import RepeatFactorTrainingSampler
from detectron2.utils.comm import get_world_size
from mobile_cv.common.misc.oss_utils import fb_overwritable
from tabulate import tabulate
logger = logging.getLogger(__name__)
......@@ -66,6 +67,27 @@ def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]:
return name_to_weight
def get_sampling_probability_table(
dataset_sizes: Dict[str, int], dataset_repeat_factors: Dict[str, float]
) -> str:
total_sum = sum(
dataset_repeat_factors.get(dsname, 1.0) * size
for dsname, size in dataset_sizes.items()
)
sample_prob_data = [
(
dsname,
size,
dataset_repeat_factors.get(dsname, 1.0),
(dataset_repeat_factors.get(dsname, 1.0) * size) * 100 / total_sum,
)
for dsname, size in dataset_sizes.items()
]
headers = ["Dataset", "Samples", "Repeat factor", "Sample Prob (%)"]
table = tabulate(sample_prob_data, headers=headers, tablefmt="pipe")
return table
def build_weighted_detection_train_loader(
cfg: CfgNode, mapper=None, enable_category_balance=False
):
......@@ -91,6 +113,11 @@ def build_weighted_detection_train_loader(
[dataset_repeat_factors[dsname]] * len(dataset_name_to_dicts[dsname])
for dsname in cfg.DATASETS.TRAIN
]
sampling_prob_table = get_sampling_probability_table(
{dsname: len(dataset_name_to_dicts[dsname]) for dsname in cfg.DATASETS.TRAIN},
dataset_repeat_factors,
)
logger.info("Dataset TRAIN sampling probability: \n" + sampling_prob_table)
repeat_factors = list(itertools.chain.from_iterable(repeat_factors))
dataset_dicts = dataset_name_to_dicts.values()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment