Commit b9209b69 authored by Fu-Chen Chen's avatar Fu-Chen Chen Committed by Facebook GitHub Bot
Browse files

add default config for RandomSubsetTrainingSampler in D2go

Summary:
Add default config for RandomSubsetTrainingSampler in D2 (https://github.com/facebookresearch/d2go/commit/3ee8885047e7ffb9eadcc6a1ecf8253c7ce9f79e)go.

User can use use the RandomSubsetTrainingSampler with the following yaml configs
```
DATALOADER:
  SAMPLER_TRAIN: RandomSubsetTrainingSampler
  RANDOM_SUBSET_RATIO: [Desired_ratio]  # for RandomSubsetTrainingSampler
```

Reviewed By: XiaoliangDai

Differential Revision: D29892366

fbshipit-source-id: cabb67fb46e51a93a8342a42f77a8a4d23a933e9
parent 3ee88850
...@@ -44,6 +44,14 @@ def add_weighted_training_sampler_default_configs(cfg: CfgNode): ...@@ -44,6 +44,14 @@ def add_weighted_training_sampler_default_configs(cfg: CfgNode):
cfg.DATASETS.TRAIN_REPEAT_FACTOR = [] cfg.DATASETS.TRAIN_REPEAT_FACTOR = []
def add_random_subset_training_sampler_default_configs(cfg: CfgNode):
"""
Add default cfg.DATALOADER.RANDOM_SUBSET_RATIO for RandomSubsetTrainingSampler
The CfgNode under cfg.DATALOADER.RANDOM_SUBSET_RATIO should be a float > 0 and <= 1
"""
cfg.DATALOADER.RANDOM_SUBSET_RATIO = 1.
def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]: def get_train_datasets_repeat_factors(cfg: CfgNode) -> Dict[str, float]:
repeat_factors = cfg.DATASETS.TRAIN_REPEAT_FACTOR repeat_factors = cfg.DATASETS.TRAIN_REPEAT_FACTOR
assert all(len(tup) == 2 for tup in repeat_factors) assert all(len(tup) == 2 for tup in repeat_factors)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from d2go.config import CfgNode as CN from d2go.config import CfgNode as CN
from d2go.data.build import ( from d2go.data.build import (
add_weighted_training_sampler_default_configs, add_weighted_training_sampler_default_configs,
add_random_subset_training_sampler_default_configs,
) )
from d2go.data.config import add_d2go_data_default_configs from d2go.data.config import add_d2go_data_default_configs
from d2go.modeling.backbone.fbnet_cfg import ( from d2go.modeling.backbone.fbnet_cfg import (
...@@ -52,6 +53,8 @@ def get_default_cfg(_C): ...@@ -52,6 +53,8 @@ def get_default_cfg(_C):
add_quantization_default_configs(_C) add_quantization_default_configs(_C)
# _C.DATASETS.TRAIN_REPEAT_FACTOR # _C.DATASETS.TRAIN_REPEAT_FACTOR
add_weighted_training_sampler_default_configs(_C) add_weighted_training_sampler_default_configs(_C)
# _C.DATALOADER.RANDOM_SUBSET_RATIO
add_random_subset_training_sampler_default_configs(_C)
# _C.ABNORMAL_CHECKER # _C.ABNORMAL_CHECKER
add_abnormal_checker_configs(_C) add_abnormal_checker_configs(_C)
# _C.MODEL.SUBCLASS # _C.MODEL.SUBCLASS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment