Unverified Commit 767b23ea authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Fixes no grad and range bugs in utils. (#3269)



* Fixes utils

* don't use any

* slightly simplify logic
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 631ff912
...@@ -2,23 +2,24 @@ from typing import Union, Optional, List, Tuple, Text, BinaryIO ...@@ -2,23 +2,24 @@ from typing import Union, Optional, List, Tuple, Text, BinaryIO
import pathlib import pathlib
import torch import torch
import math import math
import warnings
import numpy as np import numpy as np
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from PIL import ImageFont from PIL import ImageFont
__all__ = ["make_grid", "save_image", "draw_bounding_boxes"] __all__ = ["make_grid", "save_image", "draw_bounding_boxes"]
irange = range
@torch.no_grad()
def make_grid( def make_grid(
tensor: Union[torch.Tensor, List[torch.Tensor]], tensor: Union[torch.Tensor, List[torch.Tensor]],
nrow: int = 8, nrow: int = 8,
padding: int = 2, padding: int = 2,
normalize: bool = False, normalize: bool = False,
range: Optional[Tuple[int, int]] = None, value_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False, scale_each: bool = False,
pad_value: int = 0, pad_value: int = 0,
**kwargs
) -> torch.Tensor: ) -> torch.Tensor:
"""Make a grid of images. """Make a grid of images.
...@@ -30,7 +31,7 @@ def make_grid( ...@@ -30,7 +31,7 @@ def make_grid(
padding (int, optional): amount of padding. Default: ``2``. padding (int, optional): amount of padding. Default: ``2``.
normalize (bool, optional): If True, shift the image to the range (0, 1), normalize (bool, optional): If True, shift the image to the range (0, 1),
by the min and max values specified by :attr:`range`. Default: ``False``. by the min and max values specified by :attr:`range`. Default: ``False``.
range (tuple, optional): tuple (min, max) where min and max are numbers, value_range (tuple, optional): tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max then these numbers are used to normalize the image. By default, min and max
are computed from the tensor. are computed from the tensor.
scale_each (bool, optional): If ``True``, scale each image in the batch of scale_each (bool, optional): If ``True``, scale each image in the batch of
...@@ -43,7 +44,12 @@ def make_grid( ...@@ -43,7 +44,12 @@ def make_grid(
""" """
if not (torch.is_tensor(tensor) or if not (torch.is_tensor(tensor) or
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor))) raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
if "range" in kwargs.keys():
warning = "range will be deprecated, please use value_range instead."
warnings.warn(warning)
value_range = kwargs["range"]
# if list of tensors, convert to a 4D mini-batch Tensor # if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list): if isinstance(tensor, list):
...@@ -61,25 +67,25 @@ def make_grid( ...@@ -61,25 +67,25 @@ def make_grid(
if normalize is True: if normalize is True:
tensor = tensor.clone() # avoid modifying tensor in-place tensor = tensor.clone() # avoid modifying tensor in-place
if range is not None: if value_range is not None:
assert isinstance(range, tuple), \ assert isinstance(value_range, tuple), \
"range has to be a tuple (min, max) if specified. min and max are numbers" "value_range has to be a tuple (min, max) if specified. min and max are numbers"
def norm_ip(img, low, high): def norm_ip(img, low, high):
img.clamp_(min=low, max=high) img.clamp_(min=low, max=high)
img.sub_(low).div_(max(high - low, 1e-5)) img.sub_(low).div_(max(high - low, 1e-5))
def norm_range(t, range): def norm_range(t, value_range):
if range is not None: if value_range is not None:
norm_ip(t, range[0], range[1]) norm_ip(t, value_range[0], value_range[1])
else: else:
norm_ip(t, float(t.min()), float(t.max())) norm_ip(t, float(t.min()), float(t.max()))
if scale_each is True: if scale_each is True:
for t in tensor: # loop over mini-batch dimension for t in tensor: # loop over mini-batch dimension
norm_range(t, range) norm_range(t, value_range)
else: else:
norm_range(tensor, range) norm_range(tensor, value_range)
if tensor.size(0) == 1: if tensor.size(0) == 1:
return tensor.squeeze(0) return tensor.squeeze(0)
...@@ -92,8 +98,8 @@ def make_grid( ...@@ -92,8 +98,8 @@ def make_grid(
num_channels = tensor.size(1) num_channels = tensor.size(1)
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
k = 0 k = 0
for y in irange(ymaps): for y in range(ymaps):
for x in irange(xmaps): for x in range(xmaps):
if k >= nmaps: if k >= nmaps:
break break
# Tensor.copy_() is a valid method but seems to be missing from the stubs # Tensor.copy_() is a valid method but seems to be missing from the stubs
...@@ -105,16 +111,13 @@ def make_grid( ...@@ -105,16 +111,13 @@ def make_grid(
return grid return grid
@torch.no_grad()
def save_image( def save_image(
tensor: Union[torch.Tensor, List[torch.Tensor]], tensor: Union[torch.Tensor, List[torch.Tensor]],
fp: Union[Text, pathlib.Path, BinaryIO], fp: Union[Text, pathlib.Path, BinaryIO],
nrow: int = 8,
padding: int = 2,
normalize: bool = False, normalize: bool = False,
range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
format: Optional[str] = None, format: Optional[str] = None,
**kwargs
) -> None: ) -> None:
"""Save a given Tensor into an image file. """Save a given Tensor into an image file.
...@@ -126,8 +129,8 @@ def save_image( ...@@ -126,8 +129,8 @@ def save_image(
If a file object was used instead of a filename, this parameter should always be used. If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``. **kwargs: Other arguments are documented in ``make_grid``.
""" """
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each) grid = make_grid(tensor, **kwargs)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr) im = Image.fromarray(ndarr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment