Commit 21153802 authored by surgan12's avatar surgan12 Committed by Francisco Massa
Browse files

normalise updates (#699)

* normalise

* some changes

* Update functional.py

* Update functional.py

* code changes
parent 885e3c20
...@@ -181,11 +181,11 @@ def to_pil_image(pic, mode=None): ...@@ -181,11 +181,11 @@ def to_pil_image(pic, mode=None):
return Image.fromarray(npimg, mode=mode) return Image.fromarray(npimg, mode=mode)
def normalize(tensor, mean, std): def normalize(tensor, mean, std, inplace=False):
"""Normalize a tensor image with mean and standard deviation. """Normalize a tensor image with mean and standard deviation.
.. note:: .. note::
This transform acts in-place, i.e., it mutates the input tensor. This transform acts out of place by default, i.e., it does not mutates the input tensor.
See :class:`~torchvision.transforms.Normalize` for more details. See :class:`~torchvision.transforms.Normalize` for more details.
...@@ -200,9 +200,12 @@ def normalize(tensor, mean, std): ...@@ -200,9 +200,12 @@ def normalize(tensor, mean, std):
if not _is_tensor_image(tensor): if not _is_tensor_image(tensor):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
# This is faster than using broadcasting, don't change without benchmarking if not inplace:
for t, m, s in zip(tensor, mean, std): tensor = tensor.clone()
t.sub_(m).div_(s)
mean = torch.tensor(mean, dtype=torch.float32)
std = torch.tensor(std, dtype=torch.float32)
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
return tensor return tensor
......
...@@ -136,16 +136,17 @@ class Normalize(object): ...@@ -136,16 +136,17 @@ class Normalize(object):
``input[channel] = (input[channel] - mean[channel]) / std[channel]`` ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note:: .. note::
This transform acts in-place, i.e., it mutates the input tensor. This transform acts out of place, i.e., it does not mutates the input tensor.
Args: Args:
mean (sequence): Sequence of means for each channel. mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel. std (sequence): Sequence of standard deviations for each channel.
""" """
def __init__(self, mean, std): def __init__(self, mean, std, inplace=False):
self.mean = mean self.mean = mean
self.std = std self.std = std
self.inplace = inplace
def __call__(self, tensor): def __call__(self, tensor):
""" """
...@@ -155,7 +156,7 @@ class Normalize(object): ...@@ -155,7 +156,7 @@ class Normalize(object):
Returns: Returns:
Tensor: Normalized Tensor image. Tensor: Normalized Tensor image.
""" """
return F.normalize(tensor, self.mean, self.std) return F.normalize(tensor, self.mean, self.std, self.inplace)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment