{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "vjAC2mZnb4nz" }, "source": [ "# Image transformations\n", "\n", "This notebook shows new features of torchvision image transformations. \n", "\n", "Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to that. Now, since v0.8.0, transforms implementations are Tensor and PIL compatible and we can achieve the following new \n", "features:\n", "- transform multi-band torch tensor images (with more than 3-4 channels) \n", "- torchscript transforms together with your model for deployment\n", "- support for GPU acceleration\n", "- batched transformation such as for videos\n", "- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "id": "btaDWPDbgIyW", "outputId": "8a83d408-f643-42da-d247-faf3a1bd3ae0" }, "outputs": [], "source": [ "import torch, torchvision\n", "torch.__version__, torchvision.__version__" ] }, { "cell_type": "markdown", "metadata": { "id": "9Vj9draNb4oA" }, "source": [ "## Transforms on CPU/CUDA tensor images\n", "\n", "Let's show how to apply transformations on images opened directly as a torch tensors.\n", "Now, torchvision provides image reading functions for PNG and JPG images with torchscript support. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Epp3hCy0b4oD" }, "outputs": [], "source": [ "from torchvision.datasets.utils import download_url\n", "\n", "download_url(\"https://farm1.static.flickr.com/152/434505223_8d1890e1e2.jpg\", \".\", \"test-image.jpg\")\n", "download_url(\"https://farm3.static.flickr.com/2142/1896267403_24939864ba.jpg\", \".\", \"test-image2.jpg\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-m7lYDPb4oK" }, "outputs": [], "source": [ "import matplotlib.pylab as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 303 }, "id": "5bi8Q7L3b4oc", "outputId": "e5de5c73-e16d-4992-ebee-94c7ddf0bf54" }, "outputs": [], "source": [ "from torchvision.io.image import read_image\n", "\n", "tensor_image = read_image(\"test-image.jpg\")\n", "\n", "print(\"tensor image info: \", tensor_image.shape, tensor_image.dtype)\n", "\n", "plt.imshow(tensor_image.numpy().transpose((1, 2, 0)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def to_rgb_image(tensor):\n", " \"\"\"Helper method to get RGB numpy array for plotting\"\"\"\n", " np_img = tensor.cpu().numpy().transpose((1, 2, 0))\n", " m1, m2 = np_img.min(axis=(0, 1)), np_img.max(axis=(0, 1))\n", " return (255.0 * (np_img - m1) / (m2 - m1)).astype(\"uint8\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 322 }, "id": "PgWpjxQ3b4pF", "outputId": "e9a138e8-b45c-4f75-d849-3b41de0e5472" }, "outputs": [], "source": [ "import torchvision.transforms as T\n", "\n", "# to fix random seed is now:\n", "torch.manual_seed(12)\n", "\n", "transforms = T.Compose([\n", " T.RandomCrop(224),\n", " T.RandomHorizontalFlip(p=0.3),\n", " T.ConvertImageDtype(torch.float),\n", " T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", "])\n", "\n", "out_image = transforms(tensor_image)\n", "print(\"output tensor image info: \", out_image.shape, out_image.dtype)\n", "\n", "plt.imshow(to_rgb_image(out_image))" ] }, { "cell_type": "markdown", "metadata": { "id": "LmYQB4cxb4pI" }, "source": [ "Tensor images can be on GPU" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 322 }, "id": "S6syYJGEb4pN", "outputId": "86bddb64-e648-45f2-c216-790d43cfc26d" }, "outputs": [], "source": [ "out_image = transforms(tensor_image.to(\"cuda\"))\n", "print(\"output tensor image info: \", out_image.shape, out_image.dtype, out_image.device)\n", "\n", "plt.imshow(to_rgb_image(out_image))" ] }, { "cell_type": "markdown", "metadata": { "id": "jg9TQd7ajfyn" }, "source": [ "## Scriptable transforms for easier deployment via torchscript\n", "\n", "Next, we show how to combine input transformations and model's forward pass and use `torch.jit.script` to obtain a single scripted module.\n", "\n", "**Note:** we have to use only scriptable transformations that should be derived from `torch.nn.Module`. \n", "Since v0.8.0, all transformations are scriptable except `Compose`, `RandomChoice`, `RandomOrder`, `Lambda` and those applied on PIL images.\n", "The transformations like `Compose` are kept for backward compatibility and can be easily replaced by existing torch modules, like `nn.Sequential`.\n", "\n", "Let's define a module `Predictor` that transforms input tensor and applies ImageNet pretrained resnet18 model on it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NSDOJ3RajfvO" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torchvision.transforms as T\n", "from torchvision.io.image import read_image\n", "from torchvision.models import resnet18\n", "\n", "\n", "class Predictor(nn.Module):\n", "\n", " def __init__(self):\n", " super().__init__()\n", " self.resnet18 = resnet18(pretrained=True).eval()\n", " self.transforms = nn.Sequential(\n", " T.Resize([256, ]), # We use single int value inside a list due to torchscript type restrictions\n", " T.CenterCrop(224),\n", " T.ConvertImageDtype(torch.float),\n", " T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", " )\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " with torch.no_grad():\n", " x = self.transforms(x)\n", " y_pred = self.resnet18(x)\n", " return y_pred.argmax(dim=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "ZZKDovqej5vA" }, "source": [ "Now, let's define scripted and non-scripted instances of `Predictor` and apply on multiple tensor images of the same size" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GBBMSo7vjfr0" }, "outputs": [], "source": [ "from torchvision.io.image import read_image\n", "\n", "predictor = Predictor().to(\"cuda\")\n", "scripted_predictor = torch.jit.script(predictor).to(\"cuda\")\n", "\n", "\n", "tensor_image1 = read_image(\"test-image.jpg\")\n", "tensor_image2 = read_image(\"test-image2.jpg\")\n", "batch = torch.stack([tensor_image1[:, -320:, :], tensor_image2[:, -320:, :]]).to(\"cuda\")\n", "\n", "res1 = scripted_predictor(batch)\n", "res2 = predictor(batch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 501 }, "id": "Dmi9r_p-oKsk", "outputId": "b9c55e7d-5db1-4975-c485-fecc4075bf47" }, "outputs": [], "source": [ "import json\n", "from torchvision.datasets.utils import download_url\n", "\n", "\n", "download_url(\"https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json\", \".\", \"imagenet_class_index.json\")\n", "\n", "\n", "with open(\"imagenet_class_index.json\", \"r\") as h:\n", " labels = json.load(h)\n", "\n", "\n", "plt.figure(figsize=(12, 7))\n", "for i, p in enumerate(res1):\n", " plt.subplot(1, 2, i + 1)\n", " plt.title(\"Scripted predictor:\\n{label})\".format(label=labels[str(p.item())]))\n", " plt.imshow(batch[i, ...].cpu().numpy().transpose((1, 2, 0)))\n", "\n", "\n", "plt.figure(figsize=(12, 7))\n", "for i, p in enumerate(res2):\n", " plt.subplot(1, 2, i + 1)\n", " plt.title(\"Original predictor:\\n{label})\".format(label=labels[str(p.item())]))\n", " plt.imshow(batch[i, ...].cpu().numpy().transpose((1, 2, 0)))" ] }, { "cell_type": "markdown", "metadata": { "id": "7IYsjzpFqcK8" }, "source": [ "We save and reload scripted predictor in Python or C++ and use it for inference:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 52 }, "id": "0kk9LLw5jfol", "outputId": "05ea6db7-7fcf-4b74-a763-5f117c14cc00" }, "outputs": [], "source": [ "scripted_predictor.save(\"scripted_predictor.pt\")\n", "\n", "scripted_predictor = torch.jit.load(\"scripted_predictor.pt\")\n", "res1 = scripted_predictor(batch)\n", "\n", "for i, p in enumerate(res1):\n", " print(\"Scripted predictor: {label})\".format(label=labels[str(p.item())]))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Data reading and decoding functions also support torch script and therefore can be part of the model as well:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class AnotherPredictor(Predictor):\n", "\n", " def forward(self, path: str) -> int:\n", " with torch.no_grad():\n", " x = read_image(path).unsqueeze(0)\n", " x = self.transforms(x)\n", " y_pred = self.resnet18(x)\n", " return int(y_pred.argmax(dim=1).item())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-cMwTs3Yjffy" }, "outputs": [], "source": [ "scripted_predictor2 = torch.jit.script(AnotherPredictor())\n", "\n", "res = scripted_predictor2(\"test-image.jpg\")\n", "\n", "print(\"Scripted another predictor: {label})\".format(label=labels[str(res)]))" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "torchvision_scriptable_transforms.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }