Unverified Commit 6eff0a43 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added tensor transforms jupyter notebook (#2730)

* [WIP] Added scriptable transforms python example

* Replaced script file with jupyter notebook

* Updated readme

* Updates according to review
+ updated docstrings

* Updated notebook and docstring according to the review

* torch script -> torchscript
parent 4106dbb8
......@@ -14,6 +14,24 @@ All transformations accept PIL Image, Tensor Image or batch of Tensor Images as
Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or
random transformations applied on the batch of Tensor Images identically transform all the images of the batch.
.. warning::
Since v0.8.0 all random transformations are using torch default random generator to sample random parameters.
It is a backward compatibility breaking change and user should set the random state as following:
.. code:: python
# Previous versions
# import random
# random.seed(12)
# Now
import torch
torch.manual_seed(17)
Please, keep in mind that the same seed for torch random generator and Python random generator will not
produce the same results.
Scriptable transforms
---------------------
......@@ -33,6 +51,7 @@ Make sure to use only scriptable transformations, i.e. that work with ``torch.Te
For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``.
Compositions of transforms
--------------------------
......
# Python examples
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb)
[Examples of Tensor Images transformations](https://github.com/pytorch/vision/blob/master/examples/python/tensor_transforms.ipynb)
Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to
that. Now, since v0.8.0, transforms implementations are Tensor and PIL compatible and we can achieve the following new
features:
- transform multi-band torch tensor images (with more than 3-4 channels)
- torchscript transforms together with your model for deployment
- support for GPU acceleration
- batched transformation such as for videos
- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "vjAC2mZnb4nz"
},
"source": [
"# Image transformations\n",
"\n",
"This notebook shows new features of torchvision image transformations. \n",
"\n",
"Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to that. Now, since v0.8.0, transforms implementations are Tensor and PIL compatible and we can achieve the following new \n",
"features:\n",
"- transform multi-band torch tensor images (with more than 3-4 channels) \n",
"- torchscript transforms together with your model for deployment\n",
"- support for GPU acceleration\n",
"- batched transformation such as for videos\n",
"- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "btaDWPDbgIyW",
"outputId": "8a83d408-f643-42da-d247-faf3a1bd3ae0"
},
"outputs": [],
"source": [
"import torch, torchvision\n",
"torch.__version__, torchvision.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Vj9draNb4oA"
},
"source": [
"## Transforms on CPU/CUDA tensor images\n",
"\n",
"Let's show how to apply transformations on images opened directly as a torch tensors.\n",
"Now, torchvision provides image reading functions for PNG and JPG images with torchscript support. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Epp3hCy0b4oD"
},
"outputs": [],
"source": [
"from torchvision.datasets.utils import download_url\n",
"\n",
"download_url(\"https://farm1.static.flickr.com/152/434505223_8d1890e1e2.jpg\", \".\", \"test-image.jpg\")\n",
"download_url(\"https://farm3.static.flickr.com/2142/1896267403_24939864ba.jpg\", \".\", \"test-image2.jpg\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-m7lYDPb4oK"
},
"outputs": [],
"source": [
"import matplotlib.pylab as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 303
},
"id": "5bi8Q7L3b4oc",
"outputId": "e5de5c73-e16d-4992-ebee-94c7ddf0bf54"
},
"outputs": [],
"source": [
"from torchvision.io.image import read_image\n",
"\n",
"tensor_image = read_image(\"test-image.jpg\")\n",
"\n",
"print(\"tensor image info: \", tensor_image.shape, tensor_image.dtype)\n",
"\n",
"plt.imshow(tensor_image.numpy().transpose((1, 2, 0)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def to_rgb_image(tensor):\n",
" \"\"\"Helper method to get RGB numpy array for plotting\"\"\"\n",
" np_img = tensor.cpu().numpy().transpose((1, 2, 0))\n",
" m1, m2 = np_img.min(axis=(0, 1)), np_img.max(axis=(0, 1))\n",
" return (255.0 * (np_img - m1) / (m2 - m1)).astype(\"uint8\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 322
},
"id": "PgWpjxQ3b4pF",
"outputId": "e9a138e8-b45c-4f75-d849-3b41de0e5472"
},
"outputs": [],
"source": [
"import torchvision.transforms as T\n",
"\n",
"# to fix random seed is now:\n",
"torch.manual_seed(12)\n",
"\n",
"transforms = T.Compose([\n",
" T.RandomCrop(224),\n",
" T.RandomHorizontalFlip(p=0.3),\n",
" T.ConvertImageDtype(torch.float),\n",
" T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
"])\n",
"\n",
"out_image = transforms(tensor_image)\n",
"print(\"output tensor image info: \", out_image.shape, out_image.dtype)\n",
"\n",
"plt.imshow(to_rgb_image(out_image))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LmYQB4cxb4pI"
},
"source": [
"Tensor images can be on GPU"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 322
},
"id": "S6syYJGEb4pN",
"outputId": "86bddb64-e648-45f2-c216-790d43cfc26d"
},
"outputs": [],
"source": [
"out_image = transforms(tensor_image.to(\"cuda\"))\n",
"print(\"output tensor image info: \", out_image.shape, out_image.dtype, out_image.device)\n",
"\n",
"plt.imshow(to_rgb_image(out_image))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jg9TQd7ajfyn"
},
"source": [
"## Scriptable transforms for easier deployment via torchscript\n",
"\n",
"Next, we show how to combine input transformations and model's forward pass and use `torch.jit.script` to obtain a single scripted module.\n",
"\n",
"**Note:** we have to use only scriptable transformations that should be derived from `torch.nn.Module`. \n",
"Since v0.8.0, all transformations are scriptable except `Compose`, `RandomApply`, `RandomChoice`, `RandomOrder`, `Lambda` and those applied on PIL images. \n",
"The transformations like `Compose` are kept for backward compatibility and can be easily replaced by existing torch modules, like `nn.Sequential`.\n",
"\n",
"Let's define a module `Predictor` that transforms input tensor and applies ImageNet pretrained resnet18 model on it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NSDOJ3RajfvO"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torchvision.transforms as T\n",
"from torchvision.io.image import read_image\n",
"from torchvision.models import resnet18\n",
"\n",
"\n",
"class Predictor(nn.Module):\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.resnet18 = resnet18(pretrained=True).eval()\n",
" self.transforms = nn.Sequential(\n",
" T.Resize(256),\n",
" T.CenterCrop(224),\n",
" T.ConvertImageDtype(torch.float),\n",
" T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
" )\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" with torch.no_grad():\n",
" x = self.transforms(x)\n",
" y_pred = self.resnet18(x)\n",
" return y_pred.argmax(dim=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZZKDovqej5vA"
},
"source": [
"Now, let's define scripted and non-scripted instances of `Predictor` and apply on multiple tensor images of the same size"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GBBMSo7vjfr0"
},
"outputs": [],
"source": [
"from torchvision.io.image import read_image\n",
"\n",
"predictor = Predictor().to(\"cuda\")\n",
"scripted_predictor = torch.jit.script(predictor).to(\"cuda\")\n",
"\n",
"\n",
"tensor_image1 = read_image(\"test-image.jpg\")\n",
"tensor_image2 = read_image(\"test-image2.jpg\")\n",
"batch = torch.stack([tensor_image1[:, -320:, :], tensor_image2[:, -320:, :]]).to(\"cuda\")\n",
"\n",
"res1 = scripted_predictor(batch)\n",
"res2 = predictor(batch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 501
},
"id": "Dmi9r_p-oKsk",
"outputId": "b9c55e7d-5db1-4975-c485-fecc4075bf47"
},
"outputs": [],
"source": [
"import json\n",
"from torchvision.datasets.utils import download_url\n",
"\n",
"\n",
"download_url(\"https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json\", \".\", \"imagenet_class_index.json\")\n",
"\n",
"\n",
"with open(\"imagenet_class_index.json\", \"r\") as h:\n",
" labels = json.load(h)\n",
"\n",
"\n",
"plt.figure(figsize=(12, 7))\n",
"for i, p in enumerate(res1):\n",
" plt.subplot(1, 2, i + 1)\n",
" plt.title(\"Scripted predictor:\\n{label})\".format(label=labels[str(p.item())]))\n",
" plt.imshow(batch[i, ...].cpu().numpy().transpose((1, 2, 0)))\n",
"\n",
"\n",
"plt.figure(figsize=(12, 7))\n",
"for i, p in enumerate(res2):\n",
" plt.subplot(1, 2, i + 1)\n",
" plt.title(\"Original predictor:\\n{label})\".format(label=labels[str(p.item())]))\n",
" plt.imshow(batch[i, ...].cpu().numpy().transpose((1, 2, 0)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7IYsjzpFqcK8"
},
"source": [
"We save and reload scripted predictor in Python or C++ and use it for inference:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"id": "0kk9LLw5jfol",
"outputId": "05ea6db7-7fcf-4b74-a763-5f117c14cc00"
},
"outputs": [],
"source": [
"scripted_predictor.save(\"scripted_predictor.pt\")\n",
"\n",
"scripted_predictor = torch.jit.load(\"scripted_predictor.pt\")\n",
"res1 = scripted_predictor(batch)\n",
"\n",
"for i, p in enumerate(res1):\n",
" print(\"Scripted predictor: {label})\".format(label=labels[str(p.item())]))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Data reading and decoding functions also support torch script and therefore can be part of the model as well:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class AnotherPredictor(Predictor):\n",
"\n",
" def forward(self, path: str) -> int:\n",
" with torch.no_grad():\n",
" x = read_image(path).unsqueeze(0)\n",
" x = self.transforms(x)\n",
" y_pred = self.resnet18(x)\n",
" return int(y_pred.argmax(dim=1).item())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-cMwTs3Yjffy"
},
"outputs": [],
"source": [
"scripted_predictor2 = torch.jit.script(AnotherPredictor())\n",
"\n",
"res = scripted_predictor2(\"test-image.jpg\")\n",
"\n",
"print(\"Scripted another predictor: {label})\".format(label=labels[str(res)]))"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "torchvision_scriptable_transforms.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
......@@ -33,7 +33,8 @@ _pil_interpolation_to_str = {
class Compose:
"""Composes several transforms together.
"""Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
......@@ -76,7 +77,7 @@ class Compose:
class ToTensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
......@@ -107,7 +108,7 @@ class ToTensor:
class PILToTensor:
"""Convert a ``PIL Image`` to a tensor of the same type.
"""Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript.
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
"""
......@@ -153,7 +154,7 @@ class ConvertImageDtype(torch.nn.Module):
class ToPILImage:
"""Convert a tensor or an ndarray to PIL Image.
"""Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL Image while preserving the value range.
......@@ -373,7 +374,7 @@ class Pad(torch.nn.Module):
class Lambda:
"""Apply a user-defined lambda as a transform.
"""Apply a user-defined lambda as a transform. This transform does not support torchscript.
Args:
lambd (function): Lambda/function to be used for transform.
......@@ -415,7 +416,8 @@ class RandomTransforms:
class RandomApply(RandomTransforms):
"""Apply randomly a list of transformations with a given probability
"""Apply randomly a list of transformations with a given probability.
This transform does not support torchscript.
Args:
transforms (list or tuple): list of transformations
......@@ -444,7 +446,7 @@ class RandomApply(RandomTransforms):
class RandomOrder(RandomTransforms):
"""Apply a list of transformations in a random order
"""Apply a list of transformations in a random order. This transform does not support torchscript.
"""
def __call__(self, img):
order = list(range(len(self.transforms)))
......@@ -455,7 +457,7 @@ class RandomOrder(RandomTransforms):
class RandomChoice(RandomTransforms):
"""Apply single transformation randomly picked from a list
"""Apply single transformation randomly picked from a list. This transform does not support torchscript.
"""
def __call__(self, img):
t = random.choice(self.transforms)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment