Commit 88babd54 authored by Gus-Guo's avatar Gus-Guo
Browse files

avoid logger being pickled

parent eaae7f47
...@@ -13,8 +13,8 @@ class DataAugmentor(object): ...@@ -13,8 +13,8 @@ class DataAugmentor(object):
self.data_augmentor_queue = [] self.data_augmentor_queue = []
for cur_cfg in augmentor_configs: for cur_cfg in augmentor_configs:
cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg) cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
self.data_augmentor_queue.append(cur_augmentor) self.data_augmentor_queue.append(cur_augmentor)
def gt_sampling(self, config=None): def gt_sampling(self, config=None):
db_sampler = database_sampler.DataBaseSampler( db_sampler = database_sampler.DataBaseSampler(
root_path=self.root_path, root_path=self.root_path,
...@@ -24,6 +24,13 @@ class DataAugmentor(object): ...@@ -24,6 +24,13 @@ class DataAugmentor(object):
) )
return db_sampler return db_sampler
def __getstate__(self):
d = dict(self.__dict__)
del d['logger']
return d
def __setstate__(self, d):
self.__dict__.update(d)
def random_world_flip(self, data_dict=None, config=None): def random_world_flip(self, data_dict=None, config=None):
if data_dict is None: if data_dict is None:
......
...@@ -37,6 +37,14 @@ class DataBaseSampler(object): ...@@ -37,6 +37,14 @@ class DataBaseSampler(object):
'indices': np.arange(len(self.db_infos[class_name])) 'indices': np.arange(len(self.db_infos[class_name]))
} }
def __getstate__(self):
d = dict(self.__dict__)
del d['logger']
return d
def __setstate__(self, d):
self.__dict__.update(d)
def filter_by_difficulty(self, db_infos, removed_difficulty): def filter_by_difficulty(self, db_infos, removed_difficulty):
new_db_infos = {} new_db_infos = {}
for key, dinfos in db_infos.items(): for key, dinfos in db_infos.items():
......
...@@ -39,6 +39,14 @@ class DatasetTemplate(torch_data.Dataset): ...@@ -39,6 +39,14 @@ class DatasetTemplate(torch_data.Dataset):
def mode(self): def mode(self):
return 'train' if self.training else 'test' return 'train' if self.training else 'test'
def __getstate__(self):
d = dict(self.__dict__)
del d['logger']
return d
def __setstate__(self, d):
self.__dict__.update(d)
def __len__(self): def __len__(self):
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment