"docs/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "b52a25ca31624563c6af5587f131fc262b438d7f"
Unverified Commit ea197e4e authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Remove `to_tensor()` and `ToTensor()` usages (#5553)

* Remove from models and references.

* Adding most tests and docs.

* Adding transforms tests.

* Remove unnecesary ipython notebook.

* Simplify tests.

* Addressing comments.
parent 4176b632
...@@ -179,7 +179,7 @@ to:: ...@@ -179,7 +179,7 @@ to::
import torch import torch
from torchvision import datasets, transforms as T from torchvision import datasets, transforms as T
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()]) transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.PILToTensor(), T.ConvertImageDtype(torch.float)])
dataset = datasets.ImageNet(".", split="train", transform=transform) dataset = datasets.ImageNet(".", split="train", transform=transform)
means = [] means = []
......
...@@ -41,7 +41,12 @@ class DetectionPresetTrain: ...@@ -41,7 +41,12 @@ class DetectionPresetTrain:
class DetectionPresetEval: class DetectionPresetEval:
def __init__(self): def __init__(self):
self.transforms = T.ToTensor() self.transforms = T.Compose(
[
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
def __call__(self, img, target): def __call__(self, img, target):
return self.transforms(img, target) return self.transforms(img, target)
...@@ -45,15 +45,6 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip): ...@@ -45,15 +45,6 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
return image, target return image, target
class ToTensor(nn.Module):
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.pil_to_tensor(image)
image = F.convert_image_dtype(image)
return image, target
class PILToTensor(nn.Module): class PILToTensor(nn.Module):
def forward( def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
......
import unittest import unittest
from collections import defaultdict from collections import defaultdict
import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
from sampler import PKSampler from sampler import PKSampler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -17,7 +18,13 @@ class Tester(unittest.TestCase): ...@@ -17,7 +18,13 @@ class Tester(unittest.TestCase):
self.assertRaises(AssertionError, PKSampler, targets, p, k) self.assertRaises(AssertionError, PKSampler, targets, p, k)
# Ensure p, k constraints on batch # Ensure p, k constraints on batch
dataset = FakeData(size=1000, num_classes=100, image_size=(3, 1, 1), transform=transforms.ToTensor()) trans = transforms.Compose(
[
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
]
)
dataset = FakeData(size=1000, num_classes=100, image_size=(3, 1, 1), transform=trans)
targets = [target.item() for _, target in dataset] targets = [target.item() for _, target in dataset]
sampler = PKSampler(targets, p, k) sampler = PKSampler(targets, p, k)
loader = DataLoader(dataset, batch_size=p * k, sampler=sampler) loader = DataLoader(dataset, batch_size=p * k, sampler=sampler)
......
...@@ -102,7 +102,12 @@ def main(args): ...@@ -102,7 +102,12 @@ def main(args):
optimizer = Adam(model.parameters(), lr=args.lr) optimizer = Adam(model.parameters(), lr=args.lr)
transform = transforms.Compose( transform = transforms.Compose(
[transforms.Lambda(lambda image: image.convert("RGB")), transforms.Resize((224, 224)), transforms.ToTensor()] [
transforms.Lambda(lambda image: image.convert("RGB")),
transforms.Resize((224, 224)),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
]
) )
# Using FMNIST to demonstrate embedding learning using triplet loss. This dataset can # Using FMNIST to demonstrate embedding learning using triplet loss. This dataset can
......
...@@ -33,7 +33,8 @@ if __name__ == "__main__": ...@@ -33,7 +33,8 @@ if __name__ == "__main__":
[ [
transforms.RandomSizedCrop(224), transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ToTensor(), transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
] ]
) )
......
This diff is collapsed.
...@@ -30,7 +30,8 @@ def read_image1(): ...@@ -30,7 +30,8 @@ def read_image1():
) )
image = Image.open(image_path) image = Image.open(image_path)
image = image.resize((224, 224)) image = image.resize((224, 224))
x = F.to_tensor(image) x = F.pil_to_tensor(image)
x = F.convert_image_dtype(x)
return x.view(1, 3, 224, 224) return x.view(1, 3, 224, 224)
...@@ -40,7 +41,8 @@ def read_image2(): ...@@ -40,7 +41,8 @@ def read_image2():
) )
image = Image.open(image_path) image = Image.open(image_path)
image = image.resize((299, 299)) image = image.resize((299, 299))
x = F.to_tensor(image) x = F.pil_to_tensor(image)
x = F.convert_image_dtype(x)
x = x.view(1, 3, 299, 299) x = x.view(1, 3, 299, 299)
return torch.cat([x, x], 0) return torch.cat([x, x], 0)
......
...@@ -413,13 +413,13 @@ class TestONNXExporter: ...@@ -413,13 +413,13 @@ class TestONNXExporter:
import os import os
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision.transforms import functional as F
data_dir = os.path.join(os.path.dirname(__file__), "assets") data_dir = os.path.join(os.path.dirname(__file__), "assets")
path = os.path.join(data_dir, *rel_path.split("/")) path = os.path.join(data_dir, *rel_path.split("/"))
image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR) image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
return transforms.ToTensor()(image) return F.convert_image_dtype(F.pil_to_tensor(image))
def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
return ( return (
......
...@@ -154,7 +154,7 @@ class TestConvertImageDtype: ...@@ -154,7 +154,7 @@ class TestConvertImageDtype:
@pytest.mark.skipif(accimage is None, reason="accimage not available") @pytest.mark.skipif(accimage is None, reason="accimage not available")
class TestAccImage: class TestAccImage:
def test_accimage_to_tensor(self): def test_accimage_to_tensor(self):
trans = transforms.ToTensor() trans = transforms.PILToTensor()
expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB")) expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB"))
output = trans(accimage.Image(GRACE_HOPPER)) output = trans(accimage.Image(GRACE_HOPPER))
...@@ -174,7 +174,8 @@ class TestAccImage: ...@@ -174,7 +174,8 @@ class TestAccImage:
trans = transforms.Compose( trans = transforms.Compose(
[ [
transforms.Resize(256, interpolation=Image.LINEAR), transforms.Resize(256, interpolation=Image.LINEAR),
transforms.ToTensor(), transforms.PILToTensor(),
transforms.ConvertImageDtype(dtype=torch.float),
] ]
) )
...@@ -192,10 +193,7 @@ class TestAccImage: ...@@ -192,10 +193,7 @@ class TestAccImage:
def test_accimage_crop(self): def test_accimage_crop(self):
trans = transforms.Compose( trans = transforms.Compose(
[ [transforms.CenterCrop(256), transforms.PILToTensor(), transforms.ConvertImageDtype(dtype=torch.float)]
transforms.CenterCrop(256),
transforms.ToTensor(),
]
) )
# Checking if Compose, CenterCrop and ToTensor can be printed as string # Checking if Compose, CenterCrop and ToTensor can be printed as string
...@@ -457,26 +455,24 @@ class TestPad: ...@@ -457,26 +455,24 @@ class TestPad:
def test_pad(self): def test_pad(self):
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
img = torch.ones(3, height, width) img = torch.ones(3, height, width, dtype=torch.uint8)
padding = random.randint(1, 20) padding = random.randint(1, 20)
fill = random.randint(1, 50) fill = random.randint(1, 50)
result = transforms.Compose( result = transforms.Compose(
[ [
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.Pad(padding, fill=fill), transforms.Pad(padding, fill=fill),
transforms.ToTensor(), transforms.PILToTensor(),
] ]
)(img) )(img)
assert result.size(1) == height + 2 * padding assert result.size(1) == height + 2 * padding
assert result.size(2) == width + 2 * padding assert result.size(2) == width + 2 * padding
# check that all elements in the padded region correspond # check that all elements in the padded region correspond
# to the pad value # to the pad value
fill_v = fill / 255
eps = 1e-5
h_padded = result[:, :padding, :] h_padded = result[:, :padding, :]
w_padded = result[:, :, :padding] w_padded = result[:, :, :padding]
torch.testing.assert_close(h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps) torch.testing.assert_close(h_padded, torch.full_like(h_padded, fill_value=fill), rtol=0.0, atol=0.0)
torch.testing.assert_close(w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps) torch.testing.assert_close(w_padded, torch.full_like(w_padded, fill_value=fill), rtol=0.0, atol=0.0)
pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)), transforms.ToPILImage()(img)) pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)), transforms.ToPILImage()(img))
def test_pad_with_tuple_of_pad_values(self): def test_pad_with_tuple_of_pad_values(self):
...@@ -509,7 +505,7 @@ class TestPad: ...@@ -509,7 +505,7 @@ class TestPad:
# edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0 # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0
edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6] edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6]
assert_equal(edge_middle_slice, np.asarray([200, 200, 200, 200, 1, 0], dtype=np.uint8)) assert_equal(edge_middle_slice, np.asarray([200, 200, 200, 200, 1, 0], dtype=np.uint8))
assert transforms.ToTensor()(edge_padded_img).size() == (3, 35, 35) assert transforms.PILToTensor()(edge_padded_img).size() == (3, 35, 35)
# Pad 3 to left/right, 2 to top/bottom # Pad 3 to left/right, 2 to top/bottom
reflect_padded_img = F.pad(img, (3, 2), padding_mode="reflect") reflect_padded_img = F.pad(img, (3, 2), padding_mode="reflect")
...@@ -517,7 +513,7 @@ class TestPad: ...@@ -517,7 +513,7 @@ class TestPad:
# reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0 # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0
reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6] reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6]
assert_equal(reflect_middle_slice, np.asarray([0, 0, 1, 200, 1, 0], dtype=np.uint8)) assert_equal(reflect_middle_slice, np.asarray([0, 0, 1, 200, 1, 0], dtype=np.uint8))
assert transforms.ToTensor()(reflect_padded_img).size() == (3, 33, 35) assert transforms.PILToTensor()(reflect_padded_img).size() == (3, 33, 35)
# Pad 3 to left, 2 to top, 2 to right, 1 to bottom # Pad 3 to left, 2 to top, 2 to right, 1 to bottom
symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode="symmetric") symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode="symmetric")
...@@ -525,7 +521,7 @@ class TestPad: ...@@ -525,7 +521,7 @@ class TestPad:
# sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0 # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0
symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6] symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6]
assert_equal(symmetric_middle_slice, np.asarray([0, 1, 200, 200, 1, 0], dtype=np.uint8)) assert_equal(symmetric_middle_slice, np.asarray([0, 1, 200, 200, 1, 0], dtype=np.uint8))
assert transforms.ToTensor()(symmetric_padded_img).size() == (3, 32, 34) assert transforms.PILToTensor()(symmetric_padded_img).size() == (3, 32, 34)
# Check negative padding explicitly for symmetric case, since it is not # Check negative padding explicitly for symmetric case, since it is not
# implemented for tensor case to compare to # implemented for tensor case to compare to
...@@ -535,7 +531,7 @@ class TestPad: ...@@ -535,7 +531,7 @@ class TestPad:
symmetric_neg_middle_right = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][-4:] symmetric_neg_middle_right = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][-4:]
assert_equal(symmetric_neg_middle_left, np.asarray([1, 0, 0], dtype=np.uint8)) assert_equal(symmetric_neg_middle_left, np.asarray([1, 0, 0], dtype=np.uint8))
assert_equal(symmetric_neg_middle_right, np.asarray([200, 200, 0, 0], dtype=np.uint8)) assert_equal(symmetric_neg_middle_right, np.asarray([200, 200, 0, 0], dtype=np.uint8))
assert transforms.ToTensor()(symmetric_padded_img_neg).size() == (3, 28, 31) assert transforms.PILToTensor()(symmetric_padded_img_neg).size() == (3, 28, 31)
def test_pad_raises_with_invalid_pad_sequence_len(self): def test_pad_raises_with_invalid_pad_sequence_len(self):
with pytest.raises(ValueError): with pytest.raises(ValueError):
...@@ -1625,12 +1621,12 @@ def test_random_crop(): ...@@ -1625,12 +1621,12 @@ def test_random_crop():
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2 oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2 owidth = random.randint(5, (width - 2) / 2) * 2
img = torch.ones(3, height, width) img = torch.ones(3, height, width, dtype=torch.uint8)
result = transforms.Compose( result = transforms.Compose(
[ [
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.RandomCrop((oheight, owidth)), transforms.RandomCrop((oheight, owidth)),
transforms.ToTensor(), transforms.PILToTensor(),
] ]
)(img) )(img)
assert result.size(1) == oheight assert result.size(1) == oheight
...@@ -1641,14 +1637,14 @@ def test_random_crop(): ...@@ -1641,14 +1637,14 @@ def test_random_crop():
[ [
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.RandomCrop((oheight, owidth), padding=padding), transforms.RandomCrop((oheight, owidth), padding=padding),
transforms.ToTensor(), transforms.PILToTensor(),
] ]
)(img) )(img)
assert result.size(1) == oheight assert result.size(1) == oheight
assert result.size(2) == owidth assert result.size(2) == owidth
result = transforms.Compose( result = transforms.Compose(
[transforms.ToPILImage(), transforms.RandomCrop((height, width)), transforms.ToTensor()] [transforms.ToPILImage(), transforms.RandomCrop((height, width)), transforms.PILToTensor()]
)(img) )(img)
assert result.size(1) == height assert result.size(1) == height
assert result.size(2) == width assert result.size(2) == width
...@@ -1658,7 +1654,7 @@ def test_random_crop(): ...@@ -1658,7 +1654,7 @@ def test_random_crop():
[ [
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True), transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True),
transforms.ToTensor(), transforms.PILToTensor(),
] ]
)(img) )(img)
assert result.size(1) == height + 1 assert result.size(1) == height + 1
...@@ -1676,7 +1672,7 @@ def test_center_crop(): ...@@ -1676,7 +1672,7 @@ def test_center_crop():
oheight = random.randint(5, (height - 2) / 2) * 2 oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2 owidth = random.randint(5, (width - 2) / 2) * 2
img = torch.ones(3, height, width) img = torch.ones(3, height, width, dtype=torch.uint8)
oh1 = (height - oheight) // 2 oh1 = (height - oheight) // 2
ow1 = (width - owidth) // 2 ow1 = (width - owidth) // 2
imgnarrow = img[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth] imgnarrow = img[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth]
...@@ -1685,7 +1681,7 @@ def test_center_crop(): ...@@ -1685,7 +1681,7 @@ def test_center_crop():
[ [
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.CenterCrop((oheight, owidth)), transforms.CenterCrop((oheight, owidth)),
transforms.ToTensor(), transforms.PILToTensor(),
] ]
)(img) )(img)
assert result.sum() == 0 assert result.sum() == 0
...@@ -1695,7 +1691,7 @@ def test_center_crop(): ...@@ -1695,7 +1691,7 @@ def test_center_crop():
[ [
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.CenterCrop((oheight, owidth)), transforms.CenterCrop((oheight, owidth)),
transforms.ToTensor(), transforms.PILToTensor(),
] ]
)(img) )(img)
sum1 = result.sum() sum1 = result.sum()
...@@ -1706,7 +1702,7 @@ def test_center_crop(): ...@@ -1706,7 +1702,7 @@ def test_center_crop():
[ [
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.CenterCrop((oheight, owidth)), transforms.CenterCrop((oheight, owidth)),
transforms.ToTensor(), transforms.PILToTensor(),
] ]
)(img) )(img)
sum2 = result.sum() sum2 = result.sum()
...@@ -1729,12 +1725,12 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height): ...@@ -1729,12 +1725,12 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height):
delta_height *= delta delta_height *= delta
delta_width *= delta delta_width *= delta
img = torch.ones(3, *input_image_size) img = torch.ones(3, *input_image_size, dtype=torch.uint8)
crop_size = (input_image_size[0] + delta_height, input_image_size[1] + delta_width) crop_size = (input_image_size[0] + delta_height, input_image_size[1] + delta_width)
# Test both transforms, one with PIL input and one with tensor # Test both transforms, one with PIL input and one with tensor
output_pil = transforms.Compose( output_pil = transforms.Compose(
[transforms.ToPILImage(), transforms.CenterCrop(crop_size), transforms.ToTensor()], [transforms.ToPILImage(), transforms.CenterCrop(crop_size), transforms.PILToTensor()],
)(img) )(img)
assert output_pil.size()[1:3] == crop_size assert output_pil.size()[1:3] == crop_size
...@@ -1893,13 +1889,13 @@ def test_randomperspective(): ...@@ -1893,13 +1889,13 @@ def test_randomperspective():
perp = transforms.RandomPerspective() perp = transforms.RandomPerspective()
startpoints, endpoints = perp.get_params(width, height, 0.5) startpoints, endpoints = perp.get_params(width, height, 0.5)
tr_img = F.perspective(img, startpoints, endpoints) tr_img = F.perspective(img, startpoints, endpoints)
tr_img2 = F.to_tensor(F.perspective(tr_img, endpoints, startpoints)) tr_img2 = F.convert_image_dtype(F.pil_to_tensor(F.perspective(tr_img, endpoints, startpoints)))
tr_img = F.to_tensor(tr_img) tr_img = F.convert_image_dtype(F.pil_to_tensor(tr_img))
assert img.size[0] == width assert img.size[0] == width
assert img.size[1] == height assert img.size[1] == height
assert torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3 > torch.nn.functional.mse_loss( assert torch.nn.functional.mse_loss(
tr_img2, F.to_tensor(img) tr_img, F.convert_image_dtype(F.pil_to_tensor(img))
) ) + 0.3 > torch.nn.functional.mse_loss(tr_img2, F.convert_image_dtype(F.pil_to_tensor(img)))
@pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("seed", range(10))
......
...@@ -76,7 +76,7 @@ def test_save_image_file_object(): ...@@ -76,7 +76,7 @@ def test_save_image_file_object():
fp = BytesIO() fp = BytesIO()
utils.save_image(t, fp, format="png") utils.save_image(t, fp, format="png")
img_bytes = Image.open(fp) img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg="Image not stored in file object") assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows") @pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows")
...@@ -88,7 +88,7 @@ def test_save_image_single_pixel_file_object(): ...@@ -88,7 +88,7 @@ def test_save_image_single_pixel_file_object():
fp = BytesIO() fp = BytesIO()
utils.save_image(t, fp, format="png") utils.save_image(t, fp, format="png")
img_bytes = Image.open(fp) img_bytes = Image.open(fp)
assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg="Image not stored in file object") assert_equal(F.pil_to_tensor(img_orig), F.pil_to_tensor(img_bytes), msg="Image not stored in file object")
def test_draw_boxes(): def test_draw_boxes():
......
...@@ -32,7 +32,7 @@ class CelebA(VisionDataset): ...@@ -32,7 +32,7 @@ class CelebA(VisionDataset):
Defaults to ``attr``. If empty, ``None`` will be returned as target. Defaults to ``attr``. If empty, ``None`` will be returned as target.
transform (callable, optional): A function/transform that takes in an PIL image transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor`` and returns a transformed version. E.g, ``transforms.PILToTensor``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and download (bool, optional): If true, downloads the dataset from the internet and
......
...@@ -15,7 +15,7 @@ class CocoDetection(VisionDataset): ...@@ -15,7 +15,7 @@ class CocoDetection(VisionDataset):
root (string): Root directory where images are downloaded to. root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file. annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor`` and returns a transformed version. E.g, ``transforms.PILToTensor``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry transforms (callable, optional): A function/transform that takes input sample and its target as entry
...@@ -66,7 +66,7 @@ class CocoCaptions(CocoDetection): ...@@ -66,7 +66,7 @@ class CocoCaptions(CocoDetection):
root (string): Root directory where images are downloaded to. root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file. annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor`` and returns a transformed version. E.g, ``transforms.PILToTensor``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry transforms (callable, optional): A function/transform that takes input sample and its target as entry
...@@ -80,7 +80,7 @@ class CocoCaptions(CocoDetection): ...@@ -80,7 +80,7 @@ class CocoCaptions(CocoDetection):
import torchvision.transforms as transforms import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are', cap = dset.CocoCaptions(root = 'dir where images are',
annFile = 'json annotation file', annFile = 'json annotation file',
transform=transforms.ToTensor()) transform=transforms.PILToTensor())
print('Number of samples: ', len(cap)) print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample img, target = cap[3] # load 4th sample
......
...@@ -59,7 +59,7 @@ class Flickr8k(VisionDataset): ...@@ -59,7 +59,7 @@ class Flickr8k(VisionDataset):
root (string): Root directory where images are downloaded to. root (string): Root directory where images are downloaded to.
ann_file (string): Path to annotation file. ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.ToTensor`` and returns a transformed version. E.g, ``transforms.PILToTensor``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
...@@ -115,7 +115,7 @@ class Flickr30k(VisionDataset): ...@@ -115,7 +115,7 @@ class Flickr30k(VisionDataset):
root (string): Root directory where images are downloaded to. root (string): Root directory where images are downloaded to.
ann_file (string): Path to annotation file. ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.ToTensor`` and returns a transformed version. E.g, ``transforms.PILToTensor``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
......
...@@ -30,7 +30,7 @@ class Kitti(VisionDataset): ...@@ -30,7 +30,7 @@ class Kitti(VisionDataset):
train (bool, optional): Use ``train`` split if true, else ``test`` split. train (bool, optional): Use ``train`` split if true, else ``test`` split.
Defaults to ``train``. Defaults to ``train``.
transform (callable, optional): A function/transform that takes in a PIL image transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.ToTensor`` and returns a transformed version. E.g, ``transforms.PILToTensor``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
transforms (callable, optional): A function/transform that takes input sample transforms (callable, optional): A function/transform that takes input sample
......
...@@ -980,7 +980,7 @@ class FiveCrop(torch.nn.Module): ...@@ -980,7 +980,7 @@ class FiveCrop(torch.nn.Module):
Example: Example:
>>> transform = Compose([ >>> transform = Compose([
>>> FiveCrop(size), # this is a list of PIL Images >>> FiveCrop(size), # this is a list of PIL Images
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ]) >>> ])
>>> #In your test loop you can do the following: >>> #In your test loop you can do the following:
>>> input, target = batch # input is a 5d tensor, target is 2d >>> input, target = batch # input is a 5d tensor, target is 2d
...@@ -1029,7 +1029,7 @@ class TenCrop(torch.nn.Module): ...@@ -1029,7 +1029,7 @@ class TenCrop(torch.nn.Module):
Example: Example:
>>> transform = Compose([ >>> transform = Compose([
>>> TenCrop(size), # this is a list of PIL Images >>> TenCrop(size), # this is a list of PIL Images
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ]) >>> ])
>>> #In your test loop you can do the following: >>> #In your test loop you can do the following:
>>> input, target = batch # input is a 5d tensor, target is 2d >>> input, target = batch # input is a 5d tensor, target is 2d
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment