{ "cells": [ { "cell_type": "markdown", "id": "fee8cfa5", "metadata": {}, "source": [ "# 3D-UNet Example with MIGraphX\n", "References:
\n", "https://github.com/naomifridman/Unet_Brain_tumor_segmentation" ] }, { "cell_type": "code", "execution_count": null, "id": "09ceec31", "metadata": {}, "outputs": [], "source": [ "!pip install SimpleITK matplotlib scikit-image" ] }, { "cell_type": "code", "execution_count": null, "id": "bb22bcc4", "metadata": {}, "outputs": [], "source": [ "import migraphx\n", "from PIL import Image\n", "import numpy as np\n", "import os\n", "import SimpleITK as sitk" ] }, { "cell_type": "markdown", "id": "cb973c63", "metadata": {}, "source": [ "## Fetch U-NET ONNX Model" ] }, { "cell_type": "code", "execution_count": null, "id": "1928662c", "metadata": {}, "outputs": [], "source": [ "!wget -nc https://zenodo.org/record/3928973/files/224_224_160.onnx" ] }, { "cell_type": "markdown", "id": "1a64a616", "metadata": {}, "source": [ "## Load ONNX Model" ] }, { "cell_type": "code", "execution_count": null, "id": "53928a98", "metadata": {}, "outputs": [], "source": [ "model = migraphx.parse_onnx(\"224_224_160.onnx\")" ] }, { "cell_type": "code", "execution_count": null, "id": "27e8587f", "metadata": {}, "outputs": [], "source": [ "model.compile(migraphx.get_target(\"gpu\"))" ] }, { "cell_type": "markdown", "id": "2f6014a4", "metadata": {}, "source": [ "## Print model parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "9e73728c", "metadata": {}, "outputs": [], "source": [ "print(model.get_parameter_names())\n", "print(model.get_parameter_shapes())\n", "print(model.get_output_shapes())" ] }, { "cell_type": "code", "execution_count": null, "id": "a4cac52e", "metadata": {}, "outputs": [], "source": [ "img_type=['FLAIR', 'T1','T1CE', 'T2']\n", "label_type_shrt = ['background', 'necrotic',\n", " 'edema', 'enhancing']\n", "label_type = ['background', 'necrotic and non-enhancing tumor', 'edema', 'enhancing tumor']" ] }, { "cell_type": "code", "execution_count": null, "id": "b65f9297", "metadata": {}, "outputs": [], "source": [ "red_multiplier = [1, 0.2, 0.2]\n", "green_multiplier = [0.35,0.75,0.25]\n", "blue_multiplier = [0,0.5,1.]#[0,0.25,0.9]\n", "yellow_multiplier = [1,1,0.25]\n", "brown_miltiplier = [40./255, 26./255, 13./255]\n", "my_colors=[blue_multiplier, yellow_multiplier, brown_miltiplier]" ] }, { "cell_type": "code", "execution_count": null, "id": "0e175ac5", "metadata": {}, "outputs": [], "source": [ "from importlib import reload # Python 3.4+ only." ] }, { "cell_type": "code", "execution_count": null, "id": "530e4f97", "metadata": {}, "outputs": [], "source": [ "import visualization_utils as vu\n", "from visualization_utils import show_label_on_image4\n", "reload(vu)" ] }, { "cell_type": "code", "execution_count": null, "id": "865c46a2", "metadata": {}, "outputs": [], "source": [ "def show_img_label(img, lbl, modality = 0):\n", " \n", " if (len(lbl.shape)> 2):\n", " lbl[0,0,3]=1 # for uniqe colors in plot\n", " lbl = lbl_from_cat(lbl)\n", " vu.show_n_images([img[:,:,modality],lbl, show_label_on_image4(img[:,:,modality],lbl)],\n", " titles = [img_type[modality], 'Label', 'Label on '+ img_type[modality]]);\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1e926482", "metadata": {}, "outputs": [], "source": [ "def read_img_sitk(img):\n", " inputImage = sitk.ReadImage( img )\n", " inputImage = sitk.Cast( inputImage, sitk.sitkFloat32 )\n", " image = sitk.GetArrayFromImage(inputImage)\n", " return image" ] }, { "cell_type": "code", "execution_count": null, "id": "0b620138", "metadata": {}, "outputs": [], "source": [ "# ima files are of the form\n", "# BraTS19_TCIA04_192_1_flair.nii.gz \n", "# BraTS19_TCIA04_192_1_t1.nii.gz \n", "# BraTS19_TCIA04_192_1_t2.nii.gz\n", "# BraTS19_TCIA04_192_1_seg.nii.gz \n", "# BraTS19_TCIA04_192_1_t1ce.nii.gz\n", "\n", "def read_image_into_numpy(dirpath):\n", " \n", " img_id = os.path.basename(dirpath)\n", " np_image=np.zeros((4, 160, 224, 224), dtype=np.float32)\n", " \n", " ## Flair\n", " flair_img = os.path.join(dirpath, img_id+'_flair.nii.gz')\n", " if (not os.path.isfile(flair_img)):\n", " print(flair_img,' not found aborting')\n", " return None\n", " np_image[0] = read_img_sitk(flair_img)\n", " \n", " ## T1\n", " t1_nb4_img = os.path.join(dirpath, img_id+'_t1_nb4.nii.gz')\n", " if (not os.path.isfile(t1_nb4_img)):\n", " #print(t1_nb4_img,' not found')\n", " t1_img = os.path.join(dirpath, img_id+'_t1.nii.gz')\n", " if (not os.path.isfile(t1_img)):\n", " print(t1_img,' not found aborting')\n", " return None\n", " np_image[1] = read_img_sitk(t1_img)\n", " else:\n", " np_image[1] = read_img_sitk(t1_nb4_img) \n", " \n", " ## T1CE\n", " t1ce_nb4_img = os.path.join(dirpath, img_id+'_t1ce_nb4.nii.gz')\n", " if (not os.path.isfile(t1ce_nb4_img)):\n", " #print(t1ce_nb4_img,' not found')\n", " t1ce_img = os.path.join(dirpath, img_id+'_t1ce.nii.gz')\n", " if (not os.path.isfile(t1ce_img)):\n", " print(t1ce_img,' not found aborting')\n", " return None\n", " np_image[2] = read_img_sitk(t1ce_img)\n", " else:\n", " np_image[2] = read_img_sitk(t1ce_nb4_img) \n", " \n", " \n", " ## T2\n", " t2_img = os.path.join(dirpath, img_id+'_t2.nii.gz')\n", " if (not os.path.isfile(t2_img)):\n", " print(t2_img,' not found aborting')\n", " return None\n", " np_image[3] = read_img_sitk(t2_img)\n", "\n", " return np_image" ] }, { "cell_type": "code", "execution_count": null, "id": "2fb66f17", "metadata": {}, "outputs": [], "source": [ "def read_label_into_numpy(dirpath):\n", " \n", " img_id = os.path.basename(dirpath)\n", " np_image=np.zeros((160, 224, 224), dtype=np.int)\n", " \n", " ## label\n", " label_img = os.path.join(dirpath, img_id+'_seg.nii.gz')\n", " if (not os.path.isfile(label_img)):\n", " print(label_img,' not found aborting')\n", " return None\n", " np_image = read_img_sitk(label_img).astype(int)\n", "\n", " return np_image" ] }, { "cell_type": "code", "execution_count": null, "id": "558d47b9", "metadata": {}, "outputs": [], "source": [ "def bbox2_3D(img):\n", "\n", " r = np.any(img, axis=(1, 2))\n", " c = np.any(img, axis=(0, 2))\n", " z = np.any(img, axis=(0, 1))\n", "\n", " rmin, rmax = np.where(r)[0][[0, -1]]\n", " cmin, cmax = np.where(c)[0][[0, -1]]\n", " zmin, zmax = np.where(z)[0][[0, -1]]\n", "\n", " return [rmin, rmax, cmin, cmax, zmin, zmax]" ] }, { "cell_type": "code", "execution_count": null, "id": "1405e186", "metadata": {}, "outputs": [], "source": [ "def lbl_from_cat(cat_lbl):\n", " \n", " lbl=0\n", " if (len(cat_lbl.shape)==3):\n", " for i in range(1,4):\n", " lbl = lbl + cat_lbl[:,:,i]*i\n", " elif (len(cat_lbl.shape)==4):\n", " for i in range(1,4):\n", " lbl = lbl + cat_lbl[:,:,:,i]*i\n", " else:\n", " print('Error in lbl_from_cat', cat_lbl.shape)\n", " return None\n", " return lbl" ] }, { "cell_type": "code", "execution_count": null, "id": "24eb472f", "metadata": {}, "outputs": [], "source": [ "def show_label(lbl):\n", " vu.show_n_images([lbl[:,:,k] for k in range(4)]+[lbl_from_cat(lbl)],\n", " titles = label_type_shrt + ['Label'])\n", "\n", "def show_pred_im_label(im, lb, pred):\n", " \n", " vu.show_n_images([im[:,:,1], lb[:,:], \n", " show_label_on_image4(im[:,:,1], lb[:,:]),\n", " show_label_on_image4(im[:,:,1], pred[:,:])],\n", " titles=['Flair', 'Label', 'Label on T1', 'Prediction on Flair'])\n", "\n", "def show_pred_im(im, pred):\n", " \n", " vu.show_n_images([im[:,:,1], \n", " im[:,:,0],pred,\n", " show_label_on_image4(im[:,:,1], pred[:,:])],\n", " titles=['Flair','T1', 'Pred', 'Prediction on Flair'])" ] }, { "cell_type": "markdown", "id": "d15f788b", "metadata": {}, "source": [ "Multiple image inputs:\n", "- Native (T1)\n", "- Post-contrast T1-weighted (T1Gd)\n", "- T2-weighted (T2)\n", "- T2 Fluid Attenuated Inversion Recovery (T2-FLAIR)" ] }, { "cell_type": "code", "execution_count": null, "id": "7a7aad87", "metadata": {}, "outputs": [], "source": [ "# Resize input images\n", "from scipy.ndimage import zoom\n", "\n", "def resize(img, shape, mode='constant', orig_shape=(155, 240, 240)):\n", " \"\"\"\n", " Wrapper for scipy.ndimage.zoom suited for MRI images.\n", " \"\"\"\n", " assert len(shape) == 3, \"Can not have more than 3 dimensions\"\n", " factors = (\n", " shape[0]/orig_shape[0],\n", " shape[1]/orig_shape[1], \n", " shape[2]/orig_shape[2]\n", " )\n", " \n", " # Resize to the given shape\n", " return zoom(img, factors, mode=mode)\n", "\n", "def preprocess_label(img, out_shape=None, mode='nearest'):\n", " \"\"\"\n", " Separates out the 3 labels from the segmentation provided, namely:\n", " GD-enhancing tumor (ET — label 4), the peritumoral edema (ED — label 2))\n", " and the necrotic and non-enhancing tumor core (NCR/NET — label 1)\n", " \"\"\"\n", " ncr = img == 1 # Necrotic and Non-Enhancing Tumor (NCR/NET)\n", " \n", " ed = img == 2 # Peritumoral Edema (ED)\n", " et = img == 4 # GD-enhancing Tumor (ET)\n", " \n", " if out_shape is not None:\n", " ncr = resize(ncr, out_shape, mode=mode)\n", " ed = resize(ed, out_shape, mode=mode)\n", " et = resize(et, out_shape, mode=mode)\n", " return np.array([ncr, ed, et], dtype=np.uint8)\n", "\n", "hgg_path = \"/code/AMDMIGraphX/bratsdata/MICCAI_BraTS_2019_Data_Training/HGG\"\n", "np_image=np.zeros((4, 160, 224, 224), dtype=np.float32)\n", "tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_flair.nii.gz'%hgg_path)\n", "tmp = resize(tmp, [160,224,224])\n", "mean = tmp.mean()\n", "std = tmp.std()\n", "np_image[0] = (tmp - mean) / std\n", "\n", "tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t1.nii.gz'%hgg_path)\n", "tmp = resize(tmp, [160,224,224])\n", "mean = tmp.mean()\n", "std = tmp.std()\n", "np_image[1] = (tmp - mean) / std\n", "\n", "tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t1ce.nii.gz'%hgg_path)\n", "tmp = resize(tmp, [160,224,224])\n", "mean = tmp.mean()\n", "std = tmp.std()\n", "np_image[2] = (tmp - mean) / std\n", "\n", "tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t2.nii.gz'%hgg_path)\n", "tmp = resize(tmp, [160,224,224])\n", "mean = tmp.mean()\n", "std = tmp.std()\n", "np_image[3] = (tmp - mean) / std\n", "\n", "print(np_image.shape)\n", "np_image_tmp = np_image.copy()" ] }, { "cell_type": "code", "execution_count": null, "id": "d7e5b3c6", "metadata": {}, "outputs": [], "source": [ "vu.show_n_images(np_image[:,100,:,:], titles=img_type)" ] }, { "cell_type": "code", "execution_count": null, "id": "19117da5", "metadata": {}, "outputs": [], "source": [ "np_lbl=np.zeros((160, 224, 224), dtype=np.int)\n", "tmp = read_img_sitk('/code/AMDMIGraphX/bratsdata/MICCAI_BraTS_2019_Data_Training/HGG/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_seg.nii.gz').astype(int)\n", "tmp = resize(tmp, [160,224,224])\n", "print(tmp.shape)\n", "np_lbl = tmp.astype(int)\n", "print(np_lbl.shape)\n", "\n", "print(np_image.shape)\n", "\n", "img1 = vu.show_label_on_image4(np_image[1,100,:,:], np_lbl[100])\n", "img2 = vu.show_label_on_image(np_image[1,100,:,:], np_lbl[100])\n", "vu.show_n_images([img1,img2,np_image[0,100]])" ] }, { "cell_type": "code", "execution_count": null, "id": "facdea15", "metadata": {}, "outputs": [], "source": [ "def get_pred(img, threshold=0.5):\n", " out_img=img.copy()\n", " out_img=np.where(out_img>threshold, 1,0)\n", " return out_img\n", "\n", "def prediction_from_probabily_3D(img):\n", " \n", " int_image = get_pred(img)\n", " return lbl_from_cat(int_image)\n", "\n", "def get_prediction_for_batch(pred_batch, threshold=0.5):\n", " \n", " out_batch = np.zeros((pred_batch.shape[0], 224, 224),dtype=np.int)\n", " \n", " for j in range(pred_batch.shape[0]):\n", " pred = get_prediction(pred_batch[j])\n", " if (pred.sum()>0):\n", " print(j, np.unique(pred , return_counts=True))\n", " out_batch[j] = lbl_from_cat(get_prediction(pred_batch[j]))\n", " return out_batch\n", "\n", "def get_label_from_pred_batch(labels_batch):\n", " \n", " batch = np.zeros((labels_batch.shape[0], 224, 224), np.uint8)\n", " \n", " for j in range(labels_batch.shape[0]):\n", " batch[j]=get_pred(labels_batch[j,:,:,0])+\\\n", " get_pred(labels_batch[j,:,:,1])*2+\\\n", " get_pred(labels_batch[j,:,:,2])*4\n", "\n", " return batch\n", "\n", "def predict_3D_img_prob(np_file):\n", " \n", " np_img = np.load(np_file)\n", " for_pred_img = np.zeros((160, 224, 224, 4), np.float32)\n", "\n", " # Normalize image\n", " for_pred_img = normalize_3D_image(np_img)\n", "\n", " mdl_pred_img = model.predict(for_pred_img)\n", "\n", " #pred_label = prediction_from_probabily_3D(mdl_pred_img)\n", "\n", " return mdl_pred_img\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7f7fe7ee", "metadata": {}, "outputs": [], "source": [ "#Remember the MIGraphX model inputs\n", "print(model.get_parameter_names())\n", "print(model.get_parameter_shapes())\n", "\n", "np_image = np_image.transpose((0,2,3,1))\n", "\n", "print(np_image.shape)\n", "print(np_image.strides)" ] }, { "cell_type": "code", "execution_count": null, "id": "dfc47b53", "metadata": {}, "outputs": [], "source": [ "def normalize_3D_image(img):\n", " for z in range(img.shape[0]):\n", " for k in range(4):\n", " if (img[z,:,:,k].max()>0):\n", " img[z,:,:,k] /= img[z,:,:,k].max()\n", " return img" ] }, { "cell_type": "code", "execution_count": null, "id": "f990cb50", "metadata": {}, "outputs": [], "source": [ "print(np_image_tmp.shape)\n", "np_image_tmp = np_image_tmp.transpose((1,2,3,0))\n", "print(np_image_tmp.shape)" ] }, { "cell_type": "code", "execution_count": null, "id": "24c3736d", "metadata": {}, "outputs": [], "source": [ "np_image = np.expand_dims(np_image, 0)\n", "print(np_image.shape)\n", "print(np_image.strides)" ] }, { "cell_type": "code", "execution_count": null, "id": "1aac6285", "metadata": {}, "outputs": [], "source": [ "input_im = np.zeros((1,4,224,224,160),dtype='float32')\n", "np.lib.stride_tricks.as_strided(input_im, shape=np_image.shape, strides=input_im.strides)[:] = np_image #getting correct stride\n", "print(input_im.strides)\n", "print(input_im.shape)\n", "\n", "#input_im = normalize_3D_image(input_im)\n", "\n", "print(input_im.strides)\n", "print(input_im.shape)\n", "\n", "result = model.run({\n", " \"input\": input_im\n", " })" ] }, { "cell_type": "code", "execution_count": null, "id": "5848b63d", "metadata": {}, "outputs": [], "source": [ "output = np.array(result[0])\n", "print(output.shape)\n", "output = output[0]\n", "print(output.shape)\n", "output = output.transpose((3,1,2,0))\n", "print(output.shape)" ] }, { "cell_type": "code", "execution_count": null, "id": "ab77f7e9", "metadata": {}, "outputs": [], "source": [ "out = prediction_from_probabily_3D(output)\n", "print(np_image_tmp.shape)\n", "print(np_lbl.shape)\n", "print(out.shape)\n", "print(np.unique(out))\n", "ind=[100]\n", "for i in ind:\n", " show_label(output[i])\n", " show_label(get_pred(output[i]))\n", " show_pred_im_label(np_image_tmp[i], np_lbl[i], out[i])" ] }, { "cell_type": "markdown", "id": "d2862d81", "metadata": {}, "source": [ "The possible prediction discrepancy is due to the not-perfect resizing 3D input image, as BRATS dataset has 3D images of size 160x240x240, meanwhile the ONNX model utilized here requires 155x224x224. This example is representative for how to utilize MIGraphX for such an application. All data processing should follow and match the model requirements otherwise. " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 }