Unverified Commit ed091b14 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge pull request #882 from ROCmSoftwarePlatform/unet

[Example] U-Net Image Segmentation
parents 9054ebbe 42408349
...@@ -10,8 +10,8 @@ This directory contains examples of common use cases for MIGraphX. ...@@ -10,8 +10,8 @@ This directory contains examples of common use cases for MIGraphX.
- [Exporting Frozen Graphs in TF2](./export_frozen_graph_tf2) - [Exporting Frozen Graphs in TF2](./export_frozen_graph_tf2)
- [MIGraphX Docker Container](./migraphx_docker) - [MIGraphX Docker Container](./migraphx_docker)
- [MIGraphX Driver](./migraphx_driver) - [MIGraphX Driver](./migraphx_driver)
- [Python Resnet50 Inference](./python_api_inference) - [Python Resnet50](./python_api_inference)
- [Python BERT SQuAD Inference](./python_bert_squad_example) - [Python BERT-SQuAD](./python_bert_squad_example)
- [Python Super Resolution](./python_super_resolution) - [Python Super Resolution](./python_super_resolution)
- [Python NFNet Inference](./python_nfnet_inference) - [Python NFNet](./python_nfnet_inference)
- [Python U-Net](./python_unet)
...@@ -213,7 +213,7 @@ ...@@ -213,7 +213,7 @@
"# Run the model\n", "# Run the model\n",
"start = time.time()\n", "start = time.time()\n",
"results = model.run({'inputs':data_input}) # Your first inference would take longer than the following ones.\n", "results = model.run({'inputs':data_input}) # Your first inference would take longer than the following ones.\n",
"print(f\"Time inference took: {100*(time.time() - start):.2f}ms\")\n", "print(f\"Time inference took: {1000*(time.time() - start):.2f}ms\")\n",
"# Extract the index of the top prediction\n", "# Extract the index of the top prediction\n",
"res_npa = np.array(results[0])\n", "res_npa = np.array(results[0])\n",
"print(f\"\\nResult: {labels[np.argmax(res_npa)]}\")" "print(f\"\\nResult: {labels[np.argmax(res_npa)]}\")"
...@@ -228,7 +228,7 @@ ...@@ -228,7 +228,7 @@
"# Run the model again, first one would take long\n", "# Run the model again, first one would take long\n",
"start = time.time()\n", "start = time.time()\n",
"results = model.run({'inputs':data_input}) # Your first inference would take longer than the following ones.\n", "results = model.run({'inputs':data_input}) # Your first inference would take longer than the following ones.\n",
"print(f\"Time inference took: {100*(time.time() - start):.2f}ms\")\n", "print(f\"Time inference took: {1000*(time.time() - start):.2f}ms\")\n",
"# Extract the index of the top prediction\n", "# Extract the index of the top prediction\n",
"res_npa = np.array(results[0])\n", "res_npa = np.array(results[0])\n",
"print(f\"\\nResult: {labels[np.argmax(res_npa)]}\")" "print(f\"\\nResult: {labels[np.argmax(res_npa)]}\")"
......
# U-Net Image Segmentation Inference with AMD MIGraphX
This examples provides a simple example for utilizing U-Net ONNX model for image segmentation, using AMD MIGraphX graph optimization engine for fast inference.
## How-to
Please utilize the notebook given `unet_inference.ipynb`.
## Model Details
ONNX model utilized in this example can be found [here](https://www.dropbox.com/s/3ntkhyk30x05uuv/unet_13_256.onnx).
\ No newline at end of file
numpy
matplotlib
\ No newline at end of file
{
"cells": [
{
"cell_type": "markdown",
"id": "cd7a3990",
"metadata": {},
"source": [
"## Import MIGraphX Python Library"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3930d7b8",
"metadata": {},
"outputs": [],
"source": [
"import migraphx\n",
"from PIL import Image\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"id": "b350c333",
"metadata": {},
"source": [
"## Fetch U-NET ONNX Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02a7b7de",
"metadata": {},
"outputs": [],
"source": [
"!wget -nc https://www.dropbox.com/s/3ntkhyk30x05uuv/unet_13_256.onnx"
]
},
{
"cell_type": "markdown",
"id": "a6cfe6e9",
"metadata": {},
"source": [
"## Load ONNX Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e05a13dc",
"metadata": {},
"outputs": [],
"source": [
"model = migraphx.parse_onnx(\"unet_13_256.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52c67023",
"metadata": {},
"outputs": [],
"source": [
"model.compile(migraphx.get_target(\"gpu\"))"
]
},
{
"cell_type": "markdown",
"id": "80edb6f1",
"metadata": {},
"source": [
"## Print model parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd5c3269",
"metadata": {},
"outputs": [],
"source": [
"print(model.get_parameter_names())\n",
"print(model.get_parameter_shapes())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47f956c7",
"metadata": {},
"outputs": [],
"source": [
"def preprocess(pil_img, newW, newH):\n",
" w, h = pil_img.size\n",
" assert newW > 0 and newH > 0, 'Scale is too small'\n",
" pil_img = pil_img.resize((newW, newH))\n",
"\n",
" img_nd = np.array(pil_img)\n",
"\n",
" if len(img_nd.shape) == 2:\n",
" img_nd = np.expand_dims(img_nd, axis=2)\n",
"\n",
" # HWC to CHW\n",
" img_print = pil_img\n",
" img_trans = img_nd.transpose((2, 0, 1))\n",
" if img_trans.max() > 1:\n",
" img_trans = img_trans / 255\n",
" \n",
" img_trans = np.expand_dims(img_trans, 0)\n",
"\n",
" return img_trans, img_print\n",
"\n",
"def plot_img_and_mask(img, mask):\n",
" classes = mask.shape[0] if len(mask.shape) > 3 else 1\n",
" print(classes)\n",
" fig, ax = plt.subplots(1, classes + 1)\n",
" ax[0].set_title('Input image')\n",
" ax[0].imshow(img)\n",
" if classes > 1:\n",
" for i in range(classes):\n",
" ax[i+1].set_title(f'Output mask (class {i+1})')\n",
" ax[i+1].imshow(mask[:, :, i])\n",
" else:\n",
" ax[1].set_title(f'Output mask')\n",
" ax[1].imshow(mask[0,0])\n",
" plt.xticks([]), plt.yticks([])\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "389ddc4d",
"metadata": {},
"outputs": [],
"source": [
"img = Image.open(\"./car1.jpeg\")\n",
"img, imPrint = preprocess(img, 256, 256)\n",
"input_im = np.zeros((1,3,256,256),dtype='float32') \n",
"np.lib.stride_tricks.as_strided(input_im, shape=img.shape, strides=input_im.strides)[:] = img #getting correct stride\n",
"print(input_im.strides)\n",
"print(input_im.shape)\n",
"imPrint.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9de6f2a7",
"metadata": {},
"outputs": [],
"source": [
"mask = model.run({'inputs':input_im}) # Your first inference would take longer than the following ones.\n",
"output_mask = np.array(mask[0])\n",
"print(output_mask.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "acbd68e3",
"metadata": {},
"outputs": [],
"source": [
"def sigmoid(x):\n",
" return 1 / (1 + np.exp(-x))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "58e3062c",
"metadata": {},
"outputs": [],
"source": [
"probs = sigmoid(output_mask)\n",
"full_mask = probs > 0.996\n",
"plot_img_and_mask(imPrint, full_mask)"
]
},
{
"cell_type": "markdown",
"id": "6126df0b",
"metadata": {},
"source": [
"<b>NOTE:</b> The model weights utilized here are trained by using car images with plain backgrounds. The imperfect result on a \"real-world\" image as shown above is expected. To get a better result fine-tuning the model on a dataset of real-world examples is recommended. "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment