Commit 8c86f74c authored by Kamran Melikov's avatar Kamran Melikov Committed by Kai Chen
Browse files

Non color images (#1976)

* First Draft

On branch non-color-images
Changes to be committed:
modified:   mmdet/datasets/pipelines/loading.py

* Add option to load non color images

Add 'color_type' parameter to LoadImageFromFile class
Change __repr__ method accordingly
Since non-color images maybe two dimensional expand image
dimensions if necessary in DefaultFormatBundle and
ImageToTensor classes

Changes to be committed:
    modified:   mmdet/datasets/pipelines/formating.py
    modified:   mmdet/datasets/pipelines/loading.py

* Fix RandomCrop to work with grayscale images

Changes to be committed:
    modified:   mmdet/datasets/pipelines/transforms.py

* Modify retrieving w, h of padded image in anchor heads

This addreses problems with single channel images for which the
shape is  tuple with 2 values

Changes to be committed:
    modified:   mmdet/models/anchor_heads/anchor_head.py
    modified:   mmdet/models/anchor_heads/guided_anchor_head.py
    modified:   mmdet/models/anchor_heads/reppoints_head.py
parent e60d34af
...@@ -52,7 +52,10 @@ class ImageToTensor(object): ...@@ -52,7 +52,10 @@ class ImageToTensor(object):
def __call__(self, results): def __call__(self, results):
for key in self.keys: for key in self.keys:
results[key] = to_tensor(results[key].transpose(2, 0, 1)) img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
results[key] = to_tensor(img.transpose(2, 0, 1))
return results return results
def __repr__(self): def __repr__(self):
...@@ -115,7 +118,10 @@ class DefaultFormatBundle(object): ...@@ -115,7 +118,10 @@ class DefaultFormatBundle(object):
def __call__(self, results): def __call__(self, results):
if 'img' in results: if 'img' in results:
img = np.ascontiguousarray(results['img'].transpose(2, 0, 1)) img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
results['img'] = DC(to_tensor(img), stack=True) results['img'] = DC(to_tensor(img), stack=True)
for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']: for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']:
if key not in results: if key not in results:
......
...@@ -10,8 +10,9 @@ from ..registry import PIPELINES ...@@ -10,8 +10,9 @@ from ..registry import PIPELINES
@PIPELINES.register_module @PIPELINES.register_module
class LoadImageFromFile(object): class LoadImageFromFile(object):
def __init__(self, to_float32=False): def __init__(self, to_float32=False, color_type='color'):
self.to_float32 = to_float32 self.to_float32 = to_float32
self.color_type = color_type
def __call__(self, results): def __call__(self, results):
if results['img_prefix'] is not None: if results['img_prefix'] is not None:
...@@ -19,7 +20,7 @@ class LoadImageFromFile(object): ...@@ -19,7 +20,7 @@ class LoadImageFromFile(object):
results['img_info']['filename']) results['img_info']['filename'])
else: else:
filename = results['img_info']['filename'] filename = results['img_info']['filename']
img = mmcv.imread(filename) img = mmcv.imread(filename, self.color_type)
if self.to_float32: if self.to_float32:
img = img.astype(np.float32) img = img.astype(np.float32)
results['filename'] = filename results['filename'] = filename
...@@ -29,8 +30,8 @@ class LoadImageFromFile(object): ...@@ -29,8 +30,8 @@ class LoadImageFromFile(object):
return results return results
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(to_float32={})'.format( return '{} (to_float32={}, color_type={})'.format(
self.to_float32) self.__class__.__name__, self.to_float32, self.color_type)
@PIPELINES.register_module @PIPELINES.register_module
......
...@@ -364,7 +364,7 @@ class RandomCrop(object): ...@@ -364,7 +364,7 @@ class RandomCrop(object):
crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
# crop the image # crop the image
img = img[crop_y1:crop_y2, crop_x1:crop_x2, :] img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
img_shape = img.shape img_shape = img.shape
results['img'] = img results['img'] = img
results['img_shape'] = img_shape results['img_shape'] = img_shape
......
...@@ -127,7 +127,7 @@ class AnchorHead(nn.Module): ...@@ -127,7 +127,7 @@ class AnchorHead(nn.Module):
for i in range(num_levels): for i in range(num_levels):
anchor_stride = self.anchor_strides[i] anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i] feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape'] h, w = img_meta['pad_shape'][:2]
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h) valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w) valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags( flags = self.anchor_generators[i].valid_flags(
......
...@@ -246,7 +246,7 @@ class GuidedAnchorHead(AnchorHead): ...@@ -246,7 +246,7 @@ class GuidedAnchorHead(AnchorHead):
approxs = multi_level_approxs[i] approxs = multi_level_approxs[i]
anchor_stride = self.anchor_strides[i] anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i] feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape'] h, w = img_meta['pad_shape'][:2]
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h) valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w) valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.approx_generators[i].valid_flags( flags = self.approx_generators[i].valid_flags(
......
...@@ -320,7 +320,7 @@ class RepPointsHead(nn.Module): ...@@ -320,7 +320,7 @@ class RepPointsHead(nn.Module):
for i in range(num_levels): for i in range(num_levels):
point_stride = self.point_strides[i] point_stride = self.point_strides[i]
feat_h, feat_w = featmap_sizes[i] feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape'] h, w = img_meta['pad_shape'][:2]
valid_feat_h = min(int(np.ceil(h / point_stride)), feat_h) valid_feat_h = min(int(np.ceil(h / point_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / point_stride)), feat_w) valid_feat_w = min(int(np.ceil(w / point_stride)), feat_w)
flags = self.point_generators[i].valid_flags( flags = self.point_generators[i].valid_flags(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment