{ "cells": [ { "cell_type": "code", "metadata": {}, "source": [ "# The MIT License (MIT)", "#", "# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.", "#", "# Permission is hereby granted, free of charge, to any person obtaining a copy", "# of this software and associated documentation files (the 'Software'), to deal", "# in the Software without restriction, including without limitation the rights", "# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell", "# copies of the Software, and to permit persons to whom the Software is", "# furnished to do so, subject to the following conditions:", "#", "# The above copyright notice and this permission notice shall be included in", "# all copies or substantial portions of the Software.", "#", "# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR", "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,", "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE", "# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER", "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,", "# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN", "# THE SOFTWARE." ] }, { "cell_type": "markdown", "id": "fee8cfa5", "metadata": {}, "source": [ "# 3D-UNet Example with MIGraphX\n", "References:
\n", "https://github.com/naomifridman/Unet_Brain_tumor_segmentation" ] }, { "cell_type": "code", "execution_count": null, "id": "09ceec31", "metadata": {}, "outputs": [], "source": [ "!pip install SimpleITK matplotlib scikit-image" ] }, { "cell_type": "code", "execution_count": null, "id": "bb22bcc4", "metadata": {}, "outputs": [], "source": [ "import migraphx\n", "from PIL import Image\n", "import numpy as np\n", "import os\n", "import SimpleITK as sitk" ] }, { "cell_type": "markdown", "id": "cb973c63", "metadata": {}, "source": [ "## Fetch U-NET ONNX Model" ] }, { "cell_type": "code", "execution_count": null, "id": "1928662c", "metadata": {}, "outputs": [], "source": [ "!wget -nc https://zenodo.org/record/3928973/files/224_224_160.onnx" ] }, { "cell_type": "markdown", "id": "1a64a616", "metadata": {}, "source": [ "## Load ONNX Model" ] }, { "cell_type": "code", "execution_count": null, "id": "53928a98", "metadata": {}, "outputs": [], "source": [ "model = migraphx.parse_onnx(\"224_224_160.onnx\")" ] }, { "cell_type": "code", "execution_count": null, "id": "27e8587f", "metadata": {}, "outputs": [], "source": [ "model.compile(migraphx.get_target(\"gpu\"))" ] }, { "cell_type": "markdown", "id": "2f6014a4", "metadata": {}, "source": [ "## Print model parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "9e73728c", "metadata": {}, "outputs": [], "source": [ "print(model.get_parameter_names())\n", "print(model.get_parameter_shapes())\n", "print(model.get_output_shapes())" ] }, { "cell_type": "code", "execution_count": null, "id": "a4cac52e", "metadata": {}, "outputs": [], "source": [ "img_type=['FLAIR', 'T1','T1CE', 'T2']\n", "label_type_shrt = ['background', 'necrotic',\n", " 'edema', 'enhancing']\n", "label_type = ['background', 'necrotic and non-enhancing tumor', 'edema', 'enhancing tumor']" ] }, { "cell_type": "code", "execution_count": null, "id": "b65f9297", "metadata": {}, "outputs": [], "source": [ "red_multiplier = [1, 0.2, 0.2]\n", "green_multiplier = [0.35,0.75,0.25]\n", "blue_multiplier = [0,0.5,1.]#[0,0.25,0.9]\n", "yellow_multiplier = [1,1,0.25]\n", "brown_miltiplier = [40./255, 26./255, 13./255]\n", "my_colors=[blue_multiplier, yellow_multiplier, brown_miltiplier]" ] }, { "cell_type": "code", "execution_count": null, "id": "0e175ac5", "metadata": {}, "outputs": [], "source": [ "from importlib import reload # Python 3.4+ only." ] }, { "cell_type": "code", "execution_count": null, "id": "530e4f97", "metadata": {}, "outputs": [], "source": [ "import visualization_utils as vu\n", "from visualization_utils import show_label_on_image4\n", "reload(vu)" ] }, { "cell_type": "code", "execution_count": null, "id": "865c46a2", "metadata": {}, "outputs": [], "source": [ "def show_img_label(img, lbl, modality = 0):\n", " \n", " if (len(lbl.shape)> 2):\n", " lbl[0,0,3]=1 # for uniqe colors in plot\n", " lbl = lbl_from_cat(lbl)\n", " vu.show_n_images([img[:,:,modality],lbl, show_label_on_image4(img[:,:,modality],lbl)],\n", " titles = [img_type[modality], 'Label', 'Label on '+ img_type[modality]]);\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1e926482", "metadata": {}, "outputs": [], "source": [ "def read_img_sitk(img):\n", " inputImage = sitk.ReadImage( img )\n", " inputImage = sitk.Cast( inputImage, sitk.sitkFloat32 )\n", " image = sitk.GetArrayFromImage(inputImage)\n", " return image" ] }, { "cell_type": "code", "execution_count": null, "id": "0b620138", "metadata": {}, "outputs": [], "source": [ "# ima files are of the form\n", "# BraTS19_TCIA04_192_1_flair.nii.gz \n", "# BraTS19_TCIA04_192_1_t1.nii.gz \n", "# BraTS19_TCIA04_192_1_t2.nii.gz\n", "# BraTS19_TCIA04_192_1_seg.nii.gz \n", "# BraTS19_TCIA04_192_1_t1ce.nii.gz\n", "\n", "def read_image_into_numpy(dirpath):\n", " \n", " img_id = os.path.basename(dirpath)\n", " np_image=np.zeros((4, 160, 224, 224), dtype=np.float32)\n", " \n", " ## Flair\n", " flair_img = os.path.join(dirpath, img_id+'_flair.nii.gz')\n", " if (not os.path.isfile(flair_img)):\n", " print(flair_img,' not found aborting')\n", " return None\n", " np_image[0] = read_img_sitk(flair_img)\n", " \n", " ## T1\n", " t1_nb4_img = os.path.join(dirpath, img_id+'_t1_nb4.nii.gz')\n", " if (not os.path.isfile(t1_nb4_img)):\n", " #print(t1_nb4_img,' not found')\n", " t1_img = os.path.join(dirpath, img_id+'_t1.nii.gz')\n", " if (not os.path.isfile(t1_img)):\n", " print(t1_img,' not found aborting')\n", " return None\n", " np_image[1] = read_img_sitk(t1_img)\n", " else:\n", " np_image[1] = read_img_sitk(t1_nb4_img) \n", " \n", " ## T1CE\n", " t1ce_nb4_img = os.path.join(dirpath, img_id+'_t1ce_nb4.nii.gz')\n", " if (not os.path.isfile(t1ce_nb4_img)):\n", " #print(t1ce_nb4_img,' not found')\n", " t1ce_img = os.path.join(dirpath, img_id+'_t1ce.nii.gz')\n", " if (not os.path.isfile(t1ce_img)):\n", " print(t1ce_img,' not found aborting')\n", " return None\n", " np_image[2] = read_img_sitk(t1ce_img)\n", " else:\n", " np_image[2] = read_img_sitk(t1ce_nb4_img) \n", " \n", " \n", " ## T2\n", " t2_img = os.path.join(dirpath, img_id+'_t2.nii.gz')\n", " if (not os.path.isfile(t2_img)):\n", " print(t2_img,' not found aborting')\n", " return None\n", " np_image[3] = read_img_sitk(t2_img)\n", "\n", " return np_image" ] }, { "cell_type": "code", "execution_count": null, "id": "2fb66f17", "metadata": {}, "outputs": [], "source": [ "def read_label_into_numpy(dirpath):\n", " \n", " img_id = os.path.basename(dirpath)\n", " np_image=np.zeros((160, 224, 224), dtype=np.int)\n", " \n", " ## label\n", " label_img = os.path.join(dirpath, img_id+'_seg.nii.gz')\n", " if (not os.path.isfile(label_img)):\n", " print(label_img,' not found aborting')\n", " return None\n", " np_image = read_img_sitk(label_img).astype(int)\n", "\n", " return np_image" ] }, { "cell_type": "code", "execution_count": null, "id": "558d47b9", "metadata": {}, "outputs": [], "source": [ "def bbox2_3D(img):\n", "\n", " r = np.any(img, axis=(1, 2))\n", " c = np.any(img, axis=(0, 2))\n", " z = np.any(img, axis=(0, 1))\n", "\n", " rmin, rmax = np.where(r)[0][[0, -1]]\n", " cmin, cmax = np.where(c)[0][[0, -1]]\n", " zmin, zmax = np.where(z)[0][[0, -1]]\n", "\n", " return [rmin, rmax, cmin, cmax, zmin, zmax]" ] }, { "cell_type": "code", "execution_count": null, "id": "1405e186", "metadata": {}, "outputs": [], "source": [ "def lbl_from_cat(cat_lbl):\n", " \n", " lbl=0\n", " if (len(cat_lbl.shape)==3):\n", " for i in range(1,4):\n", " lbl = lbl + cat_lbl[:,:,i]*i\n", " elif (len(cat_lbl.shape)==4):\n", " for i in range(1,4):\n", " lbl = lbl + cat_lbl[:,:,:,i]*i\n", " else:\n", " print('Error in lbl_from_cat', cat_lbl.shape)\n", " return None\n", " return lbl" ] }, { "cell_type": "code", "execution_count": null, "id": "24eb472f", "metadata": {}, "outputs": [], "source": [ "def show_label(lbl):\n", " vu.show_n_images([lbl[:,:,k] for k in range(4)]+[lbl_from_cat(lbl)],\n", " titles = label_type_shrt + ['Label'])\n", "\n", "def show_pred_im_label(im, lb, pred):\n", " \n", " vu.show_n_images([im[:,:,1], lb[:,:], \n", " show_label_on_image4(im[:,:,1], lb[:,:]),\n", " show_label_on_image4(im[:,:,1], pred[:,:])],\n", " titles=['Flair', 'Label', 'Label on T1', 'Prediction on Flair'])\n", "\n", "def show_pred_im(im, pred):\n", " \n", " vu.show_n_images([im[:,:,1], \n", " im[:,:,0],pred,\n", " show_label_on_image4(im[:,:,1], pred[:,:])],\n", " titles=['Flair','T1', 'Pred', 'Prediction on Flair'])" ] }, { "cell_type": "markdown", "id": "d15f788b", "metadata": {}, "source": [ "Multiple image inputs:\n", "- Native (T1)\n", "- Post-contrast T1-weighted (T1Gd)\n", "- T2-weighted (T2)\n", "- T2 Fluid Attenuated Inversion Recovery (T2-FLAIR)" ] }, { "cell_type": "code", "execution_count": null, "id": "7a7aad87", "metadata": {}, "outputs": [], "source": [ "# Resize input images\n", "from scipy.ndimage import zoom\n", "\n", "def resize(img, shape, mode='constant', orig_shape=(155, 240, 240)):\n", " \"\"\"\n", " Wrapper for scipy.ndimage.zoom suited for MRI images.\n", " \"\"\"\n", " assert len(shape) == 3, \"Can not have more than 3 dimensions\"\n", " factors = (\n", " shape[0]/orig_shape[0],\n", " shape[1]/orig_shape[1], \n", " shape[2]/orig_shape[2]\n", " )\n", " \n", " # Resize to the given shape\n", " return zoom(img, factors, mode=mode)\n", "\n", "def preprocess_label(img, out_shape=None, mode='nearest'):\n", " \"\"\"\n", " Separates out the 3 labels from the segmentation provided, namely:\n", " GD-enhancing tumor (ET — label 4), the peritumoral edema (ED — label 2))\n", " and the necrotic and non-enhancing tumor core (NCR/NET — label 1)\n", " \"\"\"\n", " ncr = img == 1 # Necrotic and Non-Enhancing Tumor (NCR/NET)\n", " \n", " ed = img == 2 # Peritumoral Edema (ED)\n", " et = img == 4 # GD-enhancing Tumor (ET)\n", " \n", " if out_shape is not None:\n", " ncr = resize(ncr, out_shape, mode=mode)\n", " ed = resize(ed, out_shape, mode=mode)\n", " et = resize(et, out_shape, mode=mode)\n", " return np.array([ncr, ed, et], dtype=np.uint8)\n", "\n", "hgg_path = \"/code/AMDMIGraphX/bratsdata/MICCAI_BraTS_2019_Data_Training/HGG\"\n", "np_image=np.zeros((4, 160, 224, 224), dtype=np.float32)\n", "tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_flair.nii.gz'%hgg_path)\n", "tmp = resize(tmp, [160,224,224])\n", "mean = tmp.mean()\n", "std = tmp.std()\n", "np_image[0] = (tmp - mean) / std\n", "\n", "tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t1.nii.gz'%hgg_path)\n", "tmp = resize(tmp, [160,224,224])\n", "mean = tmp.mean()\n", "std = tmp.std()\n", "np_image[1] = (tmp - mean) / std\n", "\n", "tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t1ce.nii.gz'%hgg_path)\n", "tmp = resize(tmp, [160,224,224])\n", "mean = tmp.mean()\n", "std = tmp.std()\n", "np_image[2] = (tmp - mean) / std\n", "\n", "tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t2.nii.gz'%hgg_path)\n", "tmp = resize(tmp, [160,224,224])\n", "mean = tmp.mean()\n", "std = tmp.std()\n", "np_image[3] = (tmp - mean) / std\n", "\n", "print(np_image.shape)\n", "np_image_tmp = np_image.copy()" ] }, { "cell_type": "code", "execution_count": null, "id": "d7e5b3c6", "metadata": {}, "outputs": [], "source": [ "vu.show_n_images(np_image[:,100,:,:], titles=img_type)" ] }, { "cell_type": "code", "execution_count": null, "id": "19117da5", "metadata": {}, "outputs": [], "source": [ "np_lbl=np.zeros((160, 224, 224), dtype=np.int)\n", "tmp = read_img_sitk('/code/AMDMIGraphX/bratsdata/MICCAI_BraTS_2019_Data_Training/HGG/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_seg.nii.gz').astype(int)\n", "tmp = resize(tmp, [160,224,224])\n", "print(tmp.shape)\n", "np_lbl = tmp.astype(int)\n", "print(np_lbl.shape)\n", "\n", "print(np_image.shape)\n", "\n", "img1 = vu.show_label_on_image4(np_image[1,100,:,:], np_lbl[100])\n", "img2 = vu.show_label_on_image(np_image[1,100,:,:], np_lbl[100])\n", "vu.show_n_images([img1,img2,np_image[0,100]])" ] }, { "cell_type": "code", "execution_count": null, "id": "facdea15", "metadata": {}, "outputs": [], "source": [ "def get_pred(img, threshold=0.5):\n", " out_img=img.copy()\n", " out_img=np.where(out_img>threshold, 1,0)\n", " return out_img\n", "\n", "def prediction_from_probabily_3D(img):\n", " \n", " int_image = get_pred(img)\n", " return lbl_from_cat(int_image)\n", "\n", "def get_prediction_for_batch(pred_batch, threshold=0.5):\n", " \n", " out_batch = np.zeros((pred_batch.shape[0], 224, 224),dtype=np.int)\n", " \n", " for j in range(pred_batch.shape[0]):\n", " pred = get_prediction(pred_batch[j])\n", " if (pred.sum()>0):\n", " print(j, np.unique(pred , return_counts=True))\n", " out_batch[j] = lbl_from_cat(get_prediction(pred_batch[j]))\n", " return out_batch\n", "\n", "def get_label_from_pred_batch(labels_batch):\n", " \n", " batch = np.zeros((labels_batch.shape[0], 224, 224), np.uint8)\n", " \n", " for j in range(labels_batch.shape[0]):\n", " batch[j]=get_pred(labels_batch[j,:,:,0])+\\\n", " get_pred(labels_batch[j,:,:,1])*2+\\\n", " get_pred(labels_batch[j,:,:,2])*4\n", "\n", " return batch\n", "\n", "def predict_3D_img_prob(np_file):\n", " \n", " np_img = np.load(np_file)\n", " for_pred_img = np.zeros((160, 224, 224, 4), np.float32)\n", "\n", " # Normalize image\n", " for_pred_img = normalize_3D_image(np_img)\n", "\n", " mdl_pred_img = model.predict(for_pred_img)\n", "\n", " #pred_label = prediction_from_probabily_3D(mdl_pred_img)\n", "\n", " return mdl_pred_img\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7f7fe7ee", "metadata": {}, "outputs": [], "source": [ "#Remember the MIGraphX model inputs\n", "print(model.get_parameter_names())\n", "print(model.get_parameter_shapes())\n", "\n", "np_image = np_image.transpose((0,2,3,1))\n", "\n", "print(np_image.shape)\n", "print(np_image.strides)" ] }, { "cell_type": "code", "execution_count": null, "id": "dfc47b53", "metadata": {}, "outputs": [], "source": [ "def normalize_3D_image(img):\n", " for z in range(img.shape[0]):\n", " for k in range(4):\n", " if (img[z,:,:,k].max()>0):\n", " img[z,:,:,k] /= img[z,:,:,k].max()\n", " return img" ] }, { "cell_type": "code", "execution_count": null, "id": "f990cb50", "metadata": {}, "outputs": [], "source": [ "print(np_image_tmp.shape)\n", "np_image_tmp = np_image_tmp.transpose((1,2,3,0))\n", "print(np_image_tmp.shape)" ] }, { "cell_type": "code", "execution_count": null, "id": "24c3736d", "metadata": {}, "outputs": [], "source": [ "np_image = np.expand_dims(np_image, 0)\n", "print(np_image.shape)\n", "print(np_image.strides)" ] }, { "cell_type": "code", "execution_count": null, "id": "1aac6285", "metadata": {}, "outputs": [], "source": [ "input_im = np.zeros((1,4,224,224,160),dtype='float32')\n", "np.lib.stride_tricks.as_strided(input_im, shape=np_image.shape, strides=input_im.strides)[:] = np_image #getting correct stride\n", "print(input_im.strides)\n", "print(input_im.shape)\n", "\n", "#input_im = normalize_3D_image(input_im)\n", "\n", "print(input_im.strides)\n", "print(input_im.shape)\n", "\n", "result = model.run({\n", " \"input\": input_im\n", " })" ] }, { "cell_type": "code", "execution_count": null, "id": "5848b63d", "metadata": {}, "outputs": [], "source": [ "output = np.array(result[0])\n", "print(output.shape)\n", "output = output[0]\n", "print(output.shape)\n", "output = output.transpose((3,1,2,0))\n", "print(output.shape)" ] }, { "cell_type": "code", "execution_count": null, "id": "ab77f7e9", "metadata": {}, "outputs": [], "source": [ "out = prediction_from_probabily_3D(output)\n", "print(np_image_tmp.shape)\n", "print(np_lbl.shape)\n", "print(out.shape)\n", "print(np.unique(out))\n", "ind=[100]\n", "for i in ind:\n", " show_label(output[i])\n", " show_label(get_pred(output[i]))\n", " show_pred_im_label(np_image_tmp[i], np_lbl[i], out[i])" ] }, { "cell_type": "markdown", "id": "d2862d81", "metadata": {}, "source": [ "The possible prediction discrepancy is due to the not-perfect resizing 3D input image, as BRATS dataset has 3D images of size 160x240x240, meanwhile the ONNX model utilized here requires 155x224x224. This example is representative for how to utilize MIGraphX for such an application. All data processing should follow and match the model requirements otherwise. " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 }