Unverified Commit 7b87af25 authored by Sofiane Abbar's avatar Sofiane Abbar Committed by GitHub
Browse files

replaced deprecated call to ByteTensor with from_numpy (#3813)



replaced byteTensor with from_numpy

fixed lint issues and copy related worning
Co-authored-by: default avatarSofiane Abbar <sofa@fb.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent b0601631
...@@ -124,17 +124,13 @@ def to_tensor(pic): ...@@ -124,17 +124,13 @@ def to_tensor(pic):
return torch.from_numpy(nppic).to(dtype=default_float_dtype) return torch.from_numpy(nppic).to(dtype=default_float_dtype)
# handle PIL Image # handle PIL Image
if pic.mode == 'I': mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
img = torch.from_numpy(np.array(pic, np.int32, copy=False)) img = torch.from_numpy(
elif pic.mode == 'I;16': np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)
img = torch.from_numpy(np.array(pic, np.int16, copy=False)) )
elif pic.mode == 'F':
img = torch.from_numpy(np.array(pic, np.float32, copy=False))
elif pic.mode == '1':
img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
if pic.mode == '1':
img = 255 * img
img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format # put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous() img = img.permute((2, 0, 1)).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment