Unverified Commit 77e41870 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Speed up equalize transform: use bincount instead of histc (#3493)

* use bincount instead of hist

* only use bincount when on CPU

* Added equality test for CPU vs cuda

* Fix flake8 and tests

* tuple instead of int for size
parent 414427dd
......@@ -977,6 +977,18 @@ class CUDATester(Tester):
def setUp(self):
self.device = "cuda"
def test_scale_channel(self):
"""Make sure that _scale_channel gives the same results on CPU and GPU as
histc or bincount are used depending on the device.
"""
# TODO: when # https://github.com/pytorch/pytorch/issues/53194 is fixed,
# only use bincount and remove that test.
size = (1_000,)
img_chan = torch.randint(0, 256, size=size).to('cpu')
scaled_cpu = F_t._scale_channel(img_chan)
scaled_cuda = F_t._scale_channel(img_chan.to('cuda'))
self.assertTrue(scaled_cpu.equal(scaled_cuda.to('cpu')))
if __name__ == '__main__':
unittest.main()
......@@ -902,7 +902,14 @@ def autocontrast(img: Tensor) -> Tensor:
def _scale_channel(img_chan):
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
# TODO: we should expect bincount to always be faster than histc, but this
# isn't always the case. Once
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
# block and only use bincount.
if img_chan.is_cuda:
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
else:
hist = torch.bincount(img_chan.view(-1), minlength=256)
nonzero_hist = hist[hist != 0]
step = nonzero_hist[:-1].sum() // 255
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment