Unverified Commit e1d289c1 authored by missionfloyd's avatar missionfloyd Committed by GitHub
Browse files

use slice instead of torch.select()

parent e12fb88b
...@@ -1076,7 +1076,7 @@ class ImageToMask: ...@@ -1076,7 +1076,7 @@ class ImageToMask:
def image_to_mask(self, image, channel): def image_to_mask(self, image, channel):
channels = ["red", "green", "blue"] channels = ["red", "green", "blue"]
mask = torch.select(image[0], 2, channels.index(channel)) mask = image[0, :, :, channels.index(channel)]
return (mask,) return (mask,)
class MaskToImage: class MaskToImage:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment