Unverified Commit d43cf0a3 authored by Cagri Eryilmaz's avatar Cagri Eryilmaz Committed by GitHub
Browse files

Changes to Jupyter Notebook of ResNet-50 Python Inference Example (#759)



* adding changes for resnet50 inference: opencv version problem wit qt + headless server support

* added histogram output instead of text
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 0fa539da
......@@ -12,11 +12,15 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install --upgrade pip\n",
"!pip install opencv-python==4.1.2.30\n",
"!pip install matplotlib\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt \n",
"import cv2\n",
"import json\n",
"import time\n",
......@@ -30,16 +34,22 @@
"metadata": {},
"source": [
"### Importing MIGraphX Library\n",
"Sometimes the PYTHONPATH variable is not set during installation of MIGraphX. If your receive a \"Module Not Found\" error when trying to `import migraphx` in your own application, try running:\n",
"Sometimes the PYTHONPATH variable is not set during installation of MIGraphX. \n",
"If your receive a \"Module Not Found\" error when trying to `import migraphx` in your own application, try running:\n",
"```\n",
"$ export PYTHONPATH=/opt/rocm/lib:$PYTHONPATH\n",
"```\n",
"For this example, the library will be added to the kernel's sys.path."
"For this example, the library will be added to the kernel's sys.path.\n",
"\n",
"If you receive \"cannot open shared object file: No such file or directory\" , please make sure `/opt/rocm/lib` is included in $LD_LIBRARY_PATH\n",
"```\n",
" cannot open shared object file: No such file or directory\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -50,7 +60,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -68,28 +78,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"········\n",
"[youtube] TkqYmvH_XVs: Downloading webpage\n",
"[youtube] TkqYmvH_XVs: Downloading MPD manifest\n",
"[dashsegments] Total fragments: 34\n",
"[download] Destination: sample_vid-TkqYmvH_XVs.f137.mp4\n",
"\u001b[K[download] 100% of 70.35MiB in 00:06.31MiB/s ETA 00:000:11\n",
"[dashsegments] Total fragments: 18\n",
"[download] Destination: sample_vid-TkqYmvH_XVs.f140.m4a\n",
"\u001b[K[download] 100% of 2.58MiB in 00:01.99MiB/s ETA 00:000102\n",
"[ffmpeg] Merging formats into \"sample_vid-TkqYmvH_XVs.mp4\"\n",
"Deleting original file sample_vid-TkqYmvH_XVs.f137.mp4 (pass -k to keep)\n",
"Deleting original file sample_vid-TkqYmvH_XVs.f140.m4a (pass -k to keep)\n"
]
}
],
"outputs": [],
"source": [
"if not path.exists(\"./sample_vid.mp4\"):\n",
" import getpass\n",
......@@ -110,36 +101,9 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2021-01-13 20:24:16-- https://github.com/onnx/models/blob/master/vision/classification/resnet/model/resnet50-v2-7.onnx?raw=true\n",
"Resolving github.com (github.com)... 140.82.112.3\n",
"Connecting to github.com (github.com)|140.82.112.3|:443... connected.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://github.com/onnx/models/raw/master/vision/classification/resnet/model/resnet50-v2-7.onnx [following]\n",
"--2021-01-13 20:24:16-- https://github.com/onnx/models/raw/master/vision/classification/resnet/model/resnet50-v2-7.onnx\n",
"Reusing existing connection to github.com:443.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://media.githubusercontent.com/media/onnx/models/master/vision/classification/resnet/model/resnet50-v2-7.onnx [following]\n",
"--2021-01-13 20:24:16-- https://media.githubusercontent.com/media/onnx/models/master/vision/classification/resnet/model/resnet50-v2-7.onnx\n",
"Resolving media.githubusercontent.com (media.githubusercontent.com)... 151.101.48.133\n",
"Connecting to media.githubusercontent.com (media.githubusercontent.com)|151.101.48.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 102442450 (98M) [application/octet-stream]\n",
"Saving to: ‘resnet50-v2-7.onnx?raw=true’\n",
"\n",
"resnet50-v2-7.onnx? 100%[===================>] 97.70M 88.2MB/s in 1.1s \n",
"\n",
"2021-01-13 20:24:19 (88.2 MB/s) - ‘resnet50-v2-7.onnx?raw=true’ saved [102442450/102442450]\n",
"\n"
]
}
],
"outputs": [],
"source": [
"if not path.exists(\"./resnet50.onnx\"):\n",
" !wget https://github.com/onnx/models/blob/master/vision/classification/resnet/model/resnet50-v2-7.onnx?raw=true\n",
......@@ -155,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -174,7 +138,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -204,7 +168,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -234,7 +198,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -258,7 +222,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -284,12 +248,12 @@
"source": [
"### Inference Loop over Full Video\n",
"\n",
"Now everything is in place so that we can run inference on each frame of the input video. The video will be played and the predicted label will be displayed on top of each frame."
"Now everything is in place so that we can run inference on each frame of the input video. The video will be played and the predicted label will be displayed on top of each frame. If you are working on headless server, please execute the following cell."
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -330,12 +294,34 @@
"cv2.destroyAllWindows()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If script is run on a headless server where .imshow() experiences problems, the following cell for histogram can be run to verify functionalty:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"output_labels = []\n",
"while (cap.isOpened()):\n",
" start = time.perf_counter()\n",
" ret, frame = cap.read()\n",
" if not ret: break\n",
" \n",
" top_prediction = predict_class(frame)\n",
" output_labels.append(labels[top_prediction])\n",
"\n",
"cap.release()\n",
"output_labels = np.array(output_labels)\n",
"plt.hist(output_labels) \n",
"plt.xticks(rotation = 90)\n",
"plt.show()"
]
}
],
"metadata": {
......@@ -354,7 +340,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
"version": "3.6.9"
}
},
"nbformat": 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment