Commit f50c306a authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: support merge_all_iters_to_one_epoch in dataloader

parent ed6f3dd2
......@@ -33,9 +33,8 @@ class DistributedSampler(_DistributedSampler):
return iter(indices)
def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None, workers=4,
logger=None, training=True):
logger=None, training=True, merge_all_iters_to_one_epoch=False, total_epochs=0):
dataset = __all__[dataset_cfg.DATASET](
dataset_cfg=dataset_cfg,
......@@ -44,6 +43,11 @@ def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None,
training=training,
logger=logger,
)
if merge_all_iters_to_one_epoch:
assert hasattr(dataset, 'merge_all_iters_to_one_epoch')
dataset.merge_all_iters_to_one_epoch(merge=True, epochs=total_epochs)
if dist:
if training:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
......
......@@ -96,7 +96,9 @@ def main():
batch_size=args.batch_size,
dist=dist_train, workers=args.workers,
logger=logger,
training=True
training=True,
merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
total_epochs=args.epochs
)
model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=train_set)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment