Commit 21153802 authored by surgan12's avatar surgan12 Committed by Francisco Massa
Browse files

normalise updates (#699)

* normalise

* some changes

* Update functional.py

* Update functional.py

* code changes
parent 885e3c20
......@@ -181,11 +181,11 @@ def to_pil_image(pic, mode=None):
return Image.fromarray(npimg, mode=mode)
def normalize(tensor, mean, std):
def normalize(tensor, mean, std, inplace=False):
"""Normalize a tensor image with mean and standard deviation.
.. note::
This transform acts in-place, i.e., it mutates the input tensor.
This transform acts out of place by default, i.e., it does not mutates the input tensor.
See :class:`~torchvision.transforms.Normalize` for more details.
......@@ -200,9 +200,12 @@ def normalize(tensor, mean, std):
if not _is_tensor_image(tensor):
raise TypeError('tensor is not a torch image.')
# This is faster than using broadcasting, don't change without benchmarking
for t, m, s in zip(tensor, mean, std):
t.sub_(m).div_(s)
if not inplace:
tensor = tensor.clone()
mean = torch.tensor(mean, dtype=torch.float32)
std = torch.tensor(std, dtype=torch.float32)
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
return tensor
......
......@@ -136,16 +136,17 @@ class Normalize(object):
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note::
This transform acts in-place, i.e., it mutates the input tensor.
This transform acts out of place, i.e., it does not mutates the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
"""
def __init__(self, mean, std):
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, tensor):
"""
......@@ -155,7 +156,7 @@ class Normalize(object):
Returns:
Tensor: Normalized Tensor image.
"""
return F.normalize(tensor, self.mean, self.std)
return F.normalize(tensor, self.mean, self.std, self.inplace)
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment