"examples/pytorch/vscode:/vscode.git/clone" did not exist on "ba110e50e61f19e6cce3a7bf6166d92000e2641d"
Unverified Commit bc75766c authored by Cao Yuhang's avatar Cao Yuhang Committed by GitHub
Browse files

convert masks to ndarray in transforms (#2030)

parent 80ae224f
...@@ -156,7 +156,7 @@ class Resize(object): ...@@ -156,7 +156,7 @@ class Resize(object):
mmcv.imresize(mask, mask_size, interpolation='nearest') mmcv.imresize(mask, mask_size, interpolation='nearest')
for mask in results[key] for mask in results[key]
] ]
results[key] = masks results[key] = np.stack(masks)
def _resize_seg(self, results): def _resize_seg(self, results):
for key in results.get('seg_fields', []): for key in results.get('seg_fields', []):
...@@ -245,10 +245,10 @@ class RandomFlip(object): ...@@ -245,10 +245,10 @@ class RandomFlip(object):
results['flip_direction']) results['flip_direction'])
# flip masks # flip masks
for key in results.get('mask_fields', []): for key in results.get('mask_fields', []):
results[key] = [ results[key] = np.stack([
mmcv.imflip(mask, direction=results['flip_direction']) mmcv.imflip(mask, direction=results['flip_direction'])
for mask in results[key] for mask in results[key]
] ])
# flip segs # flip segs
for key in results.get('seg_fields', []): for key in results.get('seg_fields', []):
...@@ -410,7 +410,7 @@ class RandomCrop(object): ...@@ -410,7 +410,7 @@ class RandomCrop(object):
gt_mask = results['gt_masks'][i][crop_y1:crop_y2, gt_mask = results['gt_masks'][i][crop_y1:crop_y2,
crop_x1:crop_x2] crop_x1:crop_x2]
valid_gt_masks.append(gt_mask) valid_gt_masks.append(gt_mask)
results['gt_masks'] = valid_gt_masks results['gt_masks'] = np.stack(valid_gt_masks)
return results return results
...@@ -586,7 +586,7 @@ class Expand(object): ...@@ -586,7 +586,7 @@ class Expand(object):
0).astype(mask.dtype) 0).astype(mask.dtype)
expand_mask[top:top + h, left:left + w] = mask expand_mask[top:top + h, left:left + w] = mask
expand_gt_masks.append(expand_mask) expand_gt_masks.append(expand_mask)
results['gt_masks'] = expand_gt_masks results['gt_masks'] = np.stack(expand_gt_masks)
# not tested # not tested
if 'gt_semantic_seg' in results: if 'gt_semantic_seg' in results:
...@@ -678,10 +678,10 @@ class MinIoURandomCrop(object): ...@@ -678,10 +678,10 @@ class MinIoURandomCrop(object):
results['gt_masks'][i] for i in range(len(mask)) results['gt_masks'][i] for i in range(len(mask))
if mask[i] if mask[i]
] ]
results['gt_masks'] = [ results['gt_masks'] = np.stack([
gt_mask[patch[1]:patch[3], patch[0]:patch[2]] gt_mask[patch[1]:patch[3], patch[0]:patch[2]]
for gt_mask in valid_masks for gt_mask in valid_masks
] ])
# not tested # not tested
if 'gt_semantic_seg' in results: if 'gt_semantic_seg' in results:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment