Commit 57994044 authored by yhcao6's avatar yhcao6
Browse files

fix util

parent d6b69bda
......@@ -8,6 +8,7 @@ import torch
import matplotlib.pyplot as plt
import numpy as np
from .concat_dataset import ConcatDataset
from .repeat_dataset import RepeatDataset
from .. import datasets
......@@ -74,6 +75,11 @@ def show_ann(coco, img, ann_info):
def get_dataset(data_cfg):
repeat_times = None
if data_cfg['type'] == 'RepeatDataset':
repeat_times = data_cfg['repeat_times']
data_cfg = data_cfg['dataset']
if isinstance(data_cfg['ann_file'], (list, tuple)):
ann_files = data_cfg['ann_file']
num_dset = len(ann_files)
......@@ -108,4 +114,7 @@ def get_dataset(data_cfg):
dset = ConcatDataset(dsets)
else:
dset = dsets[0]
if repeat_times is not None:
dset = RepeatDataset(dset, repeat_times)
return dset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment