Unverified Commit 767b23ea authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Fixes no grad and range bugs in utils. (#3269)



* Fixes utils

* don't use any

* slightly simplify logic
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 631ff912
......@@ -2,23 +2,24 @@ from typing import Union, Optional, List, Tuple, Text, BinaryIO
import pathlib
import torch
import math
import warnings
import numpy as np
from PIL import Image, ImageDraw
from PIL import ImageFont
__all__ = ["make_grid", "save_image", "draw_bounding_boxes"]
irange = range
@torch.no_grad()
def make_grid(
tensor: Union[torch.Tensor, List[torch.Tensor]],
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
range: Optional[Tuple[int, int]] = None,
value_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
**kwargs
) -> torch.Tensor:
"""Make a grid of images.
......@@ -30,7 +31,7 @@ def make_grid(
padding (int, optional): amount of padding. Default: ``2``.
normalize (bool, optional): If True, shift the image to the range (0, 1),
by the min and max values specified by :attr:`range`. Default: ``False``.
range (tuple, optional): tuple (min, max) where min and max are numbers,
value_range (tuple, optional): tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
scale_each (bool, optional): If ``True``, scale each image in the batch of
......@@ -43,7 +44,12 @@ def make_grid(
"""
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
if "range" in kwargs.keys():
warning = "range will be deprecated, please use value_range instead."
warnings.warn(warning)
value_range = kwargs["range"]
# if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list):
......@@ -61,25 +67,25 @@ def make_grid(
if normalize is True:
tensor = tensor.clone() # avoid modifying tensor in-place
if range is not None:
assert isinstance(range, tuple), \
"range has to be a tuple (min, max) if specified. min and max are numbers"
if value_range is not None:
assert isinstance(value_range, tuple), \
"value_range has to be a tuple (min, max) if specified. min and max are numbers"
def norm_ip(img, low, high):
img.clamp_(min=low, max=high)
img.sub_(low).div_(max(high - low, 1e-5))
def norm_range(t, range):
if range is not None:
norm_ip(t, range[0], range[1])
def norm_range(t, value_range):
if value_range is not None:
norm_ip(t, value_range[0], value_range[1])
else:
norm_ip(t, float(t.min()), float(t.max()))
if scale_each is True:
for t in tensor: # loop over mini-batch dimension
norm_range(t, range)
norm_range(t, value_range)
else:
norm_range(tensor, range)
norm_range(tensor, value_range)
if tensor.size(0) == 1:
return tensor.squeeze(0)
......@@ -92,8 +98,8 @@ def make_grid(
num_channels = tensor.size(1)
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
k = 0
for y in irange(ymaps):
for x in irange(xmaps):
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
# Tensor.copy_() is a valid method but seems to be missing from the stubs
......@@ -105,16 +111,13 @@ def make_grid(
return grid
@torch.no_grad()
def save_image(
tensor: Union[torch.Tensor, List[torch.Tensor]],
fp: Union[Text, pathlib.Path, BinaryIO],
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
format: Optional[str] = None,
**kwargs
) -> None:
"""Save a given Tensor into an image file.
......@@ -126,8 +129,8 @@ def save_image(
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``.
"""
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)
grid = make_grid(tensor, **kwargs)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment