{ "cells": [ { "cell_type": "markdown", "id": "cd7a3990", "metadata": {}, "source": [ "## Import MIGraphX Python Library" ] }, { "cell_type": "code", "execution_count": null, "id": "3930d7b8", "metadata": {}, "outputs": [], "source": [ "import migraphx\n", "from PIL import Image\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch" ] }, { "cell_type": "markdown", "id": "b350c333", "metadata": {}, "source": [ "## Fetch U-NET ONNX Model" ] }, { "cell_type": "code", "execution_count": null, "id": "02a7b7de", "metadata": {}, "outputs": [], "source": [ "!wget -nc https://www.dropbox.com/s/3ntkhyk30x05uuv/unet_13_256.onnx" ] }, { "cell_type": "markdown", "id": "a6cfe6e9", "metadata": {}, "source": [ "## Load ONNX Model" ] }, { "cell_type": "code", "execution_count": null, "id": "e05a13dc", "metadata": {}, "outputs": [], "source": [ "model = migraphx.parse_onnx(\"unet_13_256.onnx\")" ] }, { "cell_type": "code", "execution_count": null, "id": "52c67023", "metadata": {}, "outputs": [], "source": [ "model.compile(migraphx.get_target(\"gpu\"))" ] }, { "cell_type": "markdown", "id": "80edb6f1", "metadata": {}, "source": [ "## Print model parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "fd5c3269", "metadata": {}, "outputs": [], "source": [ "print(model.get_parameter_names())\n", "print(model.get_parameter_shapes())" ] }, { "cell_type": "code", "execution_count": null, "id": "47f956c7", "metadata": {}, "outputs": [], "source": [ "def preprocess(pil_img, newW, newH):\n", " w, h = pil_img.size\n", " assert newW > 0 and newH > 0, 'Scale is too small'\n", " pil_img = pil_img.resize((newW, newH))\n", "\n", " img_nd = np.array(pil_img)\n", "\n", " if len(img_nd.shape) == 2:\n", " img_nd = np.expand_dims(img_nd, axis=2)\n", "\n", " # HWC to CHW\n", " img_print = pil_img\n", " img_trans = img_nd.transpose((2, 0, 1))\n", " if img_trans.max() > 1:\n", " img_trans = img_trans / 255\n", " \n", " img_trans = np.expand_dims(img_trans, 0)\n", "\n", " return img_trans, img_print\n", "\n", "def plot_img_and_mask(img, mask):\n", " classes = mask.shape[0] if len(mask.shape) > 3 else 1\n", " print(classes)\n", " fig, ax = plt.subplots(1, classes + 1)\n", " ax[0].set_title('Input image')\n", " ax[0].imshow(img)\n", " if classes > 1:\n", " for i in range(classes):\n", " ax[i+1].set_title(f'Output mask (class {i+1})')\n", " ax[i+1].imshow(mask[:, :, i])\n", " else:\n", " ax[1].set_title(f'Output mask')\n", " ax[1].imshow(mask[0,0])\n", " plt.xticks([]), plt.yticks([])\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "389ddc4d", "metadata": {}, "outputs": [], "source": [ "img = Image.open(\"./car1.jpeg\")\n", "img, imPrint = preprocess(img, 256, 256)\n", "input_im = np.zeros((1,3,256,256),dtype='float32') \n", "np.lib.stride_tricks.as_strided(input_im, shape=img.shape, strides=input_im.strides)[:] = img #getting correct stride\n", "print(input_im.strides)\n", "print(input_im.shape)\n", "imPrint.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "9de6f2a7", "metadata": {}, "outputs": [], "source": [ "mask = model.run({'inputs':input_im}) # Your first inference would take longer than the following ones.\n", "output_mask = np.array(mask[0])\n", "output_mask.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "58e3062c", "metadata": {}, "outputs": [], "source": [ "probs = torch.sigmoid(torch.from_numpy(output_mask))\n", "full_mask = probs > 0.996\n", "plot_img_and_mask(imPrint, full_mask)" ] }, { "cell_type": "markdown", "id": "6126df0b", "metadata": {}, "source": [ "NOTE: The model weights utilized here are trained by using car images with plain backgrounds. The imperfect result on a \"real-world\" image as shown above is expected. To get a better result fine-tuning the model on a dataset of real-world examples is recommended. " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 }