Commit f94d77fc authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 03929873 6403d482
...@@ -147,6 +147,7 @@ jobs: ...@@ -147,6 +147,7 @@ jobs:
os: os:
- ubuntu-16.04 - ubuntu-16.04
- ubuntu-18.04 - ubuntu-18.04
- ubuntu-20.04
configuration: configuration:
- debug - debug
- release - release
...@@ -208,7 +209,7 @@ jobs: ...@@ -208,7 +209,7 @@ jobs:
rbuild build -d cget -s gh -t check \ rbuild build -d cget -s gh -t check \
-DCMAKE_BUILD_TYPE=${{matrix.configuration}} \ -DCMAKE_BUILD_TYPE=${{matrix.configuration}} \
-DMIGRAPHX_ENABLE_PYTHON=${{matrix.configuration == 'release' && 'On' || 'Off'}} \ -DMIGRAPHX_ENABLE_PYTHON=${{matrix.configuration == 'release' && 'On' || 'Off'}} \
-DCMAKE_CXX_FLAGS_DEBUG="-g1 -Os -fdebug-prefix-map=$PWD=. -fdebug-types-section -fno-omit-frame-pointer ${{matrix.os != 'ubuntu-16.04' && '-fsanitize-address-use-after-scope' || ''}} -fsanitize=undefined,address -fno-sanitize-recover=undefined,address" \ -DCMAKE_CXX_FLAGS_DEBUG="-g1 -Os -fdebug-prefix-map=$PWD=. -fdebug-types-section -fno-omit-frame-pointer -fsanitize=undefined -fno-sanitize-recover=undefined" \
-DCMAKE_CXX_FLAGS_CODECOV="-g1 -Og -fdebug-prefix-map=$PWD=. -fdebug-types-section -fprofile-arcs -ftest-coverage -fno-omit-frame-pointer" \ -DCMAKE_CXX_FLAGS_CODECOV="-g1 -Og -fdebug-prefix-map=$PWD=. -fdebug-types-section -fprofile-arcs -ftest-coverage -fno-omit-frame-pointer" \
-DCMAKE_EXE_LINKER_FLAGS='-fuse-ld=gold' \ -DCMAKE_EXE_LINKER_FLAGS='-fuse-ld=gold' \
-DCMAKE_SHARED_LINKER_FLAGS='-fuse-ld=gold' -DCMAKE_SHARED_LINKER_FLAGS='-fuse-ld=gold'
......
...@@ -15,11 +15,13 @@ def rocmtestnode(Map conf) { ...@@ -15,11 +15,13 @@ def rocmtestnode(Map conf) {
def cmd = """ def cmd = """
env env
ulimit -c unlimited ulimit -c unlimited
echo "leak:dnnl::impl::malloc" > suppressions.txt
export LSAN_OPTIONS="suppressions=\$(pwd)/suppressions.txt"
rm -rf build rm -rf build
mkdir build mkdir build
cd build cd build
CXX=${compiler} CXXFLAGS='-Werror -Wno-fallback' cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} .. CXX=${compiler} CXXFLAGS='-Werror -Wno-fallback' cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ${flags} ..
CTEST_PARALLEL_LEVEL=32 make -j\$(nproc) generate all doc package check VERBOSE=1 make -j\$(nproc) generate all doc package check VERBOSE=1
""" """
echo cmd echo cmd
sh cmd sh cmd
...@@ -73,6 +75,8 @@ def rocmnodename(name) { ...@@ -73,6 +75,8 @@ def rocmnodename(name) {
node_name = "${rocmtest_name} && fiji"; node_name = "${rocmtest_name} && fiji";
} else if(name == "vega") { } else if(name == "vega") {
node_name = "${rocmtest_name} && vega"; node_name = "${rocmtest_name} && vega";
} else if(name == "nogpu") {
return rocmtest_name;
} }
return node_name return node_name
} }
...@@ -100,6 +104,12 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build -> ...@@ -100,6 +104,12 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'") cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'")
} }
}, clang_asan: rocmnode('nogpu') { cmake_build ->
stage('Clang ASAN') {
def sanitizers = "undefined,address"
def debug_flags = "-g -O2 -fno-omit-frame-pointer -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=Off -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'")
}
} }
def onnxnode(name, body) { def onnxnode(name, body) {
......
...@@ -10,8 +10,9 @@ This directory contains examples of common use cases for MIGraphX. ...@@ -10,8 +10,9 @@ This directory contains examples of common use cases for MIGraphX.
- [Exporting Frozen Graphs in TF2](./export_frozen_graph_tf2) - [Exporting Frozen Graphs in TF2](./export_frozen_graph_tf2)
- [MIGraphX Docker Container](./migraphx_docker) - [MIGraphX Docker Container](./migraphx_docker)
- [MIGraphX Driver](./migraphx_driver) - [MIGraphX Driver](./migraphx_driver)
- [Python Resnet50 Inference](./python_api_inference) - [Python Resnet50](./python_api_inference)
- [Python BERT SQuAD Inference](./python_bert_squad_example) - [Python BERT-SQuAD](./python_bert_squad_example)
- [Python Super Resolution](./python_super_resolution) - [Python Super Resolution](./python_super_resolution)
- [Python NFNet Inference](./python_nfnet_inference) - [Python NFNet](./python_nfnet_inference)
- [Python U-Net](./python_unet)
- [Python 3D-UNet](./python_3dunet)
...@@ -148,13 +148,6 @@ ...@@ -148,13 +148,6 @@
"\n", "\n",
"print(process.stdout)" "print(process.stdout)"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {
...@@ -172,8 +165,7 @@ ...@@ -172,8 +165,7 @@
"mimetype": "text/x-python", "mimetype": "text/x-python",
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3"
"version": "3.7.9"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
{
"cells": [
{
"cell_type": "markdown",
"id": "fee8cfa5",
"metadata": {},
"source": [
"# 3D-UNet Example with MIGraphX\n",
"References:<br>\n",
"https://github.com/naomifridman/Unet_Brain_tumor_segmentation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb22bcc4",
"metadata": {},
"outputs": [],
"source": [
"import migraphx\n",
"from PIL import Image\n",
"import numpy as np\n",
"import os\n",
"import SimpleITK as sitk"
]
},
{
"cell_type": "markdown",
"id": "cb973c63",
"metadata": {},
"source": [
"## Fetch U-NET ONNX Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1928662c",
"metadata": {},
"outputs": [],
"source": [
"!wget -nc https://zenodo.org/record/3928973/files/224_224_160.onnx"
]
},
{
"cell_type": "markdown",
"id": "1a64a616",
"metadata": {},
"source": [
"## Load ONNX Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53928a98",
"metadata": {},
"outputs": [],
"source": [
"model = migraphx.parse_onnx(\"224_224_160.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27e8587f",
"metadata": {},
"outputs": [],
"source": [
"model.compile(migraphx.get_target(\"gpu\"))"
]
},
{
"cell_type": "markdown",
"id": "2f6014a4",
"metadata": {},
"source": [
"## Print model parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e73728c",
"metadata": {},
"outputs": [],
"source": [
"print(model.get_parameter_names())\n",
"print(model.get_parameter_shapes())\n",
"print(model.get_output_shapes())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4cac52e",
"metadata": {},
"outputs": [],
"source": [
"img_type=['FLAIR', 'T1','T1CE', 'T2']\n",
"label_type_shrt = ['background', 'necrotic',\n",
" 'edema', 'enhancing']\n",
"label_type = ['background', 'necrotic and non-enhancing tumor', 'edema', 'enhancing tumor']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b65f9297",
"metadata": {},
"outputs": [],
"source": [
"red_multiplier = [1, 0.2, 0.2]\n",
"green_multiplier = [0.35,0.75,0.25]\n",
"blue_multiplier = [0,0.5,1.]#[0,0.25,0.9]\n",
"yellow_multiplier = [1,1,0.25]\n",
"brown_miltiplier = [40./255, 26./255, 13./255]\n",
"my_colors=[blue_multiplier, yellow_multiplier, brown_miltiplier]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e175ac5",
"metadata": {},
"outputs": [],
"source": [
"from importlib import reload # Python 3.4+ only."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "530e4f97",
"metadata": {},
"outputs": [],
"source": [
"import visualization_utils as vu\n",
"from visualization_utils import show_label_on_image4\n",
"reload(vu)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "865c46a2",
"metadata": {},
"outputs": [],
"source": [
"def show_img_label(img, lbl, modality = 0):\n",
" \n",
" if (len(lbl.shape)> 2):\n",
" lbl[0,0,3]=1 # for uniqe colors in plot\n",
" lbl = lbl_from_cat(lbl)\n",
" vu.show_n_images([img[:,:,modality],lbl, show_label_on_image4(img[:,:,modality],lbl)],\n",
" titles = [img_type[modality], 'Label', 'Label on '+ img_type[modality]]);\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e926482",
"metadata": {},
"outputs": [],
"source": [
"def read_img_sitk(img):\n",
" inputImage = sitk.ReadImage( img )\n",
" inputImage = sitk.Cast( inputImage, sitk.sitkFloat32 )\n",
" image = sitk.GetArrayFromImage(inputImage)\n",
" return image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b620138",
"metadata": {},
"outputs": [],
"source": [
"# ima files are of the form\n",
"# BraTS19_TCIA04_192_1_flair.nii.gz \n",
"# BraTS19_TCIA04_192_1_t1.nii.gz \n",
"# BraTS19_TCIA04_192_1_t2.nii.gz\n",
"# BraTS19_TCIA04_192_1_seg.nii.gz \n",
"# BraTS19_TCIA04_192_1_t1ce.nii.gz\n",
"\n",
"def read_image_into_numpy(dirpath):\n",
" \n",
" img_id = os.path.basename(dirpath)\n",
" np_image=np.zeros((4, 160, 224, 224), dtype=np.float32)\n",
" \n",
" ## Flair\n",
" flair_img = os.path.join(dirpath, img_id+'_flair.nii.gz')\n",
" if (not os.path.isfile(flair_img)):\n",
" print(flair_img,' not found aborting')\n",
" return None\n",
" np_image[0] = read_img_sitk(flair_img)\n",
" \n",
" ## T1\n",
" t1_nb4_img = os.path.join(dirpath, img_id+'_t1_nb4.nii.gz')\n",
" if (not os.path.isfile(t1_nb4_img)):\n",
" #print(t1_nb4_img,' not found')\n",
" t1_img = os.path.join(dirpath, img_id+'_t1.nii.gz')\n",
" if (not os.path.isfile(t1_img)):\n",
" print(t1_img,' not found aborting')\n",
" return None\n",
" np_image[1] = read_img_sitk(t1_img)\n",
" else:\n",
" np_image[1] = read_img_sitk(t1_nb4_img) \n",
" \n",
" ## T1CE\n",
" t1ce_nb4_img = os.path.join(dirpath, img_id+'_t1ce_nb4.nii.gz')\n",
" if (not os.path.isfile(t1ce_nb4_img)):\n",
" #print(t1ce_nb4_img,' not found')\n",
" t1ce_img = os.path.join(dirpath, img_id+'_t1ce.nii.gz')\n",
" if (not os.path.isfile(t1ce_img)):\n",
" print(t1ce_img,' not found aborting')\n",
" return None\n",
" np_image[2] = read_img_sitk(t1ce_img)\n",
" else:\n",
" np_image[2] = read_img_sitk(t1ce_nb4_img) \n",
" \n",
" \n",
" ## T2\n",
" t2_img = os.path.join(dirpath, img_id+'_t2.nii.gz')\n",
" if (not os.path.isfile(t2_img)):\n",
" print(t2_img,' not found aborting')\n",
" return None\n",
" np_image[3] = read_img_sitk(t2_img)\n",
"\n",
" return np_image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2fb66f17",
"metadata": {},
"outputs": [],
"source": [
"def read_label_into_numpy(dirpath):\n",
" \n",
" img_id = os.path.basename(dirpath)\n",
" np_image=np.zeros((160, 224, 224), dtype=np.int)\n",
" \n",
" ## label\n",
" label_img = os.path.join(dirpath, img_id+'_seg.nii.gz')\n",
" if (not os.path.isfile(label_img)):\n",
" print(label_img,' not found aborting')\n",
" return None\n",
" np_image = read_img_sitk(label_img).astype(int)\n",
"\n",
" return np_image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "558d47b9",
"metadata": {},
"outputs": [],
"source": [
"def bbox2_3D(img):\n",
"\n",
" r = np.any(img, axis=(1, 2))\n",
" c = np.any(img, axis=(0, 2))\n",
" z = np.any(img, axis=(0, 1))\n",
"\n",
" rmin, rmax = np.where(r)[0][[0, -1]]\n",
" cmin, cmax = np.where(c)[0][[0, -1]]\n",
" zmin, zmax = np.where(z)[0][[0, -1]]\n",
"\n",
" return [rmin, rmax, cmin, cmax, zmin, zmax]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1405e186",
"metadata": {},
"outputs": [],
"source": [
"def lbl_from_cat(cat_lbl):\n",
" \n",
" lbl=0\n",
" if (len(cat_lbl.shape)==3):\n",
" for i in range(1,4):\n",
" lbl = lbl + cat_lbl[:,:,i]*i\n",
" elif (len(cat_lbl.shape)==4):\n",
" for i in range(1,4):\n",
" lbl = lbl + cat_lbl[:,:,:,i]*i\n",
" else:\n",
" print('Error in lbl_from_cat', cat_lbl.shape)\n",
" return None\n",
" return lbl"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24eb472f",
"metadata": {},
"outputs": [],
"source": [
"def show_label(lbl):\n",
" vu.show_n_images([lbl[:,:,k] for k in range(4)]+[lbl_from_cat(lbl)],\n",
" titles = label_type_shrt + ['Label'])\n",
"\n",
"def show_pred_im_label(im, lb, pred):\n",
" \n",
" vu.show_n_images([im[:,:,1], lb[:,:], \n",
" show_label_on_image4(im[:,:,1], lb[:,:]),\n",
" show_label_on_image4(im[:,:,1], pred[:,:])],\n",
" titles=['Flair', 'Label', 'Label on T1', 'Prediction on Flair'])\n",
"\n",
"def show_pred_im(im, pred):\n",
" \n",
" vu.show_n_images([im[:,:,1], \n",
" im[:,:,0],pred,\n",
" show_label_on_image4(im[:,:,1], pred[:,:])],\n",
" titles=['Flair','T1', 'Pred', 'Prediction on Flair'])"
]
},
{
"cell_type": "markdown",
"id": "d15f788b",
"metadata": {},
"source": [
"Multiple image inputs:\n",
"- Native (T1)\n",
"- Post-contrast T1-weighted (T1Gd)\n",
"- T2-weighted (T2)\n",
"- T2 Fluid Attenuated Inversion Recovery (T2-FLAIR)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a7aad87",
"metadata": {},
"outputs": [],
"source": [
"# Resize input images\n",
"from scipy.ndimage import zoom\n",
"\n",
"def resize(img, shape, mode='constant', orig_shape=(155, 240, 240)):\n",
" \"\"\"\n",
" Wrapper for scipy.ndimage.zoom suited for MRI images.\n",
" \"\"\"\n",
" assert len(shape) == 3, \"Can not have more than 3 dimensions\"\n",
" factors = (\n",
" shape[0]/orig_shape[0],\n",
" shape[1]/orig_shape[1], \n",
" shape[2]/orig_shape[2]\n",
" )\n",
" \n",
" # Resize to the given shape\n",
" return zoom(img, factors, mode=mode)\n",
"\n",
"def preprocess_label(img, out_shape=None, mode='nearest'):\n",
" \"\"\"\n",
" Separates out the 3 labels from the segmentation provided, namely:\n",
" GD-enhancing tumor (ET — label 4), the peritumoral edema (ED — label 2))\n",
" and the necrotic and non-enhancing tumor core (NCR/NET — label 1)\n",
" \"\"\"\n",
" ncr = img == 1 # Necrotic and Non-Enhancing Tumor (NCR/NET)\n",
" \n",
" ed = img == 2 # Peritumoral Edema (ED)\n",
" et = img == 4 # GD-enhancing Tumor (ET)\n",
" \n",
" if out_shape is not None:\n",
" ncr = resize(ncr, out_shape, mode=mode)\n",
" ed = resize(ed, out_shape, mode=mode)\n",
" et = resize(et, out_shape, mode=mode)\n",
" return np.array([ncr, ed, et], dtype=np.uint8)\n",
"\n",
"hgg_path = \"/code/AMDMIGraphX/bratsdata/MICCAI_BraTS_2019_Data_Training/HGG\"\n",
"np_image=np.zeros((4, 160, 224, 224), dtype=np.float32)\n",
"tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_flair.nii.gz'%hgg_path)\n",
"tmp = resize(tmp, [160,224,224])\n",
"mean = tmp.mean()\n",
"std = tmp.std()\n",
"np_image[0] = (tmp - mean) / std\n",
"\n",
"tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t1.nii.gz'%hgg_path)\n",
"tmp = resize(tmp, [160,224,224])\n",
"mean = tmp.mean()\n",
"std = tmp.std()\n",
"np_image[1] = (tmp - mean) / std\n",
"\n",
"tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t1ce.nii.gz'%hgg_path)\n",
"tmp = resize(tmp, [160,224,224])\n",
"mean = tmp.mean()\n",
"std = tmp.std()\n",
"np_image[2] = (tmp - mean) / std\n",
"\n",
"tmp = read_img_sitk('%s/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_t2.nii.gz'%hgg_path)\n",
"tmp = resize(tmp, [160,224,224])\n",
"mean = tmp.mean()\n",
"std = tmp.std()\n",
"np_image[3] = (tmp - mean) / std\n",
"\n",
"print(np_image.shape)\n",
"np_image_tmp = np_image.copy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7e5b3c6",
"metadata": {},
"outputs": [],
"source": [
"vu.show_n_images(np_image[:,100,:,:], titles=img_type)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19117da5",
"metadata": {},
"outputs": [],
"source": [
"np_lbl=np.zeros((160, 224, 224), dtype=np.int)\n",
"tmp = read_img_sitk('/code/AMDMIGraphX/bratsdata/MICCAI_BraTS_2019_Data_Training/HGG/BraTS19_TMC_30014_1/BraTS19_TMC_30014_1_seg.nii.gz').astype(int)\n",
"tmp = resize(tmp, [160,224,224])\n",
"print(tmp.shape)\n",
"np_lbl = tmp.astype(int)\n",
"print(np_lbl.shape)\n",
"\n",
"print(np_image.shape)\n",
"\n",
"img1 = vu.show_label_on_image4(np_image[1,100,:,:], np_lbl[100])\n",
"img2 = vu.show_label_on_image(np_image[1,100,:,:], np_lbl[100])\n",
"vu.show_n_images([img1,img2,np_image[0,100]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "facdea15",
"metadata": {},
"outputs": [],
"source": [
"def get_pred(img, threshold=0.5):\n",
" out_img=img.copy()\n",
" out_img=np.where(out_img>threshold, 1,0)\n",
" return out_img\n",
"\n",
"def prediction_from_probabily_3D(img):\n",
" \n",
" int_image = get_pred(img)\n",
" return lbl_from_cat(int_image)\n",
"\n",
"def get_prediction_for_batch(pred_batch, threshold=0.5):\n",
" \n",
" out_batch = np.zeros((pred_batch.shape[0], 224, 224),dtype=np.int)\n",
" \n",
" for j in range(pred_batch.shape[0]):\n",
" pred = get_prediction(pred_batch[j])\n",
" if (pred.sum()>0):\n",
" print(j, np.unique(pred , return_counts=True))\n",
" out_batch[j] = lbl_from_cat(get_prediction(pred_batch[j]))\n",
" return out_batch\n",
"\n",
"def get_label_from_pred_batch(labels_batch):\n",
" \n",
" batch = np.zeros((labels_batch.shape[0], 224, 224), np.uint8)\n",
" \n",
" for j in range(labels_batch.shape[0]):\n",
" batch[j]=get_pred(labels_batch[j,:,:,0])+\\\n",
" get_pred(labels_batch[j,:,:,1])*2+\\\n",
" get_pred(labels_batch[j,:,:,2])*4\n",
"\n",
" return batch\n",
"\n",
"def predict_3D_img_prob(np_file):\n",
" \n",
" np_img = np.load(np_file)\n",
" for_pred_img = np.zeros((160, 224, 224, 4), np.float32)\n",
"\n",
" # Normalize image\n",
" for_pred_img = normalize_3D_image(np_img)\n",
"\n",
" mdl_pred_img = model.predict(for_pred_img)\n",
"\n",
" #pred_label = prediction_from_probabily_3D(mdl_pred_img)\n",
"\n",
" return mdl_pred_img\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f7fe7ee",
"metadata": {},
"outputs": [],
"source": [
"#Remember the MIGraphX model inputs\n",
"print(model.get_parameter_names())\n",
"print(model.get_parameter_shapes())\n",
"\n",
"np_image = np_image.transpose((0,2,3,1))\n",
"\n",
"print(np_image.shape)\n",
"print(np_image.strides)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dfc47b53",
"metadata": {},
"outputs": [],
"source": [
"def normalize_3D_image(img):\n",
" for z in range(img.shape[0]):\n",
" for k in range(4):\n",
" if (img[z,:,:,k].max()>0):\n",
" img[z,:,:,k] /= img[z,:,:,k].max()\n",
" return img"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f990cb50",
"metadata": {},
"outputs": [],
"source": [
"print(np_image_tmp.shape)\n",
"np_image_tmp = np_image_tmp.transpose((1,2,3,0))\n",
"print(np_image_tmp.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24c3736d",
"metadata": {},
"outputs": [],
"source": [
"np_image = np.expand_dims(np_image, 0)\n",
"print(np_image.shape)\n",
"print(np_image.strides)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1aac6285",
"metadata": {},
"outputs": [],
"source": [
"input_im = np.zeros((1,4,224,224,160),dtype='float32')\n",
"np.lib.stride_tricks.as_strided(input_im, shape=np_image.shape, strides=input_im.strides)[:] = np_image #getting correct stride\n",
"print(input_im.strides)\n",
"print(input_im.shape)\n",
"\n",
"#input_im = normalize_3D_image(input_im)\n",
"\n",
"print(input_im.strides)\n",
"print(input_im.shape)\n",
"\n",
"result = model.run({\n",
" \"input\": input_im\n",
" })"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5848b63d",
"metadata": {},
"outputs": [],
"source": [
"output = np.array(result[0])\n",
"print(output.shape)\n",
"output = output[0]\n",
"print(output.shape)\n",
"output = output.transpose((3,1,2,0))\n",
"print(output.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab77f7e9",
"metadata": {},
"outputs": [],
"source": [
"out = prediction_from_probabily_3D(output)\n",
"print(np_image_tmp.shape)\n",
"print(np_lbl.shape)\n",
"print(out.shape)\n",
"print(np.unique(out))\n",
"ind=[100]\n",
"for i in ind:\n",
" show_label(output[i])\n",
" show_label(get_pred(output[i]))\n",
" show_pred_im_label(np_image_tmp[i], np_lbl[i], out[i])"
]
},
{
"cell_type": "markdown",
"id": "d2862d81",
"metadata": {},
"source": [
"The possible prediction discrepancy is due to the not-perfect resizing 3D input image, as BRATS dataset has 3D images of size 160x240x240, meanwhile the ONNX model utilized here requires 155x224x224. This example is representative for how to utilize MIGraphX for such an application. All data processing should follow and match the model requirements otherwise. "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
# 3D-Unet Inference with AMD MIGraphX
This example applies image segmentation to 3D images using AMD MIGraphX on a given AMD GPU.
## How to:
1) User will need to have access to the BRATS dataset. Please follow https://www.med.upenn.edu/cbica/brats2019/data.html for how to get access to the dataset.
2) Follow the provided notebook `3dunet_inference.ipynb`.
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.pylab as pylab
import numpy as np
params = {
'legend.fontsize': 'x-large',
'figure.figsize': (6, 5),
'axes.labelsize': 'x-large',
'axes.titlesize': 'x-large',
'xtick.labelsize': 'x-large',
'ytick.labelsize': 'x-large'
}
pylab.rcParams.update(params)
#-----------------------------------------------------------
def show_n_images(imgs, titles=None, enlarge=20, cmap='jet'):
plt.set_cmap(cmap)
n = len(imgs)
gs1 = gridspec.GridSpec(1, n)
fig1 = plt.figure()
# create a figure with the default size
fig1.set_size_inches(enlarge, 2 * enlarge)
for i in range(n):
ax1 = fig1.add_subplot(gs1[i])
ax1.imshow(imgs[i], interpolation='none')
if (titles is not None):
ax1.set_title(titles[i])
ax1.set_ylim(ax1.get_ylim()[::-1])
plt.show()
#--------------------------------------------------------------
from skimage import color, img_as_float
from skimage.exposure import adjust_gamma
# Creates an image of original brain with segmentation overlay
def show_label_on_image(test_img, test_lbl):
label_im = test_lbl
ones = np.argwhere(label_im == 1)
twos = np.argwhere(label_im == 2)
threes = np.argwhere(label_im == 3)
fours = np.argwhere(label_im == 4)
gray_img = img_as_float(test_img / test_img.max())
# adjust gamma of image
# print(color.gray2rgb(gray_img))
image = adjust_gamma(np.abs(color.gray2rgb(gray_img)), 0.45)
#sliced_image = image.copy()
green_multiplier = [0.35, 0.75, 0.25]
blue_multiplier = [0, 0.5, 1.] #[0,0.25,0.9]
yellow_multiplier = [1, 1, 0.25]
brown_miltiplier = [40. / 255, 26. / 255, 13. / 255]
# change colors of segmented classes
for i in range(len(ones)):
image[ones[i][0]][ones[i][1]] = blue_multiplier
for i in range(len(twos)):
image[twos[i][0]][twos[i][1]] = yellow_multiplier
for i in range(len(threes)):
image[threes[i][0]][threes[i][1]] = brown_miltiplier #blue_multiplier
for i in range(len(fours)):
image[fours[i][0]][fours[i][1]] = green_multiplier #yellow_multiplier
return image
#-------------------------------------------------------------------------------------
def show_label_on_image4(test_img, label_im):
alpha = 0.8
img = img_as_float(test_img / test_img.max())
rows, cols = img.shape
# Construct a colour image to superimpose
color_mask = np.zeros((rows, cols, 3))
green_multiplier = [0.35, 0.75, 0.25]
blue_multiplier = [0, 0.25, 0.9]
yellow_multiplier = [1, 1, 0.25]
brown_miltiplier = [40. / 255, 26. / 255, 13. / 255]
color_mask[label_im == 1] = blue_multiplier #[1, 0, 0] # Red block
color_mask[label_im == 2] = yellow_multiplier #[0, 1, 0] # Green block
color_mask[label_im == 3] = brown_miltiplier #[0, 0, 1] # Blue block
color_mask[label_im == 4] = green_multiplier #[0, 1, 1] # Blue block
# Construct RGB version of grey-level image
img_color = np.dstack((img, img, img))
# Convert the input image and color mask to Hue Saturation Value (HSV)
# colorspace
img_hsv = color.rgb2hsv(img_color)
color_mask_hsv = color.rgb2hsv(color_mask)
# Replace the hue and saturation of the original image
# with that of the color mask
img_hsv[..., 0] = color_mask_hsv[..., 0]
img_hsv[..., 1] = color_mask_hsv[..., 1] * alpha
img_masked = color.hsv2rgb(img_hsv)
return img_masked
#------------------------------------------------------------------------------
...@@ -339,8 +339,7 @@ ...@@ -339,8 +339,7 @@
"mimetype": "text/x-python", "mimetype": "text/x-python",
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3"
"version": "3.6.9"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -273,8 +273,7 @@ ...@@ -273,8 +273,7 @@
"mimetype": "text/x-python", "mimetype": "text/x-python",
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3"
"version": "3.6.9"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
import numpy as np import numpy as np
import json import json
import time
import os.path import os.path
from os import path
import sys
import tokenizers import tokenizers
from run_onnx_squad import * import collections
from run_onnx_squad import read_squad_examples, write_predictions, convert_examples_to_features
import migraphx import migraphx
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
####################################### #######################################
input_file = 'inputs_amd.json' input_file = 'inputs_amd.json'
with open(input_file) as json_file: with open(input_file) as json_file:
......
...@@ -29,7 +29,6 @@ python onnx_squad.py --model $SQUAD_MODEL/squad.onnx \ ...@@ -29,7 +29,6 @@ python onnx_squad.py --model $SQUAD_MODEL/squad.onnx \
import argparse import argparse
import collections import collections
import json import json
import logging
import math import math
import os import os
import sys import sys
...@@ -145,8 +144,6 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -145,8 +144,6 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_to_orig_index.append(i) tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token) all_doc_tokens.append(sub_token)
tok_start_position = None
tok_end_position = None
# The -3 accounts for [CLS], [SEP] and [SEP] # The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
...@@ -567,7 +564,7 @@ def main(): ...@@ -567,7 +564,7 @@ def main():
sess_options = onnxrt.SessionOptions() sess_options = onnxrt.SessionOptions()
sess_options.session_log_verbosity_level = args.log sess_options.session_log_verbosity_level = args.log
tokenizer = BertWordPieceTokenizer(vocab_file) tokenizer = BertWordPieceTokenizer(args.vocab_file)
eval_examples = read_squad_examples(input_file=args.predict_file) eval_examples = read_squad_examples(input_file=args.predict_file)
input_ids, input_mask, segment_ids, extra_data = \ input_ids, input_mask, segment_ids, extra_data = \
......
...@@ -213,7 +213,7 @@ ...@@ -213,7 +213,7 @@
"# Run the model\n", "# Run the model\n",
"start = time.time()\n", "start = time.time()\n",
"results = model.run({'inputs':data_input}) # Your first inference would take longer than the following ones.\n", "results = model.run({'inputs':data_input}) # Your first inference would take longer than the following ones.\n",
"print(f\"Time inference took: {100*(time.time() - start):.2f}ms\")\n", "print(f\"Time inference took: {1000*(time.time() - start):.2f}ms\")\n",
"# Extract the index of the top prediction\n", "# Extract the index of the top prediction\n",
"res_npa = np.array(results[0])\n", "res_npa = np.array(results[0])\n",
"print(f\"\\nResult: {labels[np.argmax(res_npa)]}\")" "print(f\"\\nResult: {labels[np.argmax(res_npa)]}\")"
...@@ -228,7 +228,7 @@ ...@@ -228,7 +228,7 @@
"# Run the model again, first one would take long\n", "# Run the model again, first one would take long\n",
"start = time.time()\n", "start = time.time()\n",
"results = model.run({'inputs':data_input}) # Your first inference would take longer than the following ones.\n", "results = model.run({'inputs':data_input}) # Your first inference would take longer than the following ones.\n",
"print(f\"Time inference took: {100*(time.time() - start):.2f}ms\")\n", "print(f\"Time inference took: {1000*(time.time() - start):.2f}ms\")\n",
"# Extract the index of the top prediction\n", "# Extract the index of the top prediction\n",
"res_npa = np.array(results[0])\n", "res_npa = np.array(results[0])\n",
"print(f\"\\nResult: {labels[np.argmax(res_npa)]}\")" "print(f\"\\nResult: {labels[np.argmax(res_npa)]}\")"
...@@ -250,8 +250,7 @@ ...@@ -250,8 +250,7 @@
"mimetype": "text/x-python", "mimetype": "text/x-python",
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3"
"version": "3.6.9"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -210,8 +210,7 @@ ...@@ -210,8 +210,7 @@
"mimetype": "text/x-python", "mimetype": "text/x-python",
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3"
"version": "3.6.9"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
# U-Net Image Segmentation Inference with AMD MIGraphX
This examples provides a simple example for utilizing U-Net ONNX model for image segmentation, using AMD MIGraphX graph optimization engine for fast inference.
## How-to
Please utilize the notebook given `unet_inference.ipynb`.
## Model Details
ONNX model utilized in this example can be found [here](https://www.dropbox.com/s/3ntkhyk30x05uuv/unet_13_256.onnx).
\ No newline at end of file
numpy
matplotlib
\ No newline at end of file
{
"cells": [
{
"cell_type": "markdown",
"id": "cd7a3990",
"metadata": {},
"source": [
"## Import MIGraphX Python Library"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3930d7b8",
"metadata": {},
"outputs": [],
"source": [
"import migraphx\n",
"from PIL import Image\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"id": "b350c333",
"metadata": {},
"source": [
"## Fetch U-NET ONNX Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02a7b7de",
"metadata": {},
"outputs": [],
"source": [
"!wget -nc https://www.dropbox.com/s/3ntkhyk30x05uuv/unet_13_256.onnx"
]
},
{
"cell_type": "markdown",
"id": "a6cfe6e9",
"metadata": {},
"source": [
"## Load ONNX Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e05a13dc",
"metadata": {},
"outputs": [],
"source": [
"model = migraphx.parse_onnx(\"unet_13_256.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52c67023",
"metadata": {},
"outputs": [],
"source": [
"model.compile(migraphx.get_target(\"gpu\"))"
]
},
{
"cell_type": "markdown",
"id": "80edb6f1",
"metadata": {},
"source": [
"## Print model parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd5c3269",
"metadata": {},
"outputs": [],
"source": [
"print(model.get_parameter_names())\n",
"print(model.get_parameter_shapes())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47f956c7",
"metadata": {},
"outputs": [],
"source": [
"def preprocess(pil_img, newW, newH):\n",
" w, h = pil_img.size\n",
" assert newW > 0 and newH > 0, 'Scale is too small'\n",
" pil_img = pil_img.resize((newW, newH))\n",
"\n",
" img_nd = np.array(pil_img)\n",
"\n",
" if len(img_nd.shape) == 2:\n",
" img_nd = np.expand_dims(img_nd, axis=2)\n",
"\n",
" # HWC to CHW\n",
" img_print = pil_img\n",
" img_trans = img_nd.transpose((2, 0, 1))\n",
" if img_trans.max() > 1:\n",
" img_trans = img_trans / 255\n",
" \n",
" img_trans = np.expand_dims(img_trans, 0)\n",
"\n",
" return img_trans, img_print\n",
"\n",
"def plot_img_and_mask(img, mask):\n",
" classes = mask.shape[0] if len(mask.shape) > 3 else 1\n",
" print(classes)\n",
" fig, ax = plt.subplots(1, classes + 1)\n",
" ax[0].set_title('Input image')\n",
" ax[0].imshow(img)\n",
" if classes > 1:\n",
" for i in range(classes):\n",
" ax[i+1].set_title(f'Output mask (class {i+1})')\n",
" ax[i+1].imshow(mask[:, :, i])\n",
" else:\n",
" ax[1].set_title(f'Output mask')\n",
" ax[1].imshow(mask[0,0])\n",
" plt.xticks([]), plt.yticks([])\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "389ddc4d",
"metadata": {},
"outputs": [],
"source": [
"img = Image.open(\"./car1.jpeg\")\n",
"img, imPrint = preprocess(img, 256, 256)\n",
"input_im = np.zeros((1,3,256,256),dtype='float32') \n",
"np.lib.stride_tricks.as_strided(input_im, shape=img.shape, strides=input_im.strides)[:] = img #getting correct stride\n",
"print(input_im.strides)\n",
"print(input_im.shape)\n",
"imPrint.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9de6f2a7",
"metadata": {},
"outputs": [],
"source": [
"mask = model.run({'inputs':input_im}) # Your first inference would take longer than the following ones.\n",
"output_mask = np.array(mask[0])\n",
"print(output_mask.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "acbd68e3",
"metadata": {},
"outputs": [],
"source": [
"def sigmoid(x):\n",
" return 1 / (1 + np.exp(-x))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "58e3062c",
"metadata": {},
"outputs": [],
"source": [
"probs = sigmoid(output_mask)\n",
"full_mask = probs > 0.996\n",
"plot_img_and_mask(imPrint, full_mask)"
]
},
{
"cell_type": "markdown",
"id": "6126df0b",
"metadata": {},
"source": [
"<b>NOTE:</b> The model weights utilized here are trained by using car images with plain backgrounds. The imperfect result on a \"real-world\" image as shown above is expected. To get a better result fine-tuning the model on a dataset of real-world examples is recommended. "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
...@@ -9,57 +9,60 @@ add_library(migraphx ...@@ -9,57 +9,60 @@ add_library(migraphx
analyze_streams.cpp analyze_streams.cpp
argument.cpp argument.cpp
auto_contiguous.cpp auto_contiguous.cpp
eliminate_common_subexpression.cpp common.cpp
decompose.cpp
propagate_constant.cpp
compile_src.cpp compile_src.cpp
convert_to_json.cpp
cpp_generator.cpp cpp_generator.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
decompose.cpp
dom_info.cpp dom_info.cpp
dynamic_loader.cpp dynamic_loader.cpp
eliminate_allocation.cpp eliminate_allocation.cpp
eliminate_contiguous.cpp eliminate_common_subexpression.cpp
eliminate_concat.cpp eliminate_concat.cpp
eliminate_contiguous.cpp
eliminate_data_type.cpp eliminate_data_type.cpp
eliminate_identity.cpp eliminate_identity.cpp
eliminate_pad.cpp eliminate_pad.cpp
insert_pad.cpp
file_buffer.cpp
rewrite_batchnorm.cpp
rewrite_rnn.cpp
rewrite_pooling.cpp
env.cpp env.cpp
file_buffer.cpp
generate.cpp generate.cpp
inline_module.cpp inline_module.cpp
insert_pad.cpp
instruction.cpp instruction.cpp
json.cpp
load_save.cpp load_save.cpp
make_op.cpp make_op.cpp
module.cpp
msgpack.cpp msgpack.cpp
normalize_attributes.cpp
normalize_ops.cpp
operation.cpp operation.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
pass_manager.cpp
permutation.cpp permutation.cpp
preallocate_param.cpp
process.cpp process.cpp
program.cpp program.cpp
module.cpp propagate_constant.cpp
quantization.cpp quantization.cpp
reduce_dims.cpp reduce_dims.cpp
register_op.cpp
register_target.cpp
remap.cpp remap.cpp
shape.cpp rewrite_batchnorm.cpp
rewrite_pooling.cpp
rewrite_quantization.cpp
rewrite_rnn.cpp
schedule.cpp schedule.cpp
serialize.cpp serialize.cpp
pass_manager.cpp shape.cpp
register_op.cpp
register_target.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
tmp_dir.cpp tmp_dir.cpp
value.cpp value.cpp
verify_args.cpp verify_args.cpp
json.cpp
convert_to_json.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
normalize_attributes.cpp
normalize_ops.cpp
) )
configure_file(version.h.in include/migraphx/version.h) configure_file(version.h.in include/migraphx/version.h)
rocm_set_soversion(migraphx ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx ${MIGRAPHX_SO_VERSION})
...@@ -92,6 +95,7 @@ register_migraphx_ops( ...@@ -92,6 +95,7 @@ register_migraphx_ops(
cosh cosh
cos cos
deconvolution deconvolution
dequantizelinear
div div
dot dot
elu elu
...@@ -130,6 +134,7 @@ register_migraphx_ops( ...@@ -130,6 +134,7 @@ register_migraphx_ops(
prelu prelu
quant_convolution quant_convolution
quant_dot quant_dot
quantizelinear
recip recip
reduce_max reduce_max
reduce_mean reduce_mean
...@@ -146,6 +151,7 @@ register_migraphx_ops( ...@@ -146,6 +151,7 @@ register_migraphx_ops(
round round
rsqrt rsqrt
scalar scalar
scatter
sigmoid sigmoid
sign sign
sinh sinh
......
...@@ -124,6 +124,14 @@ argument::data_t argument::data_t::from_args(const std::vector<argument>& args) ...@@ -124,6 +124,14 @@ argument::data_t argument::data_t::from_args(const std::vector<argument>& args)
return result; return result;
} }
argument argument::copy() const
{
argument result{this->get_shape()};
auto* src = this->data();
std::copy(src, src + this->get_shape().bytes(), result.data());
return result;
}
argument argument::share() const { return {m_shape, m_data.share()}; } argument argument::share() const { return {m_shape, m_data.share()}; }
std::vector<argument> argument::get_sub_objects() const std::vector<argument> argument::get_sub_objects() const
......
#include <migraphx/common.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
if(s0 == s1)
return s0;
if(s0.size() > s1.size())
s0.swap(s1);
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
std::vector<std::size_t> compute_common_lens(const std::vector<shape>& shapes)
{
assert(not shapes.empty());
return transform_accumulate(shapes.begin() + 1,
shapes.end(),
shapes.front().lens(),
&compute_broadcasted_lens,
[](auto s) { return s.lens(); });
}
shape::type_t compute_common_type(shape::type_t t1, shape::type_t t2)
{
if(t1 == t2)
return t1;
shape::type_t result;
shape::visit(t1, [&](auto x) {
shape::visit(t2, [&](auto y) {
// Workaround broken warning on gcc 5
(void)x;
(void)y;
using type = std::common_type_t<decltype(x()), decltype(y())>;
result = shape::get_type<type>{};
});
});
return result;
}
shape::type_t compute_common_types(const std::vector<shape>& shapes)
{
assert(not shapes.empty());
return transform_accumulate(
shapes.begin() + 1, shapes.end(), shapes.front().type(), &compute_common_type, [&](auto s) {
return s.type();
});
}
shape common_shape(const std::vector<shape>& shapes)
{
if(shapes.empty())
return {};
return {compute_common_types(shapes), compute_common_lens(shapes)};
}
instruction_ref insert_common_op(module& m,
instruction_ref ins,
const operation& op,
std::vector<instruction_ref> inputs)
{
auto common = common_shape(to_shapes(inputs));
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input->get_shape().lens() != common.lens())
{
input = m.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", common.lens()}}), input);
}
if(input->get_shape().type() != common.type())
{
input = m.insert_instruction(
ins, make_op("convert", {{"target_type", common.type()}}), input);
}
return input;
});
return m.insert_instruction(ins, op, inputs);
}
instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs)
{
return insert_common_op(m, m.end(), op, std::move(inputs));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment