Commit a03774d6 authored by YuanLiuuuuuu's avatar YuanLiuuuuuu Committed by zhouzaida
Browse files

[Refactor]: Grascale return uint8 type

parent 5867a97a
...@@ -692,6 +692,7 @@ class RandomGrayscale(BaseTransform): ...@@ -692,6 +692,7 @@ class RandomGrayscale(BaseTransform):
normalized_weights = ( normalized_weights = (
np.array(self.channel_weights) / sum(self.channel_weights)) np.array(self.channel_weights) / sum(self.channel_weights))
img = (normalized_weights * img).sum(axis=2) img = (normalized_weights * img).sum(axis=2)
img = img.astype('uint8')
if self.keep_channels: if self.keep_channels:
img = img[:, :, None] img = img[:, :, None]
results['img'] = np.dstack( results['img'] = np.dstack(
...@@ -699,6 +700,7 @@ class RandomGrayscale(BaseTransform): ...@@ -699,6 +700,7 @@ class RandomGrayscale(BaseTransform):
else: else:
results['img'] = img results['img'] = img
return results return results
img = img.astype('uint8')
results['img'] = img results['img'] = img
return results return results
......
...@@ -474,7 +474,7 @@ class TestRandomGrayscale: ...@@ -474,7 +474,7 @@ class TestRandomGrayscale:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
cls.img = np.random.rand(10, 10, 3).astype(np.float32) cls.img = (np.random.rand(10, 10, 3) * 255).astype(np.uint8)
def test_repr(self): def test_repr(self):
# test repr # test repr
...@@ -504,9 +504,9 @@ class TestRandomGrayscale: ...@@ -504,9 +504,9 @@ class TestRandomGrayscale:
random_gray_scale_module = TRANSFORMS.build(transform) random_gray_scale_module = TRANSFORMS.build(transform)
results['img'] = copy.deepcopy(self.img) results['img'] = copy.deepcopy(self.img)
img = random_gray_scale_module(results)['img'] img = random_gray_scale_module(results)['img']
computed_gray = ( computed_gray = (self.img[:, :, 0] * 0.299 +
self.img[:, :, 0] * 0.299 + self.img[:, :, 1] * 0.587 + self.img[:, :, 1] * 0.587 +
self.img[:, :, 2] * 0.114) self.img[:, :, 2] * 0.114).astype(np.uint8)
for i in range(img.shape[2]): for i in range(img.shape[2]):
assert_array_almost_equal(img[:, :, i], computed_gray, decimal=4) assert_array_almost_equal(img[:, :, i], computed_gray, decimal=4)
assert img.shape == (10, 10, 3) assert img.shape == (10, 10, 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment