Unverified Commit bc75766c authored by Cao Yuhang's avatar Cao Yuhang Committed by GitHub
Browse files

convert masks to ndarray in transforms (#2030)

parent 80ae224f
......@@ -156,7 +156,7 @@ class Resize(object):
mmcv.imresize(mask, mask_size, interpolation='nearest')
for mask in results[key]
]
results[key] = masks
results[key] = np.stack(masks)
def _resize_seg(self, results):
for key in results.get('seg_fields', []):
......@@ -245,10 +245,10 @@ class RandomFlip(object):
results['flip_direction'])
# flip masks
for key in results.get('mask_fields', []):
results[key] = [
results[key] = np.stack([
mmcv.imflip(mask, direction=results['flip_direction'])
for mask in results[key]
]
])
# flip segs
for key in results.get('seg_fields', []):
......@@ -410,7 +410,7 @@ class RandomCrop(object):
gt_mask = results['gt_masks'][i][crop_y1:crop_y2,
crop_x1:crop_x2]
valid_gt_masks.append(gt_mask)
results['gt_masks'] = valid_gt_masks
results['gt_masks'] = np.stack(valid_gt_masks)
return results
......@@ -586,7 +586,7 @@ class Expand(object):
0).astype(mask.dtype)
expand_mask[top:top + h, left:left + w] = mask
expand_gt_masks.append(expand_mask)
results['gt_masks'] = expand_gt_masks
results['gt_masks'] = np.stack(expand_gt_masks)
# not tested
if 'gt_semantic_seg' in results:
......@@ -678,10 +678,10 @@ class MinIoURandomCrop(object):
results['gt_masks'][i] for i in range(len(mask))
if mask[i]
]
results['gt_masks'] = [
results['gt_masks'] = np.stack([
gt_mask[patch[1]:patch[3], patch[0]:patch[2]]
for gt_mask in valid_masks
]
])
# not tested
if 'gt_semantic_seg' in results:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment