Unverified Commit 6eff0a43 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added tensor transforms jupyter notebook (#2730)

* [WIP] Added scriptable transforms python example

* Replaced script file with jupyter notebook

* Updated readme

* Updates according to review
+ updated docstrings

* Updated notebook and docstring according to the review

* torch script -> torchscript
parent 4106dbb8
...@@ -14,6 +14,24 @@ All transformations accept PIL Image, Tensor Image or batch of Tensor Images as ...@@ -14,6 +14,24 @@ All transformations accept PIL Image, Tensor Image or batch of Tensor Images as
Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or
random transformations applied on the batch of Tensor Images identically transform all the images of the batch. random transformations applied on the batch of Tensor Images identically transform all the images of the batch.
.. warning::
Since v0.8.0 all random transformations are using torch default random generator to sample random parameters.
It is a backward compatibility breaking change and user should set the random state as following:
.. code:: python
# Previous versions
# import random
# random.seed(12)
# Now
import torch
torch.manual_seed(17)
Please, keep in mind that the same seed for torch random generator and Python random generator will not
produce the same results.
Scriptable transforms Scriptable transforms
--------------------- ---------------------
...@@ -33,6 +51,7 @@ Make sure to use only scriptable transformations, i.e. that work with ``torch.Te ...@@ -33,6 +51,7 @@ Make sure to use only scriptable transformations, i.e. that work with ``torch.Te
For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``. For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``.
Compositions of transforms Compositions of transforms
-------------------------- --------------------------
......
# Python examples
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb)
[Examples of Tensor Images transformations](https://github.com/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb)
Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to
that. Now, since v0.8.0, transforms implementations are Tensor and PIL compatible and we can achieve the following new
features:
- transform multi-band torch tensor images (with more than 3-4 channels)
- torchscript transforms together with your model for deployment
- support for GPU acceleration
- batched transformation such as for videos
- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "vjAC2mZnb4nz"
},
"source": [
"# Image transformations\n",
"\n",
"This notebook shows new features of torchvision image transformations. \n",
"\n",
"Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to that. Now, since v0.8.0, transforms implementations are Tensor and PIL compatible and we can achieve the following new \n",
"features:\n",
"- transform multi-band torch tensor images (with more than 3-4 channels) \n",
"- torchscript transforms together with your model for deployment\n",
"- support for GPU acceleration\n",
"- batched transformation such as for videos\n",
"- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "btaDWPDbgIyW",
"outputId": "8a83d408-f643-42da-d247-faf3a1bd3ae0"
},
"outputs": [],
"source": [
"import torch, torchvision\n",
"torch.__version__, torchvision.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Vj9draNb4oA"
},
"source": [
"## Transforms on CPU/CUDA tensor images\n",
"\n",
"Let's show how to apply transformations on images opened directly as a torch tensors.\n",
"Now, torchvision provides image reading functions for PNG and JPG images with torchscript support. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Epp3hCy0b4oD"
},
"outputs": [],
"source": [
"from torchvision.datasets.utils import download_url\n",
"\n",
"download_url(\"https://farm1.static.flickr.com/152/434505223_8d1890e1e2.jpg\", \".\", \"test-image.jpg\")\n",
"download_url(\"https://farm3.static.flickr.com/2142/1896267403_24939864ba.jpg\", \".\", \"test-image2.jpg\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-m7lYDPb4oK"
},
"outputs": [],
"source": [
"import matplotlib.pylab as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 303
},
"id": "5bi8Q7L3b4oc",
"outputId": "e5de5c73-e16d-4992-ebee-94c7ddf0bf54"
},
"outputs": [],
"source": [
"from torchvision.io.image import read_image\n",
"\n",
"tensor_image = read_image(\"test-image.jpg\")\n",
"\n",
"print(\"tensor image info: \", tensor_image.shape, tensor_image.dtype)\n",
"\n",
"plt.imshow(tensor_image.numpy().transpose((1, 2, 0)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def to_rgb_image(tensor):\n",
" \"\"\"Helper method to get RGB numpy array for plotting\"\"\"\n",
" np_img = tensor.cpu().numpy().transpose((1, 2, 0))\n",
" m1, m2 = np_img.min(axis=(0, 1)), np_img.max(axis=(0, 1))\n",
" return (255.0 * (np_img - m1) / (m2 - m1)).astype(\"uint8\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 322
},
"id": "PgWpjxQ3b4pF",
"outputId": "e9a138e8-b45c-4f75-d849-3b41de0e5472"
},
"outputs": [],
"source": [
"import torchvision.transforms as T\n",
"\n",
"# to fix random seed is now:\n",
"torch.manual_seed(12)\n",
"\n",
"transforms = T.Compose([\n",
" T.RandomCrop(224),\n",
" T.RandomHorizontalFlip(p=0.3),\n",
" T.ConvertImageDtype(torch.float),\n",
" T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
"])\n",
"\n",
"out_image = transforms(tensor_image)\n",
"print(\"output tensor image info: \", out_image.shape, out_image.dtype)\n",
"\n",
"plt.imshow(to_rgb_image(out_image))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LmYQB4cxb4pI"
},
"source": [
"Tensor images can be on GPU"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 322
},
"id": "S6syYJGEb4pN",
"outputId": "86bddb64-e648-45f2-c216-790d43cfc26d"
},
"outputs": [],
"source": [
"out_image = transforms(tensor_image.to(\"cuda\"))\n",
"print(\"output tensor image info: \", out_image.shape, out_image.dtype, out_image.device)\n",
"\n",
"plt.imshow(to_rgb_image(out_image))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jg9TQd7ajfyn"
},
"source": [
"## Scriptable transforms for easier deployment via torchscript\n",
"\n",
"Next, we show how to combine input transformations and model's forward pass and use `torch.jit.script` to obtain a single scripted module.\n",
"\n",
"**Note:** we have to use only scriptable transformations that should be derived from `torch.nn.Module`. \n",
"Since v0.8.0, all transformations are scriptable except `Compose`, `RandomApply`, `RandomChoice`, `RandomOrder`, `Lambda` and those applied on PIL images. \n",
"The transformations like `Compose` are kept for backward compatibility and can be easily replaced by existing torch modules, like `nn.Sequential`.\n",
"\n",
"Let's define a module `Predictor` that transforms input tensor and applies ImageNet pretrained resnet18 model on it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NSDOJ3RajfvO"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torchvision.transforms as T\n",
"from torchvision.io.image import read_image\n",
"from torchvision.models import resnet18\n",
"\n",
"\n",
"class Predictor(nn.Module):\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.resnet18 = resnet18(pretrained=True).eval()\n",
" self.transforms = nn.Sequential(\n",
" T.Resize(256),\n",
" T.CenterCrop(224),\n",
" T.ConvertImageDtype(torch.float),\n",
" T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
" )\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" with torch.no_grad():\n",
" x = self.transforms(x)\n",
" y_pred = self.resnet18(x)\n",
" return y_pred.argmax(dim=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZZKDovqej5vA"
},
"source": [
"Now, let's define scripted and non-scripted instances of `Predictor` and apply on multiple tensor images of the same size"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GBBMSo7vjfr0"
},
"outputs": [],
"source": [
"from torchvision.io.image import read_image\n",
"\n",
"predictor = Predictor().to(\"cuda\")\n",
"scripted_predictor = torch.jit.script(predictor).to(\"cuda\")\n",
"\n",
"\n",
"tensor_image1 = read_image(\"test-image.jpg\")\n",
"tensor_image2 = read_image(\"test-image2.jpg\")\n",
"batch = torch.stack([tensor_image1[:, -320:, :], tensor_image2[:, -320:, :]]).to(\"cuda\")\n",
"\n",
"res1 = scripted_predictor(batch)\n",
"res2 = predictor(batch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 501
},
"id": "Dmi9r_p-oKsk",
"outputId": "b9c55e7d-5db1-4975-c485-fecc4075bf47"
},
"outputs": [],
"source": [
"import json\n",
"from torchvision.datasets.utils import download_url\n",
"\n",
"\n",
"download_url(\"https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json\", \".\", \"imagenet_class_index.json\")\n",
"\n",
"\n",
"with open(\"imagenet_class_index.json\", \"r\") as h:\n",
" labels = json.load(h)\n",
"\n",
"\n",
"plt.figure(figsize=(12, 7))\n",
"for i, p in enumerate(res1):\n",
" plt.subplot(1, 2, i + 1)\n",
" plt.title(\"Scripted predictor:\\n{label})\".format(label=labels[str(p.item())]))\n",
" plt.imshow(batch[i, ...].cpu().numpy().transpose((1, 2, 0)))\n",
"\n",
"\n",
"plt.figure(figsize=(12, 7))\n",
"for i, p in enumerate(res2):\n",
" plt.subplot(1, 2, i + 1)\n",
" plt.title(\"Original predictor:\\n{label})\".format(label=labels[str(p.item())]))\n",
" plt.imshow(batch[i, ...].cpu().numpy().transpose((1, 2, 0)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7IYsjzpFqcK8"
},
"source": [
"We save and reload scripted predictor in Python or C++ and use it for inference:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"id": "0kk9LLw5jfol",
"outputId": "05ea6db7-7fcf-4b74-a763-5f117c14cc00"
},
"outputs": [],
"source": [
"scripted_predictor.save(\"scripted_predictor.pt\")\n",
"\n",
"scripted_predictor = torch.jit.load(\"scripted_predictor.pt\")\n",
"res1 = scripted_predictor(batch)\n",
"\n",
"for i, p in enumerate(res1):\n",
" print(\"Scripted predictor: {label})\".format(label=labels[str(p.item())]))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Data reading and decoding functions also support torch script and therefore can be part of the model as well:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class AnotherPredictor(Predictor):\n",
"\n",
" def forward(self, path: str) -> int:\n",
" with torch.no_grad():\n",
" x = read_image(path).unsqueeze(0)\n",
" x = self.transforms(x)\n",
" y_pred = self.resnet18(x)\n",
" return int(y_pred.argmax(dim=1).item())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-cMwTs3Yjffy"
},
"outputs": [],
"source": [
"scripted_predictor2 = torch.jit.script(AnotherPredictor())\n",
"\n",
"res = scripted_predictor2(\"test-image.jpg\")\n",
"\n",
"print(\"Scripted another predictor: {label})\".format(label=labels[str(res)]))"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "torchvision_scriptable_transforms.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
...@@ -33,7 +33,8 @@ _pil_interpolation_to_str = { ...@@ -33,7 +33,8 @@ _pil_interpolation_to_str = {
class Compose: class Compose:
"""Composes several transforms together. """Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args: Args:
transforms (list of ``Transform`` objects): list of transforms to compose. transforms (list of ``Transform`` objects): list of transforms to compose.
...@@ -76,7 +77,7 @@ class Compose: ...@@ -76,7 +77,7 @@ class Compose:
class ToTensor: class ToTensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
...@@ -107,7 +108,7 @@ class ToTensor: ...@@ -107,7 +108,7 @@ class ToTensor:
class PILToTensor: class PILToTensor:
"""Convert a ``PIL Image`` to a tensor of the same type. """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript.
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
""" """
...@@ -153,7 +154,7 @@ class ConvertImageDtype(torch.nn.Module): ...@@ -153,7 +154,7 @@ class ConvertImageDtype(torch.nn.Module):
class ToPILImage: class ToPILImage:
"""Convert a tensor or an ndarray to PIL Image. """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL Image while preserving the value range. H x W x C to a PIL Image while preserving the value range.
...@@ -373,7 +374,7 @@ class Pad(torch.nn.Module): ...@@ -373,7 +374,7 @@ class Pad(torch.nn.Module):
class Lambda: class Lambda:
"""Apply a user-defined lambda as a transform. """Apply a user-defined lambda as a transform. This transform does not support torchscript.
Args: Args:
lambd (function): Lambda/function to be used for transform. lambd (function): Lambda/function to be used for transform.
...@@ -415,7 +416,8 @@ class RandomTransforms: ...@@ -415,7 +416,8 @@ class RandomTransforms:
class RandomApply(RandomTransforms): class RandomApply(RandomTransforms):
"""Apply randomly a list of transformations with a given probability """Apply randomly a list of transformations with a given probability.
This transform does not support torchscript.
Args: Args:
transforms (list or tuple): list of transformations transforms (list or tuple): list of transformations
...@@ -444,7 +446,7 @@ class RandomApply(RandomTransforms): ...@@ -444,7 +446,7 @@ class RandomApply(RandomTransforms):
class RandomOrder(RandomTransforms): class RandomOrder(RandomTransforms):
"""Apply a list of transformations in a random order """Apply a list of transformations in a random order. This transform does not support torchscript.
""" """
def __call__(self, img): def __call__(self, img):
order = list(range(len(self.transforms))) order = list(range(len(self.transforms)))
...@@ -455,7 +457,7 @@ class RandomOrder(RandomTransforms): ...@@ -455,7 +457,7 @@ class RandomOrder(RandomTransforms):
class RandomChoice(RandomTransforms): class RandomChoice(RandomTransforms):
"""Apply single transformation randomly picked from a list """Apply single transformation randomly picked from a list. This transform does not support torchscript.
""" """
def __call__(self, img): def __call__(self, img):
t = random.choice(self.transforms) t = random.choice(self.transforms)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment