Unverified Commit 82ce5194 authored by yukang's avatar yukang Committed by GitHub
Browse files

Update dataset.py

parent 867e41a0
...@@ -122,15 +122,17 @@ class DatasetTemplate(torch_data.Dataset): ...@@ -122,15 +122,17 @@ class DatasetTemplate(torch_data.Dataset):
if self.training: if self.training:
assert 'gt_boxes' in data_dict, 'gt_boxes should be provided for training' assert 'gt_boxes' in data_dict, 'gt_boxes should be provided for training'
gt_boxes_mask = np.array([n in self.class_names for n in data_dict['gt_names']], dtype=np.bool_) gt_boxes_mask = np.array([n in self.class_names for n in data_dict['gt_names']], dtype=np.bool_)
calib = data_dict['calib'] if 'calib' in data_dict:
calib = data_dict['calib']
data_dict = self.data_augmentor.forward( data_dict = self.data_augmentor.forward(
data_dict={ data_dict={
**data_dict, **data_dict,
'gt_boxes_mask': gt_boxes_mask 'gt_boxes_mask': gt_boxes_mask
} }
) )
data_dict['calib'] = calib if 'calib' in data_dict:
data_dict['calib'] = calib
if data_dict.get('gt_boxes', None) is not None: if data_dict.get('gt_boxes', None) is not None:
selected = common_utils.keep_arrays_by_name(data_dict['gt_names'], self.class_names) selected = common_utils.keep_arrays_by_name(data_dict['gt_names'], self.class_names)
data_dict['gt_boxes'] = data_dict['gt_boxes'][selected] data_dict['gt_boxes'] = data_dict['gt_boxes'][selected]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment