"vscode:/vscode.git/clone" did not exist on "ec1aded12e8fbdc98d6703f9111ed824addebc77"
Commit ab49d0b6 authored by Anton Rigner's avatar Anton Rigner Committed by Facebook GitHub Bot
Browse files

Fix WeightedSampler to also work with adhoc datasets

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/437

# Problem
- We use `TRAIN_CATEGORIES` to overrider the classes for convenient experimentation, to not have to re-map the JSON file
- But it's not possible to use the WeightedTrainingSampler with specified repeat factors (`DATASETS.TRAIN_REPEAT_FACTOR`) when also overriding the classes to use for training (ad-hoc datasets), because the underlying dataset name doesn't match the datasets specified in the `TRAIN_REPEAT_FACTOR` pairs (mapping between <dataset_name, repeat_factor>)

# Fix

- Update the dataset names for the REPEAT_FACTORS mapping as well, if we have enabled the `WeightedTrainingSampler` and use ad-hoc datasets.

Reviewed By: wat3rBro

Differential Revision: D41765638

fbshipit-source-id: 51dad484e4d715d2de900b5d0b7c7caa19903fb7
parent b5e5b0ad
......@@ -393,6 +393,16 @@ def update_cfg_if_using_adhoc_dataset(cfg):
with temp_defrost(cfg):
cfg.DATASETS.TRAIN = tuple(ds.new_ds_name for ds in new_train_datasets)
# If present, we also need to update the data set names for the WeightedTrainingSampler
if cfg.DATASETS.TRAIN_REPEAT_FACTOR:
for ds_to_repeat_factor in cfg.DATASETS.TRAIN_REPEAT_FACTOR:
original_ds_name = ds_to_repeat_factor[0]
# Search corresponding data set name, to not rely on the order
for ds in new_train_datasets:
if ds.src_ds_name == original_ds_name:
ds_to_repeat_factor[0] = ds.new_ds_name
break
if cfg.D2GO_DATA.DATASETS.TEST_CATEGORIES:
new_test_datasets = [
COCOWithClassesToUse(ds, cfg.D2GO_DATA.DATASETS.TEST_CATEGORIES)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment