"...composable_kernel_rocm.git" did not exist on "7d700bc0a403b2a235a617c4bdd221c3a5bcd877"
Unverified Commit fcb49e71 authored by Cagri Eryilmaz's avatar Cagri Eryilmaz Committed by GitHub
Browse files

[Example] 3D-Unet (#883)



* unet3d notebook, visualization

* inference notebook for unet3d, sample input inference

* unet3d performance migraphx notebook

* seperating unet3d from unet

* remove unet from unet3d branch

* readme updates

* rename file

* sample inference with brats dataset for unet3d notebook

* required visualization file

* readme update

* remove perf script, move to branch

* remove unused functions from vis

* renaming

* renaming more

* py format

* main readme update

* notebook update

* update readme for data access

* cleanup notebook

* Update README.md

duplicate pointers in readme

* Update examples/python_3dunet/README.md
Co-authored-by: default avatarkahmed10 <15948690+kahmed10@users.noreply.github.com>

* Update examples/python_3dunet/README.md
Co-authored-by: default avatarkahmed10 <15948690+kahmed10@users.noreply.github.com>

* cleanup notebook

* label typos

* cleanup

* vis util import cleanup

* path changes + npsave remove
Co-authored-by: default avatarkahmed10 <15948690+kahmed10@users.noreply.github.com>
Co-authored-by: default avatarChris Austen <causten@users.noreply.github.com>
parent 9d71a5e6
...@@ -15,3 +15,4 @@ This directory contains examples of common use cases for MIGraphX. ...@@ -15,3 +15,4 @@ This directory contains examples of common use cases for MIGraphX.
- [Python Super Resolution](./python_super_resolution) - [Python Super Resolution](./python_super_resolution)
- [Python NFNet](./python_nfnet_inference) - [Python NFNet](./python_nfnet_inference)
- [Python U-Net](./python_unet) - [Python U-Net](./python_unet)
- [Python 3D-UNet](./python_3dunet)
{
"cells": [
{
"cell_type": "markdown",
"id": "fee8cfa5",
"metadata": {},
"source": [
"# 3D-UNet Example with MIGraphX\n",
"References:<br>\n",
"https://github.com/naomifridman/Unet_Brain_tumor_segmentation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb22bcc4",
"metadata": {},
"outputs": [],
"source": [
"import migraphx\n",
"from PIL import Image\n",
"import numpy as np\n",
"import os\n",
"import SimpleITK as sitk"
]
},
{
"cell_type": "markdown",
"id": "cb973c63",
"metadata": {},
"source": [
"## Fetch U-NET ONNX Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1928662c",
"metadata": {},
"outputs": [],
"source": [
"!wget -nc https://zenodo.org/record/3928973/files/224_224_160.onnx"
]
},
{
"cell_type": "markdown",
"id": "1a64a616",
"metadata": {},
"source": [
"## Load ONNX Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53928a98",
"metadata": {},
"outputs": [],
"source": [
"model = migraphx.parse_onnx(\"224_224_160.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27e8587f",
"metadata": {},
"outputs": [],
"source": [
"model.compile(migraphx.get_target(\"gpu\"))"
]
},
{
"cell_type": "markdown",
"id": "2f6014a4",
"metadata": {},
"source": [
"## Print model parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e73728c",
"metadata": {},
"outputs": [],
"source": [
"print(model.get_parameter_names())\n",
"print(model.get_parameter_shapes())\n",
"print(model.get_output_shapes())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4cac52e",
"metadata": {},
"outputs": [],
"source": [
"img_type=['FLAIR', 'T1','T1CE', 'T2']\n",
"label_type_shrt = ['background', 'necrotic',\n",
" 'edema', 'enhancing']\n",
"label_type = ['background', 'necrotic and non-enhancing tumor', 'edema', 'enhancing tumor']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b65f9297",
"metadata": {},
"outputs": [],
"source": [
"red_multiplier = [1, 0.2, 0.2]\n",
"green_multiplier = [0.35,0.75,0.25]\n",
"blue_multiplier = [0,0.5,1.]#[0,0.25,0.9]\n",
"yellow_multiplier = [1,1,0.25]\n",
"brown_miltiplier = [40./255, 26./255, 13./255]\n",
"my_colors=[blue_multiplier, yellow_multiplier, brown_miltiplier]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e175ac5",
"metadata": {},
"outputs": [],
"source": [
"from importlib import reload # Python 3.4+ only."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "530e4f97",
"metadata": {},
"outputs": [],
"source": [
"import visualization_utils as vu\n",
"from visualization_utils import show_label_on_image4\n",
"reload(vu)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "865c46a2",
"metadata": {},
"outputs": [],
"source": [
"def show_img_label(img, lbl, modality = 0):\n",
" \n",
" if (len(lbl.shape)> 2):\n",
" lbl[0,0,3]=1 # for uniqe colors in plot\n",
" lbl = lbl_from_cat(lbl)\n",
" vu.show_n_images([img[:,:,modality],lbl, show_label_on_image4(img[:,:,modality],lbl)],\n",
" titles = [img_type[modality], 'Label', 'Label on '+ img_type[modality]]);\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e926482",
"metadata": {},
"outputs": [],
"source": [
"def read_img_sitk(img):\n",
" inputImage = sitk.ReadImage( img )\n",
" inputImage = sitk.Cast( inputImage, sitk.sitkFloat32 )\n",
" image = sitk.GetArrayFromImage(inputImage)\n",
" return image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b620138",
"metadata": {},
"outputs": [],
"source": [
"# ima files are of the form\n",
"# BraTS19_TCIA04_192_1_flair.nii.gz \n",
"# BraTS19_TCIA04_192_1_t1.nii.gz \n",
"# BraTS19_TCIA04_192_1_t2.nii.gz\n",
"# BraTS19_TCIA04_192_1_seg.nii.gz \n",
"# BraTS19_TCIA04_192_1_t1ce.nii.gz\n",
"\n",
"def read_image_into_numpy(dirpath):\n",
" \n",
" img_id = os.path.basename(dirpath)\n",
" np_image=np.zeros((4, 160, 224, 224), dtype=np.float32)\n",
" \n",
" ## Flair\n",
" flair_img = os.path.join(dirpath, img_id+'_flair.nii.gz')\n",
" if (not os.path.isfile(flair_img)):\n",
" print(flair_img,' not found aborting')\n",
" return None\n",
" np_image[0] = read_img_sitk(flair_img)\n",
" \n",
" ## T1\n",
" t1_nb4_img = os.path.join(dirpath, img_id+'_t1_nb4.nii.gz')\n",
" if (not os.path.isfile(t1_nb4_img)):\n",
" #print(t1_nb4_img,' not found')\n",
" t1_img = os.path.join(dirpath, img_id+'_t1.nii.gz')\n",
" if (not os.path.isfile(t1_img)):\n",
" print(t1_img,' not found aborting')\n",
" return None\n",
" np_image[1] = read_img_sitk(t1_img)\n",
" else:\n",
" np_image[1] = read_img_sitk(t1_nb4_img) \n",
" \n",
" ## T1CE\n",
" t1ce_nb4_img = os.path.join(dirpath, img_id+'_t1ce_nb4.nii.gz')\n",
" if (not os.path.isfile(t1ce_nb4_img)):\n",
" #print(t1ce_nb4_img,' not found')\n",
" t1ce_img = os.path.join(dirpath, img_id+'_t1ce.nii.gz')\n",
" if (not os.path.isfile(t1ce_img)):\n",
" print(t1ce_img,' not found aborting')\n",
" return None\n",
" np_image[2] = read_img_sitk(t1ce_img)\n",
" else:\n",
" np_image[2] = read_img_sitk(t1ce_nb4_img) \n",
" \n",
" \n",
" ## T2\n",
" t2_img = os.path.join(dirpath, img_id+'_t2.nii.gz')\n",
" if (not os.path.isfile(t2_img)):\n",
" print(t2_img,' not found aborting')\n",
" return None\n",
" np_image[3] = read_img_sitk(t2_img)\n",
"\n",
" return np_image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2fb66f17",
"metadata": {},
"outputs": [],
"source": [
"def read_label_into_numpy(dirpath):\n",
" \n",
" img_id = os.path.basename(dirpath)\n",
" np_image=np.zeros((160, 224, 224), dtype=np.int)\n",
" \n",
" ## label\n",
" label_img = os.path.join(dirpath, img_id+'_seg.nii.gz')\n",
" if (not os.path.isfile(label_img)):\n",
" print(label_img,' not found aborting')\n",
" return None\n",
" np_image = read_img_sitk(label_img).astype(int)\n",
"\n",
" return np_image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "558d47b9",
"metadata": {},
"outputs": [],
"source": [
"def bbox2_3D(img):\n",
"\n",
" r = np.any(img, axis=(1, 2))\n",
" c = np.any(img, axis=(0, 2))\n",
" z = np.any(img, axis=(0, 1))\n",
"\n",
" rmin, rmax = np.where(r)[0][[0, -1]]\n",
" cmin, cmax = np.where(c)[0][[0, -1]]\n",
" zmin, zmax = np.where(z)[0][[0, -1]]\n",
"\n",
" return [rmin, rmax, cmin, cmax, zmin, zmax]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1405e186",
"metadata": {},
"outputs": [],
"source": [
"def lbl_from_cat(cat_lbl):\n",
" \n",
" lbl=0\n",
" if (len(cat_lbl.shape)==3):\n",
" for i in range(1,4):\n",
" lbl = lbl + cat_lbl[:,:,i]*i\n",
" elif (len(cat_lbl.shape)==4):\n",
" for i in range(1,4):\n",
" lbl = lbl + cat_lbl[:,:,:,i]*i\n",
" else:\n",
" print('Error in lbl_from_cat', cat_lbl.shape)\n",
" return None\n",
" return lbl"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24eb472f",
"metadata": {},
"outputs": [],
"source": [
"def show_label(lbl):\n",
" vu.show_n_images([lbl[:,:,k] for k in range(4)]+[lbl_from_cat(lbl)],\n",
" titles = label_type_shrt + ['Label'])\n",
"\n",
"def show_pred_im_label(im, lb, pred):\n",
" \n",
" vu.show_n_images([im[:,:,1], lb[:,:], \n",
" show_label_on_image4(im[:,:,1], lb[:,:]),\n",
" show_label_on_image4(im[:,:,1], pred[:,:])],\n",
" titles=['Flair', 'Label', 'Label on T1', 'Prediction on Flair'])\n",
"\n",
"def show_pred_im(im, pred):\n",
" \n",
" vu.show_n_images([im[:,:,1], \n",
" im[:,:,0],pred,\n",
" show_label_on_image4(im[:,:,1], pred[:,:])],\n",
" titles=['Flair','T1', 'Pred', 'Prediction on Flair'])"
]
},
{
"cell_type": "markdown",
"id": "d15f788b",
"metadata": {},
"source": [
"Multiple image inputs:\n",
"- Native (T1)\n",
"- Post-contrast T1-weighted (T1Gd)\n",
"- T2-weighted (T2)\n",
"- T2 Fluid Attenuated Inversion Recovery (T2-FLAIR)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a7aad87",
"metadata": {},
"outputs": [],
"source": [
"# Resize input images\n",
"from scipy.ndimage import zoom\n",
"\n",
"def resize(img, shape, mode='constant', orig_shape=(155, 240, 240)):\n",
" \"\"\"\n",
" Wrapper for scipy.ndimage.zoom suited for MRI images.\n",
" \"\"\"\n",
" assert len(shape) == 3, \"Can not have more than 3 dimensions\"\n",
" factors = (\n",
" shape[0]/orig_shape[0],\n",
" shape[1]/orig_shape[1], \n",
" shape[2]/orig_shape[2]\n",
" )\n",
" \n",
" # Resize to the given shape\n",
" return zoom(img, factors, mode=mode)\n",
"\n",
"def preprocess_label(img, out_shape=None, mode='nearest'):\n",
" \"\"\"\n",
" Separates out the 3 labels from the segmentation provided, namely:\n",
" GD-enhancing tumor (ET — label 4), the peritumoral edema (ED — label 2))\n",
" and the necrotic and non-enhancing tumor core (NCR/NET — label 1)\n",
" \"\"\"\n",
" ncr = img == 1 # Necrotic and Non-Enhancing Tumor (NCR/NET)\n",
" \n",
" ed = img == 2 # Peritumoral Edema (ED)\n",
" et = img == 4 # GD-enhancing Tumor (ET)\n",
" \n",
" if out_shape is not None:\n",
" ncr = resize(ncr, out_shape, mode=mode)\n",
" ed = resize(ed, out_shape, mode=mode)\n",
" et = resize(et, out_shape, mode=mode)\n",
" return np.array([ncr, ed, et], dtype=np.uint8)\n",
"\n",
"hgg_path = \"/code/AMDMIGraphX/bratsdata/MICCAI_BraTS_2019_Data_Training/HGG\"\n",
"np_image=np.zeros((4, 160, 224, 224), dtype=np.float32)\n",
"tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_flair.nii.gz'%hgg_path)\n",
"tmp = resize(tmp, [160,224,224])\n",
"mean = tmp.mean()\n",
"std = tmp.std()\n",
"np_image[0] = (tmp - mean) / std\n",
"\n",
"tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t1.nii.gz'%hgg_path)\n",
"tmp = resize(tmp, [160,224,224])\n",
"mean = tmp.mean()\n",
"std = tmp.std()\n",
"np_image[1] = (tmp - mean) / std\n",
"\n",
"tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t1ce.nii.gz'%hgg_path)\n",
"tmp = resize(tmp, [160,224,224])\n",
"mean = tmp.mean()\n",
"std = tmp.std()\n",
"np_image[2] = (tmp - mean) / std\n",
"\n",
"tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t2.nii.gz'%hgg_path)\n",
"tmp = resize(tmp, [160,224,224])\n",
"mean = tmp.mean()\n",
"std = tmp.std()\n",
"np_image[3] = (tmp - mean) / std\n",
"\n",
"print(np_image.shape)\n",
"np_image_tmp = np_image.copy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7e5b3c6",
"metadata": {},
"outputs": [],
"source": [
"vu.show_n_images(np_image[:,100,:,:], titles=img_type)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19117da5",
"metadata": {},
"outputs": [],
"source": [
"np_lbl=np.zeros((160, 224, 224), dtype=np.int)\n",
"tmp = read_img_sitk('/code/AMDMIGraphX/bratsdata/MICCAI_BraTS_2019_Data_Training/HGG/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_seg.nii.gz').astype(int)\n",
"tmp = resize(tmp, [160,224,224])\n",
"print(tmp.shape)\n",
"np_lbl = tmp.astype(int)\n",
"print(np_lbl.shape)\n",
"\n",
"print(np_image.shape)\n",
"\n",
"img1 = vu.show_label_on_image4(np_image[1,100,:,:], np_lbl[100])\n",
"img2 = vu.show_label_on_image(np_image[1,100,:,:], np_lbl[100])\n",
"vu.show_n_images([img1,img2,np_image[0,100]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "facdea15",
"metadata": {},
"outputs": [],
"source": [
"def get_pred(img, threshold=0.5):\n",
" out_img=img.copy()\n",
" out_img=np.where(out_img>threshold, 1,0)\n",
" return out_img\n",
"\n",
"def prediction_from_probabily_3D(img):\n",
" \n",
" int_image = get_pred(img)\n",
" return lbl_from_cat(int_image)\n",
"\n",
"def get_prediction_for_batch(pred_batch, threshold=0.5):\n",
" \n",
" out_batch = np.zeros((pred_batch.shape[0], 224, 224),dtype=np.int)\n",
" \n",
" for j in range(pred_batch.shape[0]):\n",
" pred = get_prediction(pred_batch[j])\n",
" if (pred.sum()>0):\n",
" print(j, np.unique(pred , return_counts=True))\n",
" out_batch[j] = lbl_from_cat(get_prediction(pred_batch[j]))\n",
" return out_batch\n",
"\n",
"def get_label_from_pred_batch(labels_batch):\n",
" \n",
" batch = np.zeros((labels_batch.shape[0], 224, 224), np.uint8)\n",
" \n",
" for j in range(labels_batch.shape[0]):\n",
" batch[j]=get_pred(labels_batch[j,:,:,0])+\\\n",
" get_pred(labels_batch[j,:,:,1])*2+\\\n",
" get_pred(labels_batch[j,:,:,2])*4\n",
"\n",
" return batch\n",
"\n",
"def predict_3D_img_prob(np_file):\n",
" \n",
" np_img = np.load(np_file)\n",
" for_pred_img = np.zeros((160, 224, 224, 4), np.float32)\n",
"\n",
" # Normalize image\n",
" for_pred_img = normalize_3D_image(np_img)\n",
"\n",
" mdl_pred_img = model.predict(for_pred_img)\n",
"\n",
" #pred_label = prediction_from_probabily_3D(mdl_pred_img)\n",
"\n",
" return mdl_pred_img\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f7fe7ee",
"metadata": {},
"outputs": [],
"source": [
"#Remember the MIGraphX model inputs\n",
"print(model.get_parameter_names())\n",
"print(model.get_parameter_shapes())\n",
"\n",
"np_image = np_image.transpose((0,2,3,1))\n",
"\n",
"print(np_image.shape)\n",
"print(np_image.strides)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dfc47b53",
"metadata": {},
"outputs": [],
"source": [
"def normalize_3D_image(img):\n",
" for z in range(img.shape[0]):\n",
" for k in range(4):\n",
" if (img[z,:,:,k].max()>0):\n",
" img[z,:,:,k] /= img[z,:,:,k].max()\n",
" return img"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f990cb50",
"metadata": {},
"outputs": [],
"source": [
"print(np_image_tmp.shape)\n",
"np_image_tmp = np_image_tmp.transpose((1,2,3,0))\n",
"print(np_image_tmp.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24c3736d",
"metadata": {},
"outputs": [],
"source": [
"np_image = np.expand_dims(np_image, 0)\n",
"print(np_image.shape)\n",
"print(np_image.strides)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1aac6285",
"metadata": {},
"outputs": [],
"source": [
"input_im = np.zeros((1,4,224,224,160),dtype='float32')\n",
"np.lib.stride_tricks.as_strided(input_im, shape=np_image.shape, strides=input_im.strides)[:] = np_image #getting correct stride\n",
"print(input_im.strides)\n",
"print(input_im.shape)\n",
"\n",
"#input_im = normalize_3D_image(input_im)\n",
"\n",
"print(input_im.strides)\n",
"print(input_im.shape)\n",
"\n",
"result = model.run({\n",
" \"input\": input_im\n",
" })"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5848b63d",
"metadata": {},
"outputs": [],
"source": [
"output = np.array(result[0])\n",
"print(output.shape)\n",
"output = output[0]\n",
"print(output.shape)\n",
"output = output.transpose((3,1,2,0))\n",
"print(output.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab77f7e9",
"metadata": {},
"outputs": [],
"source": [
"out = prediction_from_probabily_3D(output)\n",
"print(np_image_tmp.shape)\n",
"print(np_lbl.shape)\n",
"print(out.shape)\n",
"print(np.unique(out))\n",
"ind=[100]\n",
"for i in ind:\n",
" show_label(output[i])\n",
" show_label(get_pred(output[i]))\n",
" show_pred_im_label(np_image_tmp[i], np_lbl[i], out[i])"
]
},
{
"cell_type": "markdown",
"id": "d2862d81",
"metadata": {},
"source": [
"The possible prediction discrepancy is due to the not-perfect resizing 3D input image, as BRATS dataset has 3D images of size 160x240x240, meanwhile the ONNX model utilized here requires 155x224x224. This example is representative for how to utilize MIGraphX for such an application. All data processing should follow and match the model requirements otherwise. "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
# 3D-Unet Inference with AMD MIGraphX
This example applies image segmentation to 3D images using AMD MIGraphX on a given AMD GPU.
## How to:
1) User will need to have access to the BRATS dataset. Please follow https://www.med.upenn.edu/cbica/brats2019/data.html for how to get access to the dataset.
2) Follow the provided notebook `3dunet_inference.ipynb`.
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.pylab as pylab
import numpy as np
params = {
'legend.fontsize': 'x-large',
'figure.figsize': (6, 5),
'axes.labelsize': 'x-large',
'axes.titlesize': 'x-large',
'xtick.labelsize': 'x-large',
'ytick.labelsize': 'x-large'
}
pylab.rcParams.update(params)
#-----------------------------------------------------------
def show_n_images(imgs, titles=None, enlarge=20, cmap='jet'):
plt.set_cmap(cmap)
n = len(imgs)
gs1 = gridspec.GridSpec(1, n)
fig1 = plt.figure()
# create a figure with the default size
fig1.set_size_inches(enlarge, 2 * enlarge)
for i in range(n):
ax1 = fig1.add_subplot(gs1[i])
ax1.imshow(imgs[i], interpolation='none')
if (titles is not None):
ax1.set_title(titles[i])
ax1.set_ylim(ax1.get_ylim()[::-1])
plt.show()
#--------------------------------------------------------------
from skimage import io, color, img_as_float
from skimage.exposure import adjust_gamma
# Creates an image of original brain with segmentation overlay
def show_label_on_image(test_img, test_lbl):
modes = {'flair': 0, 't1': 1, 't1c': 2, 't2': 3}
label_im = test_lbl
ones = np.argwhere(label_im == 1)
twos = np.argwhere(label_im == 2)
threes = np.argwhere(label_im == 3)
fours = np.argwhere(label_im == 4)
gray_img = img_as_float(test_img / test_img.max())
# adjust gamma of image
# print(color.gray2rgb(gray_img))
image = adjust_gamma(np.abs(color.gray2rgb(gray_img)), 0.45)
#sliced_image = image.copy()
red_multiplier = [1, 0.2, 0.2]
green_multiplier = [0.35, 0.75, 0.25]
blue_multiplier = [0, 0.5, 1.] #[0,0.25,0.9]
yellow_multiplier = [1, 1, 0.25]
brown_miltiplier = [40. / 255, 26. / 255, 13. / 255]
# change colors of segmented classes
for i in range(len(ones)):
image[ones[i][0]][ones[i][1]] = blue_multiplier #red_multiplier
for i in range(len(twos)):
image[twos[i][0]][twos[i][1]] = yellow_multiplier
for i in range(len(threes)):
image[threes[i][0]][threes[i][1]] = brown_miltiplier #blue_multiplier
for i in range(len(fours)):
image[fours[i][0]][fours[i][1]] = green_multiplier #yellow_multiplier
return image
#-------------------------------------------------------------------------------------
def show_label_on_image4(test_img, label_im):
alpha = 0.8
img = img_as_float(test_img / test_img.max())
rows, cols = img.shape
# Construct a colour image to superimpose
color_mask = np.zeros((rows, cols, 3))
red_multiplier = [1, 0.2, 0.2]
green_multiplier = [0.35, 0.75, 0.25]
blue_multiplier = [0, 0.25, 0.9]
yellow_multiplier = [1, 1, 0.25]
brown_miltiplier = [40. / 255, 26. / 255, 13. / 255]
color_mask[label_im == 1] = blue_multiplier #[1, 0, 0] # Red block
color_mask[label_im == 2] = yellow_multiplier #[0, 1, 0] # Green block
color_mask[label_im == 3] = brown_miltiplier #[0, 0, 1] # Blue block
color_mask[label_im == 4] = green_multiplier #[0, 1, 1] # Blue block
# Construct RGB version of grey-level image
img_color = np.dstack((img, img, img))
# Convert the input image and color mask to Hue Saturation Value (HSV)
# colorspace
img_hsv = color.rgb2hsv(img_color)
color_mask_hsv = color.rgb2hsv(color_mask)
# Replace the hue and saturation of the original image
# with that of the color mask
img_hsv[..., 0] = color_mask_hsv[..., 0]
img_hsv[..., 1] = color_mask_hsv[..., 1] * alpha
img_masked = color.hsv2rgb(img_hsv)
return img_masked
#------------------------------------------------------------------------------
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment