Commit f50c306a authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: support merge_all_iters_to_one_epoch in dataloader

parent ed6f3dd2
...@@ -33,9 +33,8 @@ class DistributedSampler(_DistributedSampler): ...@@ -33,9 +33,8 @@ class DistributedSampler(_DistributedSampler):
return iter(indices) return iter(indices)
def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None, workers=4, def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None, workers=4,
logger=None, training=True): logger=None, training=True, merge_all_iters_to_one_epoch=False, total_epochs=0):
dataset = __all__[dataset_cfg.DATASET]( dataset = __all__[dataset_cfg.DATASET](
dataset_cfg=dataset_cfg, dataset_cfg=dataset_cfg,
...@@ -44,6 +43,11 @@ def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None, ...@@ -44,6 +43,11 @@ def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None,
training=training, training=training,
logger=logger, logger=logger,
) )
if merge_all_iters_to_one_epoch:
assert hasattr(dataset, 'merge_all_iters_to_one_epoch')
dataset.merge_all_iters_to_one_epoch(merge=True, epochs=total_epochs)
if dist: if dist:
if training: if training:
sampler = torch.utils.data.distributed.DistributedSampler(dataset) sampler = torch.utils.data.distributed.DistributedSampler(dataset)
......
...@@ -96,7 +96,9 @@ def main(): ...@@ -96,7 +96,9 @@ def main():
batch_size=args.batch_size, batch_size=args.batch_size,
dist=dist_train, workers=args.workers, dist=dist_train, workers=args.workers,
logger=logger, logger=logger,
training=True training=True,
merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
total_epochs=args.epochs
) )
model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=train_set) model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=train_set)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment