Commit 88babd54 authored by Gus-Guo's avatar Gus-Guo
Browse files

avoid logger being pickled

parent eaae7f47
......@@ -13,8 +13,8 @@ class DataAugmentor(object):
self.data_augmentor_queue = []
for cur_cfg in augmentor_configs:
cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
self.data_augmentor_queue.append(cur_augmentor)
self.data_augmentor_queue.append(cur_augmentor)
def gt_sampling(self, config=None):
db_sampler = database_sampler.DataBaseSampler(
root_path=self.root_path,
......@@ -24,6 +24,13 @@ class DataAugmentor(object):
)
return db_sampler
def __getstate__(self):
d = dict(self.__dict__)
del d['logger']
return d
def __setstate__(self, d):
self.__dict__.update(d)
def random_world_flip(self, data_dict=None, config=None):
if data_dict is None:
......
......@@ -37,6 +37,14 @@ class DataBaseSampler(object):
'indices': np.arange(len(self.db_infos[class_name]))
}
def __getstate__(self):
d = dict(self.__dict__)
del d['logger']
return d
def __setstate__(self, d):
self.__dict__.update(d)
def filter_by_difficulty(self, db_infos, removed_difficulty):
new_db_infos = {}
for key, dinfos in db_infos.items():
......
......@@ -39,6 +39,14 @@ class DatasetTemplate(torch_data.Dataset):
def mode(self):
return 'train' if self.training else 'test'
def __getstate__(self):
d = dict(self.__dict__)
del d['logger']
return d
def __setstate__(self, d):
self.__dict__.update(d)
def __len__(self):
raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment